diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index c3682329cd68..996d022d0c2f 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -139,8 +139,8 @@ object desugar { * def x: Int = expr * def x_=($1: ): Unit = () */ - def valDef(vdef: ValDef)(implicit ctx: Context): Tree = { - val ValDef(name, tpt, rhs) = vdef + def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = { + val vdef @ ValDef(name, tpt, rhs) = transformQuotedPatternName(vdef0) val mods = vdef.mods val setterNeeded = (mods is Mutable) && ctx.owner.isClass && (!(mods is PrivateLocal) || (ctx.owner is Trait)) @@ -197,8 +197,8 @@ object desugar { * ==> * inline def f(x: Boolean): Any = (if (x) 1 else ""): Any */ - private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(implicit ctx: Context): Tree = { - val DefDef(_, tparams, vparamss, tpt, rhs) = meth + private def defDef(meth0: DefDef, isPrimaryConstructor: Boolean = false)(implicit ctx: Context): Tree = { + val meth @ DefDef(_, tparams, vparamss, tpt, rhs) = transformQuotedPatternName(meth0) val methName = normalizeName(meth, tpt).asTermName val mods = meth.mods val epbuf = new ListBuffer[ValDef] @@ -272,6 +272,32 @@ object desugar { } } + /** Transforms a definition with a name starting with a `$` in a quoted pattern into a `quoted.binding.Binding` splice. + * + * The desugaring consists in renaming the the definition and adding the `@patternBindHole` annotation. This + * annotation is used during typing to perform the full transformation. + * + * A definition + * ```scala + * case '{ def $a(...) = ... a() ...; ... a() ... } + * ``` + * into + * ```scala + * case '{ @patternBindHole def a(...) = ... a() ...; ... a() ... } + * ``` + */ + def transformQuotedPatternName(tree: ValOrDefDef)(implicit ctx: Context): ValOrDefDef = { + if (ctx.mode.is(Mode.QuotedPattern) && !tree.isBackquoted && tree.name != nme.ANON_FUN && tree.name.startsWith("$")) { + val name = tree.name.toString.substring(1).toTermName + val newTree: ValOrDefDef = tree match { + case tree: ValDef => cpy.ValDef(tree)(name) + case tree: DefDef => cpy.DefDef(tree)(name) + } + val mods = tree.mods.withAddedAnnotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(tree.span)) + newTree.withMods(mods) + } else tree + } + // Add all evidence parameters in `params` as implicit parameters to `meth` */ private def addEvidenceParams(meth: DefDef, params: List[ValDef])(implicit ctx: Context): DefDef = params match { diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 65e604d8fce0..04a336075ad0 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -357,10 +357,14 @@ object Trees { /** A ValDef or DefDef tree */ abstract class ValOrDefDef[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends MemberDef[T] with WithLazyField[Tree[T]] { + type ThisTree[-T >: Untyped] <: ValOrDefDef[T] def name: TermName def tpt: Tree[T] def unforcedRhs: LazyTree = unforced def rhs(implicit ctx: Context): Tree[T] = forceIfLazy + + /** Is this a `BackquotedValDef` or `BackquotedDefDef` ? */ + def isBackquoted: Boolean = false } // ----------- Tree case classes ------------------------------------ @@ -706,6 +710,12 @@ object Trees { protected def force(x: AnyRef): Unit = preRhs = x } + class BackquotedValDef[-T >: Untyped] private[ast] (name: TermName, tpt: Tree[T], preRhs: LazyTree)(implicit @constructorOnly src: SourceFile) + extends ValDef[T](name, tpt, preRhs) { + override def isBackquoted: Boolean = true + override def productPrefix: String = "BackquotedValDef" + } + /** mods def name[tparams](vparams_1)...(vparams_n): tpt = rhs */ case class DefDef[-T >: Untyped] private[ast] (name: TermName, tparams: List[TypeDef[T]], vparamss: List[List[ValDef[T]]], tpt: Tree[T], private var preRhs: LazyTree)(implicit @constructorOnly src: SourceFile) @@ -716,6 +726,13 @@ object Trees { protected def force(x: AnyRef): Unit = preRhs = x } + class BackquotedDefDef[-T >: Untyped] private[ast] (name: TermName, tparams: List[TypeDef[T]], + vparamss: List[List[ValDef[T]]], tpt: Tree[T], preRhs: LazyTree)(implicit @constructorOnly src: SourceFile) + extends DefDef[T](name, tparams, vparamss, tpt, preRhs) { + override def isBackquoted: Boolean = true + override def productPrefix: String = "BackquotedDefDef" + } + /** mods class name template or * mods trait name template or * mods type name = rhs or @@ -932,7 +949,9 @@ object Trees { type Alternative = Trees.Alternative[T] type UnApply = Trees.UnApply[T] type ValDef = Trees.ValDef[T] + type BackquotedValDef = Trees.BackquotedValDef[T] type DefDef = Trees.DefDef[T] + type BackquotedDefDef = Trees.BackquotedDefDef[T] type TypeDef = Trees.TypeDef[T] type Template = Trees.Template[T] type Import = Trees.Import[T] @@ -1125,10 +1144,16 @@ object Trees { case _ => finalize(tree, untpd.UnApply(fun, implicits, patterns)(sourceFile(tree))) } def ValDef(tree: Tree)(name: TermName, tpt: Tree, rhs: LazyTree)(implicit ctx: Context): ValDef = tree match { + case tree: BackquotedValDef => + if ((name == tree.name) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs)) tree + else finalize(tree, untpd.BackquotedValDef(name, tpt, rhs)(sourceFile(tree))) case tree: ValDef if (name == tree.name) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs) => tree case _ => finalize(tree, untpd.ValDef(name, tpt, rhs)(sourceFile(tree))) } def DefDef(tree: Tree)(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit ctx: Context): DefDef = tree match { + case tree: BackquotedDefDef => + if ((name == tree.name) && (tparams eq tree.tparams) && (vparamss eq tree.vparamss) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs)) tree + else finalize(tree, untpd.BackquotedDefDef(name, tparams, vparamss, tpt, rhs)(sourceFile(tree))) case tree: DefDef if (name == tree.name) && (tparams eq tree.tparams) && (vparamss eq tree.vparamss) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs) => tree case _ => finalize(tree, untpd.DefDef(name, tparams, vparamss, tpt, rhs)(sourceFile(tree))) } diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 55a1cbfeff73..94ca12203539 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -321,7 +321,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def Alternative(trees: List[Tree])(implicit src: SourceFile): Alternative = new Alternative(trees) def UnApply(fun: Tree, implicits: List[Tree], patterns: List[Tree])(implicit src: SourceFile): UnApply = new UnApply(fun, implicits, patterns) def ValDef(name: TermName, tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): ValDef = new ValDef(name, tpt, rhs) + def BackquotedValDef(name: TermName, tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): ValDef = new BackquotedValDef(name, tpt, rhs) def DefDef(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): DefDef = new DefDef(name, tparams, vparamss, tpt, rhs) + def BackquotedDefDef(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): DefDef = new BackquotedDefDef(name, tparams, vparamss, tpt, rhs) def TypeDef(name: TypeName, rhs: Tree)(implicit src: SourceFile): TypeDef = new TypeDef(name, rhs) def Template(constr: DefDef, parents: List[Tree], derived: List[Tree], self: ValDef, body: LazyTreeList)(implicit src: SourceFile): Template = if (derived.isEmpty) new Template(constr, parents, self, body) @@ -406,8 +408,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def makeAndType(left: Tree, right: Tree)(implicit ctx: Context): AppliedTypeTree = AppliedTypeTree(ref(defn.andType.typeRef), left :: right :: Nil) - def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers = EmptyModifiers)(implicit ctx: Context): ValDef = - ValDef(pname, tpe, EmptyTree).withMods(mods | Param) + def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers = EmptyModifiers, isBackquoted: Boolean = false)(implicit ctx: Context): ValDef = { + val vdef = + if (isBackquoted) BackquotedValDef(pname, tpe, EmptyTree) + else ValDef(pname, tpe, EmptyTree) + vdef.withMods(mods | Param) + } def makeSyntheticParameter(n: Int = 1, tpt: Tree = null, flags: FlagSet = EmptyFlags)(implicit ctx: Context): ValDef = ValDef(nme.syntheticParamName(n), if (tpt == null) TypeTree() else tpt, EmptyTree) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 973e025d9c0f..a08676b0a596 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -722,6 +722,11 @@ class Definitions { def InternalQuoted_typeQuote(implicit ctx: Context): Symbol = InternalQuoted_typeQuoteR.symbol lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole") def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol + lazy val InternalQuoted_patternBindHoleAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternBindHole") + lazy val InternalQuoted_patternMatchBindHoleModuleR: TermRef = InternalQuotedModule.requiredValueRef("patternMatchBindHole".toTermName) + def InternalQuoted_patternMatchBindHoleModule: Symbol = InternalQuoted_patternMatchBindHoleModuleR.symbol + lazy val InternalQuoted_patternMatchBindHole_unapplyR: TermRef = InternalQuoted_patternMatchBindHoleModule.requiredMethodRef("unapply") + def InternalQuoted_patternMatchBindHole_unapply(implicit ctx: Context): Symbol = InternalQuoted_patternMatchBindHole_unapplyR.symbol lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher") def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol @@ -741,6 +746,9 @@ class Definitions { lazy val QuotedTypeModuleRef: TermRef = ctx.requiredModuleRef("scala.quoted.Type") def QuotedTypeModule(implicit ctx: Context): Symbol = QuotedTypeModuleRef.symbol + lazy val QuotedMatchingBindingType: TypeRef = ctx.requiredClassRef("scala.quoted.matching.Bind") + def QuotedMatchingBindingClass(implicit ctx: Context): ClassSymbol = QuotedMatchingBindingType.symbol.asClass + def Unpickler_unpickleExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleExpr") def Unpickler_liftedExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.liftedExpr") def Unpickler_unpickleType: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleType") diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 143d14745554..c13f0924b0b4 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -406,10 +406,12 @@ object Parsers { /** Convert tree to formal parameter */ def convertToParam(tree: Tree, expected: String = "formal parameter"): ValDef = tree match { - case Ident(name) => - makeParameter(name.asTermName, TypeTree()).withSpan(tree.span) - case Typed(Ident(name), tpt) => - makeParameter(name.asTermName, tpt).withSpan(tree.span) + case id @ Ident(name) => + makeParameter(name.asTermName, TypeTree(), isBackquoted = id.isBackquoted).withSpan(tree.span) + case Typed(id @ Ident(name), tpt) => + makeParameter(name.asTermName, tpt, isBackquoted = id.isBackquoted).withSpan(tree.span) + case Typed(Splice(Ident(name)), tpt) => + makeParameter(("$" + name).toTermName, tpt).withSpan(tree.span) case _ => syntaxError(s"not a legal $expected", tree.span) makeParameter(nme.ERROR, tree) @@ -2370,7 +2372,9 @@ object Parsers { } } else EmptyTree lhs match { - case (id @ Ident(name: TermName)) :: Nil => { + case (id: BackquotedIdent) :: Nil if id.name.isTermName => + finalizeDef(BackquotedValDef(id.name.asTermName, tpt, rhs), mods, start) + case Ident(name: TermName) :: Nil => { finalizeDef(ValDef(name, tpt, rhs), mods, start) } case _ => PatDef(mods, lhs, tpt, rhs) @@ -2414,10 +2418,10 @@ object Parsers { else (Nil, Method) val mods1 = addFlag(mods, flags) - val name = ident() + val ident = termIdent() val tparams = typeParamClauseOpt(ParamOwner.Def) val vparamss = paramClauses() match { - case rparams :: rparamss if leadingParamss.nonEmpty && !isLeftAssoc(name) => + case rparams :: rparamss if leadingParamss.nonEmpty && !isLeftAssoc(ident.name) => rparams :: leadingParamss ::: rparamss case rparamss => leadingParamss ::: rparamss @@ -2447,7 +2451,9 @@ object Parsers { accept(EQUALS) expr() } - finalizeDef(DefDef(name, tparams, vparamss, tpt, rhs), mods1, start) + + if (ident.isBackquoted) finalizeDef(BackquotedDefDef(ident.name.asTermName, tparams, vparamss, tpt, rhs), mods1, start) + else finalizeDef(DefDef(ident.name.asTermName, tparams, vparamss, tpt, rhs), mods1, start) } } diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala b/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala index 084ae22a93f6..9d3f80e4dcb1 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala @@ -1810,6 +1810,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util. def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole + def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol = defn.InternalQuoted_patternBindHoleAnnot // Types diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index baaf81b32eb3..d8983de80b71 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1013,7 +1013,17 @@ trait Applications extends Compatibility { self: Typer with Dynamic => tree } - def typedUnApply(tree: untpd.Apply, selType: Type)(implicit ctx: Context): Tree = track("typedUnApply") { + def typedUnApply(tree0: untpd.Apply, selType: Type)(implicit ctx: Context): Tree = track("typedUnApply") { + val tree = + if (ctx.mode.is(Mode.QuotedPattern)) { // TODO move to desugar + val Apply(qual0, args0) = tree0 + val args1 = args0 map { + case arg: untpd.Ident if arg.name.startsWith("$") => + untpd.Apply(untpd.ref(defn.InternalQuoted_patternMatchBindHoleModuleR), untpd.Ident(arg.name.toString.substring(1).toTermName) :: Nil) + case arg => arg + } + untpd.cpy.Apply(tree0)(qual0, args1) + } else tree0 val Apply(qual, args) = tree def notAnExtractor(tree: Tree) = diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 148ba097b2af..a70c6727a12a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1416,6 +1416,10 @@ class Typer extends Namer } def typedBind(tree: untpd.Bind, pt: Type)(implicit ctx: Context): Tree = track("typedBind") { + if (ctx.mode.is(Mode.QuotedPattern) && tree.name.startsWith("$")) { + val bind1 = untpd.cpy.Bind(tree)(tree.name.toString.substring(1).toTermName, tree.body) + return typed(untpd.Apply(untpd.ref(defn.InternalQuoted_patternMatchBindHoleModuleR), bind1 :: Nil).withSpan(tree.span), pt) + } val pt1 = fullyDefinedType(pt, "pattern variable", tree.span) val body1 = typed(tree.body, pt1) body1 match { @@ -1959,6 +1963,14 @@ class Typer extends Namer } def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = { + val ctx0 = ctx + + def bindExpr(name: Name, tpe: Type, span: Span): Tree = { + val exprTpe = AppliedType(defn.QuotedMatchingBindingType, tpe :: Nil) + val sym = ctx0.newPatternBoundSymbol(name, exprTpe, span) + Bind(sym, untpd.Ident(nme.WILDCARD).withType(exprTpe)).withSpan(span) + } + object splitter extends tpd.TreeMap { val patBuf = new mutable.ListBuffer[Tree] override def transform(tree: Tree)(implicit ctx: Context) = tree match { @@ -1973,6 +1985,25 @@ class Typer extends Namer val pat1 = if (patType eq patType1) pat else pat.withType(patType1) patBuf += pat1 } + case ddef: ValOrDefDef => + if (ddef.symbol.annotations.exists(_.symbol == defn.InternalQuoted_patternBindHoleAnnot)) { + val tpe = ddef.symbol.info match { + case t: ExprType => t.resType + case t: MethodType => t.toFunctionType() + case t: PolyType => + HKTypeLambda(t.paramNames)( + x => t.paramInfos.mapConserve(_.subst(t, x).asInstanceOf[TypeBounds]), + x => t.resType.subst(t, x).toFunctionType()) + case t => t + } + val exprTpe = AppliedType(defn.QuotedMatchingBindingType, tpe :: Nil) + val sym = ctx0.newPatternBoundSymbol(ddef.name, exprTpe, ddef.span) + patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(exprTpe)).withSpan(ddef.span) + } + super.transform(tree) + case tree @ UnApply(_, _, (bind: Bind) :: Nil) if tree.fun.symbol == defn.InternalQuoted_patternMatchBindHole_unapply => + patBuf += bindExpr(bind.name, bind.tpe.widen, bind.span) + cpy.UnApply(tree)(patterns = untpd.Ident(nme.WILDCARD).withType(bind.tpe.widen) :: Nil) case _ => super.transform(tree) } diff --git a/library/src-bootstrapped/scala/internal/Quoted.scala b/library/src-bootstrapped/scala/internal/Quoted.scala index 4e6122e7006d..025c89837490 100644 --- a/library/src-bootstrapped/scala/internal/Quoted.scala +++ b/library/src-bootstrapped/scala/internal/Quoted.scala @@ -1,5 +1,6 @@ package scala.internal +import scala.annotation.Annotation import scala.quoted._ object Quoted { @@ -19,4 +20,14 @@ object Quoted { /** A splice in a quoted pattern is desugared by the compiler into a call to this method */ def patternHole[T]: T = throw new Error("Internal error: this method call should have been replaced by the compiler") + + /** A splice of a name in a quoted pattern is desugared by adding this annotation */ + class patternBindHole extends Annotation + + /** A splice of a name in a quoted pattern in pattern position is desugared by wrapping it in this extractor */ + object patternMatchBindHole { + def unapply(x: Any): Some[x.type] = + throw new Error("Internal error: this method call should have been replaced by the compiler") + } + } diff --git a/library/src-bootstrapped/scala/internal/quoted/Matcher.scala b/library/src-bootstrapped/scala/internal/quoted/Matcher.scala index 355101b0851e..fc29ef501423 100644 --- a/library/src-bootstrapped/scala/internal/quoted/Matcher.scala +++ b/library/src-bootstrapped/scala/internal/quoted/Matcher.scala @@ -3,6 +3,7 @@ package scala.internal.quoted import scala.annotation.internal.sharable import scala.quoted._ +import scala.quoted.matching.Bind import scala.tasty._ object Matcher { @@ -30,9 +31,14 @@ object Matcher { * @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]`` */ def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = { - import reflection._ + import reflection.{Bind => BindPattern, _} + // TODO improve performance + /** Create a new matching with the resulting binding for the symbol */ + def bindingMatched(sym: Symbol) = + Some(Tuple1(new Binding(sym.name, sym))) + /** Check that the trees match and return the contents from the pattern holes. * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. * @@ -51,6 +57,18 @@ object Matcher { sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable) } + def bindingMatch(sym: Symbol) = + Some(Tuple1(new Bind(sym.name, sym))) + + def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match { + case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), ""), Nil)) => true + case Annotated(tpt2, _) => hasBindTypeAnnotation(tpt2) + case _ => false + } + + def hasBindAnnotation(sym: Symbol) = + sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),""),List()) => true; case _ => true } + def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] = if (scrutinees.size != patterns.size) None else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*) @@ -142,24 +160,30 @@ object Matcher { foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2)) case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() => + val bindMatch = + if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol) + else Some(()) val returnTptMatch = treeMatches(tpt1, tpt2) val rhsEnv = env + (scrutinee.symbol -> pattern.symbol) val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv) - foldMatchings(returnTptMatch, rhsMatchings) + foldMatchings(bindMatch, returnTptMatch, rhsMatchings) case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) => val typeParmasMatch = treesMatch(typeParams1, typeParams2) val paramssMatch = if (paramss1.size != paramss2.size) None else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*) + val bindMatch = + if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol) + else Some(()) val tptMatch = treeMatches(tpt1, tpt2) val rhsEnv = env + (scrutinee.symbol -> pattern.symbol) ++ - typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++ - paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol) + typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++ + paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol) val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv) - foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch) + foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch) case (Lambda(_, tpt1), Lambda(_, tpt2)) => // TODO match tpt1 with tpt2? @@ -180,6 +204,10 @@ object Matcher { val finalizerMatch = treeOptMatches(finalizer1, finalizer2) foldMatchings(bodyMacth, casesMatch, finalizerMatch) + // Ignore type annotations + case (Annotated(tpt, _), _) => treeMatches(tpt, pattern) + case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt) + // No Match case _ => if (debug) @@ -229,9 +257,19 @@ object Matcher { * `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ def patternMatches(scrutinee: Pattern, pattern: Pattern)(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = (scrutinee, pattern) match { - case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) - if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" => - (env, Some(Tuple1(v1.seal))) +// case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) +// if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" => +// (env, Some(Tuple1(v1.seal))) + + case (Pattern.Bind(name1, pat1), Pattern.Unapply(Select(Ident("patternMatchBindHole"), "unapply"), Nil, List(Pattern.Bind(name2, pat2)))) +// TODO if pattern.symbol == ... => + => + val (env1, patMatch) = patternMatches(pat1, pat2) + (env1, foldMatchings(bindingMatched(scrutinee.symbol), patMatch)) + + case (Pattern.Value(Ident("_")), Pattern.Value(Ident("_"))) => // TODO add Wildcard to patterns + val bindEnv = env + (scrutinee.symbol -> pattern.symbol) + (bindEnv, Some(())) case (Pattern.Value(v1), Pattern.Value(v2)) => (env, treeMatches(v1, v2)) diff --git a/library/src-bootstrapped/scala/quoted/matching/Bind.scala b/library/src-bootstrapped/scala/quoted/matching/Bind.scala new file mode 100644 index 000000000000..d36b8b24f0de --- /dev/null +++ b/library/src-bootstrapped/scala/quoted/matching/Bind.scala @@ -0,0 +1,34 @@ +package scala.quoted +package matching + +import scala.tasty.Reflection // TODO do not depend on reflection directly + +/** Bind of an Expr[T] used to know if some Expr[T] is a reference to the binding + * + * @param name string name of this binding + * @param id unique id used for equality + */ +class Bind[T <: AnyKind] private[scala](val name: String, private[Bind] val id: Object) { self => + + override def equals(obj: Any): Boolean = obj match { + case obj: Bind[_] => obj.id == id + case _ => false + } + + override def hashCode(): Int = id.hashCode() + +} + +object Bind { + + def unapply[T](expr: Expr[T])(implicit reflect: Reflection): Option[Bind[T]] = { + import reflect.{Bind => BindPattern, _} + expr.unseal match { + case IsIdent(ref) => + val sym = ref.symbol + Some(new Bind[T](sym.name, sym)) + case _ => None + } + } + +} diff --git a/library/src-non-bootstrapped/scala/internal/Quoted.scala b/library/src-non-bootstrapped/scala/internal/Quoted.scala index 5c6dcfcd6d10..8255b7463ab0 100644 --- a/library/src-non-bootstrapped/scala/internal/Quoted.scala +++ b/library/src-non-bootstrapped/scala/internal/Quoted.scala @@ -1,5 +1,6 @@ package scala.internal +import scala.annotation.Annotation import scala.quoted._ object Quoted { @@ -16,4 +17,11 @@ object Quoted { def typeQuote[T/* <: AnyKind */]: Type[T] = throw new Error("Internal error: this method call should have been replaced by the compiler") + /** A splice in a quoted pattern is desugared by the compiler into a call to this method */ + def patternHole[T]: T = + throw new Error("Internal error: this method call should have been replaced by the compiler") + + /** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */ + class patternBindHole extends Annotation + } diff --git a/library/src-non-bootstrapped/scala/quoted/matching/Bind.scala b/library/src-non-bootstrapped/scala/quoted/matching/Bind.scala new file mode 100644 index 000000000000..9a81ac587c95 --- /dev/null +++ b/library/src-non-bootstrapped/scala/quoted/matching/Bind.scala @@ -0,0 +1,34 @@ +package scala.quoted +package matching + +import scala.tasty.Reflection // TODO do not depend on reflection directly + +/** Bind of an Expr[T] used to know if some Expr[T] is a reference to the binding + * + * @param name string name of this binding + * @param id unique id used for equality + */ +class Bind[T /*<: AnyKind*/] private[scala](val name: String, private[Bind] val id: Object) { self => + + override def equals(obj: Any): Boolean = obj match { + case obj: Bind[_] => obj.id == id + case _ => false + } + + override def hashCode(): Int = id.hashCode() + +} + +object Bind { + + def unapply[T](expr: Expr[T])(implicit reflect: Reflection): Option[Bind[T]] = { + import reflect.{Bind => BindPattern, _} + expr.unseal match { + case IsIdent(ref) => + val sym = ref.symbol + Some(new Bind[T](sym.name, sym)) + case _ => None + } + } + +} diff --git a/library/src/scala/tasty/reflect/Kernel.scala b/library/src/scala/tasty/reflect/Kernel.scala index d7ca80d08f0f..7d9fc1e08830 100644 --- a/library/src/scala/tasty/reflect/Kernel.scala +++ b/library/src/scala/tasty/reflect/Kernel.scala @@ -1469,9 +1469,12 @@ trait Kernel { def Definitions_TupleClass(arity: Int): Symbol - /** Symbol of scala.runtime.Quoted.patternHole */ + /** Symbol of scala.internal.Quoted.patternHole */ def Definitions_InternalQuoted_patternHole: Symbol + /** Symbol of scala.internal.Quoted.patternBindHole */ + def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol + def Definitions_UnitType: Type def Definitions_ByteType: Type def Definitions_ShortType: Type diff --git a/tests/neg/quotedPatterns-3.scala b/tests/neg/quotedPatterns-3.scala new file mode 100644 index 000000000000..cce479de5055 --- /dev/null +++ b/tests/neg/quotedPatterns-3.scala @@ -0,0 +1,11 @@ +object Test { + def test(x: quoted.Expr[Int]) given tasty.Reflection = x match { + case '{ val `$y`: Int = 2; 1 } => + y // error: Not found: y + case '{ ((`$y`: Int) => 3); 2 } => + y // error: Not found: y + case '{ def `$f`: Int = 8; 2 } => + f // error: Not found: f + case _ => + } +} diff --git a/tests/pos/quotedPatterns.scala b/tests/pos/quotedPatterns.scala index dba426299225..3927b9c6c60f 100644 --- a/tests/pos/quotedPatterns.scala +++ b/tests/pos/quotedPatterns.scala @@ -11,10 +11,33 @@ object Test { case '{g($y, $z)} => '{$y * $z} case '{ ((a: Int) => 3)($y) } => y case '{ 1 + ($y: Int)} => y + case '{ val a = 1 + ($y: Int); 3 } => y // currently gives an unreachable case warning // but only when used in conjunction with the others. // I believe this is because implicit arguments are not taken // into account when checking whether we have already seen an `unapply` before. + case '{ val $y: Int = $z; 1 } => + val a: quoted.matching.Bind[Int] = y + z + case '{ (($y: Int) => 1 + y + ($z: Int))(2) } => + val a: quoted.matching.Bind[Int] = y + z + case '{ def $ff: Int = $z; ff } => + val a: quoted.matching.Bind[Int] = ff + z + case '{ def $ff(i: Int): Int = $z; 2 } => + val a: quoted.matching.Bind[Int => Int] = ff + z + case '{ def $ff(i: Int)(j: Int): Int = $z; 2 } => + val a: quoted.matching.Bind[Int => Int => Int] = ff + z + case '{ def $ff[T](i: T): Int = $z; 2 } => + val a: quoted.matching.Bind[[T] => T => Int] = ff + z + case '{ Option(1) match { case $a @ Some(_) => $z } } => z + case '{ Option(1) match { case $b: Some[_] => $z } } => z + // case '{ Option(1) match { case Some($n @ _) => $z } } => z +// case '{ Option(1) match { case $c => $z } } => z case _ => '{1} } } \ No newline at end of file diff --git a/tests/run-with-compiler/quote-matcher-runtime.check b/tests/run-with-compiler/quote-matcher-runtime.check index 972b113a2074..8f11b5f0a5c0 100644 --- a/tests/run-with-compiler/quote-matcher-runtime.check +++ b/tests/run-with-compiler/quote-matcher-runtime.check @@ -236,6 +236,10 @@ Scrutinee: ((x: scala.Int) => "abc").apply(4) Pattern: scala.internal.Quoted.patternHole[scala.Function1[scala.Int, scala.Predef.String]].apply(4) Result: Some(List(Expr(((x: scala.Int) => "abc")))) +Scrutinee: ((x: scala.Int) => "abc") +Pattern: ((x: scala.Int @scala.internal.Quoted.patternBindHole) => scala.internal.Quoted.patternHole[scala.Predef.String]) +Result: Some(List(Bind(x), Expr("abc"))) + Scrutinee: scala.StringContext.apply(("abc", "xyz": scala.[scala.Predef.String])) Pattern: scala.StringContext.apply(("abc", "xyz": scala.[scala.Predef.String])) Result: Some(List()) @@ -258,6 +262,16 @@ Pattern: { } Result: Some(List()) +Scrutinee: { + val a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole val a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: Some(List(Bind(a), Expr(45))) + Scrutinee: { val a: scala.Int = 45 () @@ -278,6 +292,16 @@ Pattern: { } Result: None +Scrutinee: { + val a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole var a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: None + Scrutinee: { lazy val a: scala.Int = 45 () @@ -308,6 +332,26 @@ Pattern: { } Result: None +Scrutinee: { + lazy val a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole val a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: None + +Scrutinee: { + lazy val a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole var a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: None + Scrutinee: { var a: scala.Int = 45 () @@ -338,6 +382,26 @@ Pattern: { } Result: Some(List()) +Scrutinee: { + var a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole val a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: None + +Scrutinee: { + var a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole lazy val a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: None + Scrutinee: { scala.Predef.println() scala.Predef.println() @@ -398,6 +462,16 @@ Pattern: { } Result: Some(List()) +Scrutinee: { + def a: scala.Int = 45 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: Some(List(Bind(a), Expr(45))) + Scrutinee: { def a(x: scala.Int): scala.Int = 45 () @@ -458,6 +532,26 @@ Pattern: { } Result: Some(List()) +Scrutinee: { + def a(x: scala.Int): scala.Int = 45 + () +} +Pattern: { + def a(x: scala.Int @scala.internal.Quoted.patternBindHole): scala.Int = 45 + () +} +Result: Some(List(Bind(x))) + +Scrutinee: { + def a(x: scala.Int): scala.Int = 45 + () +} +Pattern: { + def a(x: scala.Int @scala.internal.Quoted.patternBindHole): scala.Int = 45 + () +} +Result: Some(List(Bind(x))) + Scrutinee: { def a(x: scala.Int): scala.Int = x () @@ -498,6 +592,16 @@ Pattern: 1 match { } Result: Some(List()) +Scrutinee: 1 match { + case _ => + 2 +} +Pattern: scala.internal.Quoted.patternHole[scala.Int] match { + case _ => + scala.internal.Quoted.patternHole[scala.Int] +} +Result: Some(List(Expr(1), Expr(2))) + Scrutinee: scala.Predef.??? match { case scala.None => 2 @@ -518,6 +622,26 @@ Pattern: scala.Predef.??? match { } Result: Some(List()) +Scrutinee: scala.Predef.??? match { + case scala.Some(n) => + 2 +} +Pattern: scala.Predef.??? match { + case scala.Some(scala.internal.Quoted.patternMatchBindHole(n)) => + 2 +} +Result: Some(List(Binding(n))) + +Scrutinee: scala.Predef.??? match { + case scala.Some(n @ scala.Some(m)) => + 2 +} +Pattern: scala.Predef.??? match { + case scala.Some(scala.internal.Quoted.patternMatchBindHole(n @ scala.Some(scala.internal.Quoted.patternMatchBindHole(m)))) => + 2 +} +Result: Some(List(Binding(n), Binding(m))) + Scrutinee: try 1 catch { case _ => 2 @@ -538,3 +662,23 @@ Pattern: try 1 finally { } Result: Some(List()) +Scrutinee: try 1 catch { + case _ => + 2 +} +Pattern: try scala.internal.Quoted.patternHole[scala.Int] catch { + case _ => + scala.internal.Quoted.patternHole[scala.Int] +} +Result: Some(List(Expr(1), Expr(2))) + +Scrutinee: try 1 finally { + 2 + () +} +Pattern: try scala.internal.Quoted.patternHole[scala.Int] finally { + scala.internal.Quoted.patternHole[scala.Int] + () +} +Result: Some(List(Expr(1), Expr(2))) + diff --git a/tests/run-with-compiler/quote-matcher-runtime/quoted_1.scala b/tests/run-with-compiler/quote-matcher-runtime/quoted_1.scala index 5204b23e35c8..8e2ef4d8a071 100644 --- a/tests/run-with-compiler/quote-matcher-runtime/quoted_1.scala +++ b/tests/run-with-compiler/quote-matcher-runtime/quoted_1.scala @@ -8,7 +8,7 @@ object Macros { inline def matches[A, B](a: => A, b: => B): Unit = ${impl('a, 'b)} private def impl[A, B](a: Expr[A], b: Expr[B])(implicit reflect: Reflection): Expr[Unit] = { - import reflect._ + import reflect.{Bind => _, _} val res = scala.internal.quoted.Matcher.unapply[Tuple](a)(b, reflect).map { tup => tup.toArray.toList.map { @@ -16,6 +16,8 @@ object Macros { s"Expr(${r.unseal.show})" case r: quoted.Type[_] => s"Type(${r.unseal.show})" + case r: Bind[_] => + s"Bind(${r.name})" } } diff --git a/tests/run-with-compiler/quote-matcher-runtime/quoted_2.scala b/tests/run-with-compiler/quote-matcher-runtime/quoted_2.scala index 5caec627a51e..7565dc0600aa 100644 --- a/tests/run-with-compiler/quote-matcher-runtime/quoted_2.scala +++ b/tests/run-with-compiler/quote-matcher-runtime/quoted_2.scala @@ -3,7 +3,7 @@ import Macros._ import scala.internal.quoted.Matcher._ -import scala.internal.Quoted.patternHole +import scala.internal.Quoted.{patternHole, patternBindHole, patternMatchBindHole} object Test { @@ -81,55 +81,54 @@ object Test { matches((() => "abc")(), (patternHole[() => String]).apply()) matches((x: Int) => "abc", patternHole[Int=> String]) matches(((x: Int) => "abc")(4), (patternHole[Int => String]).apply(4)) - // matches((x: Int) => "abc", (x: bindHole[Int]) => patternHole[String]) + matches((x: Int) => "abc", (x: Int @patternBindHole) => patternHole[String]) matches(StringContext("abc", "xyz"), StringContext("abc", "xyz")) matches(StringContext("abc", "xyz"), StringContext(patternHole, patternHole)) matches(StringContext("abc", "xyz"), StringContext(patternHole[Seq[String]]: _*)) matches({ val a: Int = 45 }, { val a: Int = 45 }) - // matches({ val a: Int = 45 }, { val a: bindHole[Int] = patternHole }) + matches({ val a: Int = 45 }, { @patternBindHole val a: Int = patternHole }) matches({ val a: Int = 45 }, { lazy val a: Int = 45 }) matches({ val a: Int = 45 }, { var a: Int = 45 }) - // matches({ val a: Int = 45 }, { var a: bindHole[Int] = patternHole }) + matches({ val a: Int = 45 }, { @patternBindHole var a: Int = patternHole }) matches({ lazy val a: Int = 45 }, { val a: Int = 45 }) matches({ lazy val a: Int = 45 }, { lazy val a: Int = 45 }) matches({ lazy val a: Int = 45 }, { var a: Int = 45 }) - // matches({ lazy val a: Int = 45 }, { val a: bindHole[Int] = patternHole }) - // matches({ lazy val a: Int = 45 }, { var a: bindHole[Int] = patternHole }) + matches({ lazy val a: Int = 45 }, { @patternBindHole val a: Int = patternHole }) + matches({ lazy val a: Int = 45 }, { @patternBindHole var a: Int = patternHole }) matches({ var a: Int = 45 }, { val a: Int = 45 }) matches({ var a: Int = 45 }, { lazy val a: Int = 45 }) matches({ var a: Int = 45 }, { var a: Int = 45 }) - // matches({ var a: Int = 45 }, { val a: bindHole[Int] = patternHole }) - // matches({ var a: Int = 45 }, { lazy val a: bindHole[Int] = patternHole }) + matches({ var a: Int = 45 }, { @patternBindHole val a: Int = patternHole }) + matches({ var a: Int = 45 }, { @patternBindHole lazy val a: Int = patternHole }) matches({ println(); println() }, { println(); println() }) matches({ { println() }; println() }, { println(); println() }) matches({ println(); { println() } }, { println(); println() }) matches({ println(); println() }, { println(); { println() } }) matches({ println(); println() }, { { println() }; println() }) matches({ def a: Int = 45 }, { def a: Int = 45 }) - // matches({ def a: Int = 45 }, { def a: bindHole[Int] = patternHole[Int] }) + matches({ def a: Int = 45 }, { @patternBindHole def a: Int = patternHole[Int] }) matches({ def a(x: Int): Int = 45 }, { def a(x: Int): Int = 45 }) matches({ def a(x: Int): Int = 45 }, { def a(x: Int, y: Int): Int = 45 }) matches({ def a(x: Int): Int = 45 }, { def a(x: Int)(y: Int): Int = 45 }) matches({ def a(x: Int, y: Int): Int = 45 }, { def a(x: Int): Int = 45 }) matches({ def a(x: Int)(y: Int): Int = 45 }, { def a(x: Int): Int = 45 }) matches({ def a(x: String): Int = 45 }, { def a(x: String): Int = 45 }) - // matches({ def a(x: Int): Int = 45 }, { def a(x: bindHole[Int]): Int = 45 }) - // matches({ def a(x: Int): Int = 45 }, { def a(x: bindHole[Int]): bindHole[Int] = 45 }) + matches({ def a(x: Int): Int = 45 }, { def a(x: Int @patternBindHole): Int = 45 }) + matches({ def a(x: Int): Int = 45 }, { def a(x: Int @patternBindHole): Int = 45 }) matches({ def a(x: Int): Int = x }, { def b(y: Int): Int = y }) matches({ def a: Int = a }, { def b: Int = b }) matches({ lazy val a: Int = a }, { lazy val b: Int = b }) matches(1 match { case _ => 2 }, 1 match { case _ => 2 }) - // matches(1 match { case _ => 2 }, patternHole[Int] match { case _ => patternHole[Int] }) + matches(1 match { case _ => 2 }, patternHole[Int] match { case _ => patternHole[Int] }) matches(??? match { case None => 2 }, ??? match { case None => 2 }) matches(??? match { case Some(1) => 2 }, ??? match { case Some(1) => 2 }) // matches(??? match { case Some(1) => 2 }, ??? match { case Some(patternMatchHole()) => 2 }) - // matches(??? match { case Some(n) => 2 }, ??? match { case Some(patternBindHole(n)) => 2 }) - // matches(??? match { case Some(n @ Some(m)) => 2 }, ??? match { case Some(patternBindHole(n @ Some(patternBindHole(m)))) => 2 }) + matches(??? match { case Some(n) => 2 }, ??? match { case Some(patternMatchBindHole(n)) => 2 }) + matches(??? match { case Some(n @ Some(m)) => 2 }, ??? match { case Some(patternMatchBindHole(n @ Some(patternMatchBindHole(m)))) => 2 }) matches(try 1 catch { case _ => 2 }, try 1 catch { case _ => 2 }) matches(try 1 finally 2, try 1 finally 2) - // matches(try 1 catch { case _ => 2 }, try patternHole[Int] catch { case _ => patternHole[Int] }) - // matches(try 1 finally 2, try patternHole[Int] finally patternHole[Int]) - + matches(try 1 catch { case _ => 2 }, try patternHole[Int] catch { case _ => patternHole[Int] }) + matches(try 1 finally 2, try patternHole[Int] finally patternHole[Int]) } } diff --git a/tests/run-with-compiler/quote-matcher-symantics-1/quoted_1.scala b/tests/run-with-compiler/quote-matcher-symantics-1/quoted_1.scala index da66bbc38e5a..01fed1d67835 100644 --- a/tests/run-with-compiler/quote-matcher-symantics-1/quoted_1.scala +++ b/tests/run-with-compiler/quote-matcher-symantics-1/quoted_1.scala @@ -14,15 +14,15 @@ object Macros { def lift(e: Expr[DSL]): Expr[T] = e match { case '{ LitDSL(${ Literal(c) }) } => - // case scala.runtime.quoted.Matcher.unapply[Tuple1[Expr[Int]]](Tuple1(Literal(c)))(/*implicits*/ '{ LitDSL(patternHole[Int]) }, reflect) => + // case scala.internal.quoted.Matcher.unapply[Tuple1[Expr[Int]]](Tuple1(Literal(c)))(/*implicits*/ '{ LitDSL(patternHole[Int]) }, reflect) => '{ $sym.value(${c.toExpr}) } case '{ ($x: DSL) + ($y: DSL) } => - // case scala.runtime.quoted.Matcher.unapply[Tuple2[Expr[DSL], Expr[DSL]]](Tuple2(x, y))(/*implicits*/ '{ patternHole[DSL] + patternHole[DSL] }, reflect) => + // case scala.internal.quoted.Matcher.unapply[Tuple2[Expr[DSL], Expr[DSL]]](Tuple2(x, y))(/*implicits*/ '{ patternHole[DSL] + patternHole[DSL] }, reflect) => '{ $sym.plus(${lift(x)}, ${lift(y)}) } case '{ ($x: DSL) * ($y: DSL) } => - // case scala.runtime.quoted.Matcher.unapply[Tuple2[Expr[DSL], Expr[DSL]]](Tuple2(x, y))(/*implicits*/ '{ patternHole[DSL] * patternHole[DSL] }, reflect) => + // case scala.internal.quoted.Matcher.unapply[Tuple2[Expr[DSL], Expr[DSL]]](Tuple2(x, y))(/*implicits*/ '{ patternHole[DSL] * patternHole[DSL] }, reflect) => '{ $sym.times(${lift(x)}, ${lift(y)}) } case _ => diff --git a/tests/run-with-compiler/quote-matcher-symantics-2.check b/tests/run-with-compiler/quote-matcher-symantics-2.check new file mode 100644 index 000000000000..416ca0ef5f81 --- /dev/null +++ b/tests/run-with-compiler/quote-matcher-symantics-2.check @@ -0,0 +1,23 @@ +1 +1 +LitAST(1) + +1 + 2 +3 +PlusAST(LitAST(1),LitAST(2)) + +1 * 2 +2 +TimesAST(LitAST(1),LitAST(2)) + +1 + 3 * 4 +13 +PlusAST(LitAST(1),TimesAST(LitAST(3),LitAST(4))) + +2 + 5 +7 +AppAST(, LitAST(5)) + +2 + 2 +4 +PlusAST(LitAST(2),LitAST(2)) diff --git a/tests/run-with-compiler/quote-matcher-symantics-2/quoted_1.scala b/tests/run-with-compiler/quote-matcher-symantics-2/quoted_1.scala new file mode 100644 index 000000000000..985c7acd34ee --- /dev/null +++ b/tests/run-with-compiler/quote-matcher-symantics-2/quoted_1.scala @@ -0,0 +1,103 @@ +import scala.quoted._ +import scala.quoted.matching._ + +import scala.tasty.Reflection + +object Macros { + + inline def liftString(a: => DSL): String = ${impl(StringNum, 'a)} + + inline def liftCompute(a: => DSL): Int = ${impl(ComputeNum, 'a)} + + inline def liftAST(a: => DSL): ASTNum = ${impl(ASTNum, 'a)} + + private def impl[T: Type](sym: Symantics[T], a: Expr[DSL])(implicit reflect: Reflection): Expr[T] = { + + def lift(e: Expr[DSL])(implicit env: Map[Bind[DSL], Expr[T]]): Expr[T] = e match { + + case '{ LitDSL(${Literal(c)}) } => sym.value(c) + + case '{ ($x: DSL) + ($y: DSL) } => sym.plus(lift(x), lift(y)) + + case '{ ($x: DSL) * ($y: DSL) } => sym.times(lift(x), lift(y)) + + case '{ ($f: DSL => DSL)($x: DSL) } => sym.app(liftFun(f), lift(x)) + + case '{ val $x: DSL = $value; $body: DSL } => lift(body)(env + (x -> lift(value))) + + case Bind(b) if env.contains(b) => env(b) + + case _ => + import reflect._ + error("Expected explicit DSL", e.unseal.pos) + ??? + } + + def liftFun(e: Expr[DSL => DSL])(implicit env: Map[Bind[DSL], Expr[T]]): Expr[T => T] = e match { + case '{ ($x: DSL) => ($body: DSL) } => + sym.lam((y: Expr[T]) => lift(body)(env + (x -> y))) + + case _ => + import reflect._ + error("Expected explicit DSL => DSL", e.unseal.pos) + ??? + } + + lift(a)(Map.empty) + } + +} + +// +// DSL in which the user write the code +// + +trait DSL { + def + (x: DSL): DSL = ??? + def * (x: DSL): DSL = ??? +} +case class LitDSL(x: Int) extends DSL + +// +// Interpretation of the DSL +// + +trait Symantics[Num] { + def value(x: Int): Expr[Num] + def plus(x: Expr[Num], y: Expr[Num]): Expr[Num] + def times(x: Expr[Num], y: Expr[Num]): Expr[Num] + def app(f: Expr[Num => Num], x: Expr[Num]): Expr[Num] + def lam(body: Expr[Num] => Expr[Num]): Expr[Num => Num] +} + +object StringNum extends Symantics[String] { + def value(x: Int): Expr[String] = x.toString.toExpr + def plus(x: Expr[String], y: Expr[String]): Expr[String] = '{ s"${$x} + ${$y}" } // '{ x + " + " + y } + def times(x: Expr[String], y: Expr[String]): Expr[String] = '{ s"${$x} * ${$y}" } + def app(f: Expr[String => String], x: Expr[String]): Expr[String] = f(x) // functions are beta reduced + def lam(body: Expr[String] => Expr[String]): Expr[String => String] = '{ (x: String) => ${body('x)} } +} + +object ComputeNum extends Symantics[Int] { + def value(x: Int): Expr[Int] = x.toExpr + def plus(x: Expr[Int], y: Expr[Int]): Expr[Int] = '{ $x + $y } + def times(x: Expr[Int], y: Expr[Int]): Expr[Int] = '{ $x * $y } + def app(f: Expr[Int => Int], x: Expr[Int]): Expr[Int] = '{ $f($x) } + def lam(body: Expr[Int] => Expr[Int]): Expr[Int => Int] = '{ (x: Int) => ${body('x)} } +} + +object ASTNum extends Symantics[ASTNum] { + def value(x: Int): Expr[ASTNum] = '{ LitAST(${x.toExpr}) } + def plus(x: Expr[ASTNum], y: Expr[ASTNum]): Expr[ASTNum] = '{ PlusAST($x, $y) } + def times(x: Expr[ASTNum], y: Expr[ASTNum]): Expr[ASTNum] = '{ TimesAST($x, $y) } + def app(f: Expr[ASTNum => ASTNum], x: Expr[ASTNum]): Expr[ASTNum] = '{ AppAST($f, $x) } + def lam(body: Expr[ASTNum] => Expr[ASTNum]): Expr[ASTNum => ASTNum] = '{ (x: ASTNum) => ${body('x)} } +} + +trait ASTNum +case class LitAST(x: Int) extends ASTNum +case class PlusAST(x: ASTNum, y: ASTNum) extends ASTNum +case class TimesAST(x: ASTNum, y: ASTNum) extends ASTNum +case class AppAST(x: ASTNum => ASTNum, y: ASTNum) extends ASTNum { + override def toString: String = s"AppAST(, $y)" +} diff --git a/tests/run-with-compiler/quote-matcher-symantics-2/quoted_2.scala b/tests/run-with-compiler/quote-matcher-symantics-2/quoted_2.scala new file mode 100644 index 000000000000..1b26bd657837 --- /dev/null +++ b/tests/run-with-compiler/quote-matcher-symantics-2/quoted_2.scala @@ -0,0 +1,31 @@ +import Macros._ + +object Test { + + def main(args: Array[String]): Unit = { + println(liftString(LitDSL(1))) + println(liftCompute(LitDSL(1))) + println(liftAST(LitDSL(1))) + println() + println(liftString(LitDSL(1) + LitDSL(2))) + println(liftCompute(LitDSL(1) + LitDSL(2))) + println(liftAST(LitDSL(1) + LitDSL(2))) + println() + println(liftString(LitDSL(1) * LitDSL(2))) + println(liftCompute(LitDSL(1) * LitDSL(2))) + println(liftAST(LitDSL(1) * LitDSL(2))) + println() + println(liftString(LitDSL(1) + LitDSL(3) * LitDSL(4))) + println(liftCompute(LitDSL(1) + LitDSL(3) * LitDSL(4))) + println(liftAST(LitDSL(1) + LitDSL(3) * LitDSL(4))) + println() + println(liftString(((x: DSL) => LitDSL(2) + x).apply(LitDSL(5)))) + println(liftCompute(((x: DSL) => LitDSL(2) + x).apply(LitDSL(5)))) + println(liftAST(((x: DSL) => LitDSL(2) + x).apply(LitDSL(5)))) + println() + println(liftString({ val x: DSL = LitDSL(2); x + x })) + println(liftCompute({ val x: DSL = LitDSL(2); x + x })) + println(liftAST({ val x: DSL = LitDSL(2); x + x })) + } + +}