diff --git a/library/src-bootstrapped/scala/quoted/util/ExprMap.scala b/library/src-bootstrapped/scala/quoted/util/ExprMap.scala new file mode 100644 index 000000000000..bbd37f11ea5f --- /dev/null +++ b/library/src-bootstrapped/scala/quoted/util/ExprMap.scala @@ -0,0 +1,159 @@ +package scala.quoted.util + +import scala.quoted._ + +trait ExprMap { + + /** Map an expression `e` with a type `tpe` */ + def transform[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] + + /** Map subexpressions an expression `e` with a type `tpe` */ + def transformChildren[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] = { + import qctx.tasty.{_, given} + final class MapChildren() { + + def transformStatement(tree: Statement)(given ctx: Context): Statement = { + def localCtx(definition: Definition): Context = definition.symbol.localContext + tree match { + case tree: Term => + transformTerm(tree, defn.AnyType) + case tree: Definition => + transformDefinition(tree) + case tree: Import => + tree + } + } + + def transformDefinition(tree: Definition)(given ctx: Context): Definition = { + def localCtx(definition: Definition): Context = definition.symbol.localContext + tree match { + case tree: ValDef => + implicit val ctx = localCtx(tree) + val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe)) + ValDef.copy(tree)(tree.name, tree.tpt, rhs1) + case tree: DefDef => + implicit val ctx = localCtx(tree) + DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe))) + case tree: TypeDef => + tree + case tree: ClassDef => + val newBody = transformStats(tree.body) + ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, newBody) + } + } + + def transformTermChildren(tree: Term, tpe: Type)(given ctx: Context): Term = tree match { + case Ident(name) => + tree + case Select(qualifier, name) => + Select.copy(tree)(transformTerm(qualifier, qualifier.tpe), name) + case This(qual) => + tree + case Super(qual, mix) => + tree + case tree @ Apply(fun, args) => + val MethodType(_, tpes, _) = fun.tpe.widen + Apply.copy(tree)(transformTerm(fun, defn.AnyType), transformTerms(args, tpes)) + case TypeApply(fun, args) => + TypeApply.copy(tree)(transformTerm(fun, defn.AnyType), args) + case _: Literal => + tree + case New(tpt) => + New.copy(tree)(transformTypeTree(tpt)) + case Typed(expr, tpt) => + val tp = tpt.tpe match + // TODO improve code + case AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), ""), List(tp0: Type)) => + type T + val a = tp0.seal.asInstanceOf[quoted.Type[T]] + '[Seq[$a]].unseal.tpe + case tp => tp + Typed.copy(tree)(transformTerm(expr, tp), transformTypeTree(tpt)) + case tree: NamedArg => + NamedArg.copy(tree)(tree.name, transformTerm(tree.value, tpe)) + case Assign(lhs, rhs) => + Assign.copy(tree)(lhs, transformTerm(rhs, lhs.tpe.widen)) + case Block(stats, expr) => + Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe)) + case If(cond, thenp, elsep) => + If.copy(tree)( + transformTerm(cond, defn.BooleanType), + transformTerm(thenp, tpe), + transformTerm(elsep, tpe)) + case _: Closure => + tree + case Match(selector, cases) => + Match.copy(tree)(transformTerm(selector, selector.tpe), transformCaseDefs(cases, tpe)) + case Return(expr) => + // FIXME + // ctx.owner seems to be set to the wrong symbol + // Return.copy(tree)(transformTerm(expr, expr.tpe)) + tree + case While(cond, body) => + While.copy(tree)(transformTerm(cond, defn.BooleanType), transformTerm(body, defn.AnyType)) + case Try(block, cases, finalizer) => + Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases, defn.AnyType), finalizer.map(x => transformTerm(x, defn.AnyType))) + case Repeated(elems, elemtpt) => + Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt) + case Inlined(call, bindings, expansion) => + Inlined.copy(tree)(call, transformDefinitions(bindings), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/) + } + + def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term = + tree match { + case _: Closure => + tree + case _: Inlined => + transformTermChildren(tree, tpe) + case _ => + tree.tpe.widen match { + case _: MethodType | _: PolyType => + transformTermChildren(tree, tpe) + case _ => + type X + val expr = tree.seal.asInstanceOf[Expr[X]] + val t = tpe.seal.asInstanceOf[quoted.Type[X]] + transform(expr)(given qctx, t).unseal + } + } + + def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree + + def transformCaseDef(tree: CaseDef, tpe: Type)(given ctx: Context): CaseDef = + CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, defn.BooleanType)), transformTerm(tree.rhs, tpe)) + + def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = { + TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs)) + } + + def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] = + trees mapConserve (transformStatement(_)) + + def transformDefinitions(trees: List[Definition])(given ctx: Context): List[Definition] = + trees mapConserve (transformDefinition(_)) + + def transformTerms(trees: List[Term], tpes: List[Type])(given ctx: Context): List[Term] = + var tpes2 = tpes // TODO use proper zipConserve + trees mapConserve { x => + val tpe :: tail = tpes2 + tpes2 = tail + transformTerm(x, tpe) + } + + def transformTerms(trees: List[Term], tpe: Type)(given ctx: Context): List[Term] = + trees.mapConserve(x => transformTerm(x, tpe)) + + def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] = + trees mapConserve (transformTypeTree(_)) + + def transformCaseDefs(trees: List[CaseDef], tpe: Type)(given ctx: Context): List[CaseDef] = + trees mapConserve (x => transformCaseDef(x, tpe)) + + def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] = + trees mapConserve (transformTypeCaseDef(_)) + + } + new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).seal.cast[T] // Cast will only fail if this implementation has a bug + } + +} diff --git a/library/src/scala/quoted/Expr.scala b/library/src/scala/quoted/Expr.scala index ed9a35c5bea1..8d37819b107f 100644 --- a/library/src/scala/quoted/Expr.scala +++ b/library/src/scala/quoted/Expr.scala @@ -19,6 +19,17 @@ package quoted { */ final def getValue[U >: T](given qctx: QuoteContext, valueOf: ValueOfExpr[U]): Option[U] = valueOf(this) + /** Pattern matches `this` against `that`. Effectively performing a deep equality check. + * It does the equivalent of + * ``` + * this match + * case '{...} => true // where the contens of the pattern are the contents of `that` + * case _ => false + * ``` + */ + final def matches(that: Expr[Any])(given qctx: QuoteContext): Boolean = + !scala.internal.quoted.Expr.unapply[Unit, Unit](this)(given that, false, qctx).isEmpty + } object Expr { diff --git a/tests/run-macros/expr-map-1.check b/tests/run-macros/expr-map-1.check new file mode 100644 index 000000000000..0f16bb01c517 --- /dev/null +++ b/tests/run-macros/expr-map-1.check @@ -0,0 +1,27 @@ +oof +oofoof +ylppa +kcolb +kcolb +neht +esle +lav +vals +fed +defs +fed +rab +yrt +yllanif +hctac +elihw +wen +depyt +depyt +grAdeman +qual +adbmal +ravsgra +hctam +def +ooF wen diff --git a/tests/run-macros/expr-map-1/Macro_1.scala b/tests/run-macros/expr-map-1/Macro_1.scala new file mode 100644 index 000000000000..872c072a1b3d --- /dev/null +++ b/tests/run-macros/expr-map-1/Macro_1.scala @@ -0,0 +1,18 @@ +import scala.quoted._ +import scala.quoted.matching._ + +inline def rewrite[T](x: => Any): Any = ${ stringRewriter('x) } + +private def stringRewriter(e: Expr[Any])(given QuoteContext): Expr[Any] = + StringRewriter.transform(e) + +private object StringRewriter extends util.ExprMap { + + def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = e match + case Const(s: String) => + Expr(s.reverse) match + case '{ $x: T } => x + case _ => e // e had a singlton String type + case _ => transformChildren(e) + +} diff --git a/tests/run-macros/expr-map-1/Test_2.scala b/tests/run-macros/expr-map-1/Test_2.scala new file mode 100644 index 000000000000..70aada14219a --- /dev/null +++ b/tests/run-macros/expr-map-1/Test_2.scala @@ -0,0 +1,130 @@ +object Test { + + def main(args: Array[String]): Unit = { + println(rewrite("foo")) + println(rewrite("foo" + "foo")) + + rewrite { + println("apply") + } + + rewrite { + println("block") + println("block") + } + + val b: Boolean = true + rewrite { + if b then println("then") + else println("else") + } + + rewrite { + if !b then println("then") + else println("else") + } + + rewrite { + val s: String = "val" + println(s) + } + + rewrite { + val s: "vals" = "vals" + println(s) // prints "foo" not "oof" + } + + rewrite { + def s: String = "def" + println(s) + } + + rewrite { + def s: "defs" = "defs" + println(s) // prints "foo" not "oof" + } + + rewrite { + def s(x: String): String = x + println(s("def")) + } + + rewrite { + var s: String = "var" + s = "bar" + println(s) + } + + rewrite { + try println("try") + finally println("finally") + } + + rewrite { + try throw new Exception() + catch case x: Exception => println("catch") + } + + rewrite { + var x = true + while (x) { + println("while") + x = false + } + } + + rewrite { + val t = new Tuple1("new") + println(t._1) + } + + rewrite { + println("typed": String) + println("typed": Any) + } + + rewrite { + val f = new Foo(foo = "namedArg") + println(f.foo) + } + + rewrite { + println("qual".reverse) + } + + rewrite { + val f = () => "lambda" + println(f()) + } + + rewrite { + def f(args: String*): String = args.mkString + println(f("var", "args")) + } + + rewrite { + "match" match { + case "match" => println("match") + case x => println("x") + } + } + + // FIXME should print fed + rewrite { + def s: String = return "def" + println(s) + } + + rewrite { + class Foo { + println("new Foo") + } + new Foo + } + + + } + +} + +class Foo(val foo: String) diff --git a/tests/run-macros/expr-map-2.check b/tests/run-macros/expr-map-2.check new file mode 100644 index 000000000000..cbaaa5a883be --- /dev/null +++ b/tests/run-macros/expr-map-2.check @@ -0,0 +1,3 @@ +Foo(2) +4 +4 diff --git a/tests/run-macros/expr-map-2/Macro_1.scala b/tests/run-macros/expr-map-2/Macro_1.scala new file mode 100644 index 000000000000..ede05f1f9f87 --- /dev/null +++ b/tests/run-macros/expr-map-2/Macro_1.scala @@ -0,0 +1,19 @@ +import scala.quoted._ +import scala.quoted.matching._ + +inline def rewrite[T](x: => Any): Any = ${ stringRewriter('x) } + +private def stringRewriter(e: Expr[Any])(given QuoteContext): Expr[Any] = + StringRewriter.transform(e) + +private object StringRewriter extends util.ExprMap { + + def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = e match + case '{ ($x: Foo).x } => + '{ new Foo(4).x } match case '{ $e: T } => e + case _ => + transformChildren(e) + +} + +case class Foo(x: Int) diff --git a/tests/run-macros/expr-map-2/Test_2.scala b/tests/run-macros/expr-map-2/Test_2.scala new file mode 100644 index 000000000000..7790ec34cf9f --- /dev/null +++ b/tests/run-macros/expr-map-2/Test_2.scala @@ -0,0 +1,13 @@ +object Test { + + def main(args: Array[String]): Unit = { + println(rewrite(new Foo(2))) + println(rewrite(new Foo(2).x)) + + rewrite { + val foo = new Foo(2) + println(foo.x) + } + + } +} diff --git a/tests/run-macros/flops-rewrite-2/Macro_1.scala b/tests/run-macros/flops-rewrite-2/Macro_1.scala index 97ee354c8611..9b6124ad499a 100644 --- a/tests/run-macros/flops-rewrite-2/Macro_1.scala +++ b/tests/run-macros/flops-rewrite-2/Macro_1.scala @@ -35,7 +35,7 @@ private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = { fixPoint = true ) - val x2 = rewriter.rewrite(x) + val x2 = rewriter.transform(x) '{ println(${Expr(x.show)}) @@ -63,39 +63,13 @@ private object Rewriter { new Rewriter(preTransform, postTransform, fixPoint) } -private class Rewriter(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean) { - def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = { +private class Rewriter(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean) extends util.ExprMap { + def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = { val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei)) - val e3 = rewriteChildren(e2) + val e3 = transformChildren(e2) val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei)) - if fixPoint && e4 != e then rewrite(e4) + if fixPoint && !e4.matches(e) then transform(e4) else e4 } - def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = { - import qctx.tasty.{_, given} - class MapChildren extends TreeMap { - override def transformTerm(tree: Term)(given ctx: Context): Term = tree match { - case _: Closure => - tree - case _: Inlined | _: Select => - transformChildrenTerm(tree) - case _ => - tree.tpe.widen match { - case _: MethodType | _: PolyType => - transformChildrenTerm(tree) - case _ => - tree.seal match { - case '{ $x: $t } => rewrite(x).unseal - } - } - } - def transformChildrenTerm(tree: Term)(given ctx: Context): Term = - super.transformTerm(tree) - } - (new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug - } - } - - diff --git a/tests/run-macros/flops-rewrite-3/Macro_1.scala b/tests/run-macros/flops-rewrite-3/Macro_1.scala index 55543e0a7c8a..241abe5ccb21 100644 --- a/tests/run-macros/flops-rewrite-3/Macro_1.scala +++ b/tests/run-macros/flops-rewrite-3/Macro_1.scala @@ -33,7 +33,7 @@ private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = { } ) - val x2 = rewriter.rewrite(x) + val x2 = rewriter.transform(x) '{ println(${Expr(x.show)}) @@ -92,7 +92,7 @@ private object Rewriter { def apply(): Rewriter = new Rewriter(Nil, Nil, false) } -private class Rewriter private (preTransform: List[Transformation] = Nil, postTransform: List[Transformation] = Nil, fixPoint: Boolean) { +private class Rewriter private (preTransform: List[Transformation] = Nil, postTransform: List[Transformation] = Nil, fixPoint: Boolean) extends util.ExprMap { def withFixPoint: Rewriter = new Rewriter(preTransform, postTransform, fixPoint = true) @@ -101,37 +101,11 @@ private class Rewriter private (preTransform: List[Transformation] = Nil, postTr def withPost(transform: Transformation): Rewriter = new Rewriter(preTransform, transform :: postTransform, fixPoint) - def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = { + def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = { val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei)) - val e3 = rewriteChildren(e2) + val e3 = transformChildren(e2) val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei)) - if fixPoint && e4 != e then rewrite(e4) else e4 - } - - def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = { - import qctx.tasty.{_, given} - class MapChildren extends TreeMap { - override def transformTerm(tree: Term)(given ctx: Context): Term = tree match { - case _: Closure => - tree - case _: Inlined | _: Select => - transformChildrenTerm(tree) - case _ => - tree.tpe.widen match { - case _: MethodType | _: PolyType => - transformChildrenTerm(tree) - case _ => - tree.seal match { - case '{ $x: $t } => rewrite(x).unseal - } - } - } - def transformChildrenTerm(tree: Term)(given ctx: Context): Term = - super.transformTerm(tree) - } - (new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug + if fixPoint && !e4.matches(e) then transform(e4) else e4 } } - - diff --git a/tests/run-macros/flops-rewrite/Macro_1.scala b/tests/run-macros/flops-rewrite/Macro_1.scala index d78df3dc4cbf..65c33bfd2033 100644 --- a/tests/run-macros/flops-rewrite/Macro_1.scala +++ b/tests/run-macros/flops-rewrite/Macro_1.scala @@ -13,7 +13,7 @@ private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = { } ) - val x2 = rewriter.rewrite(x) + val x2 = rewriter.transform(x) '{ println(${Expr(x.show)}) @@ -28,12 +28,12 @@ private object Rewriter { new Rewriter(preTransform, postTransform, fixPoint) } -private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr[Any] => Expr[Any], fixPoint: Boolean) { - def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = { +private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr[Any] => Expr[Any], fixPoint: Boolean) extends util.ExprMap { + def transform[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = { val e2 = checkedTransform(e, preTransform) - val e3 = rewriteChildren(e2) + val e3 = transformChildren(e2) val e4 = checkedTransform(e3, postTransform) - if fixPoint && e4 != e then rewrite(e4) + if fixPoint && !e4.matches(e) then transform(e4) else e4 } @@ -54,30 +54,4 @@ private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr } } - def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = { - import qctx.tasty.{_, given} - class MapChildren extends TreeMap { - override def transformTerm(tree: Term)(given ctx: Context): Term = tree match { - case _: Closure => - tree - case _: Inlined | _: Select => - transformChildrenTerm(tree) - case _ => - tree.tpe.widen match { - case _: MethodType | _: PolyType => - transformChildrenTerm(tree) - case _ => - tree.seal match { - case '{ $x: $t } => rewrite(x).unseal - } - } - } - def transformChildrenTerm(tree: Term)(given ctx: Context): Term = - super.transformTerm(tree) - } - (new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug - } - } - - diff --git a/tests/run-staging/expr-matches.scala b/tests/run-staging/expr-matches.scala new file mode 100644 index 000000000000..2f0255fbadfb --- /dev/null +++ b/tests/run-staging/expr-matches.scala @@ -0,0 +1,13 @@ +import scala.quoted._ +import scala.quoted.staging._ + + +object Test { + given Toolbox = Toolbox.make(getClass.getClassLoader) + def main(args: Array[String]): Unit = withQuoteContext { + assert('{1} matches '{1}) + assert('{println("foo")} matches '{println("foo")}) + assert('{println("foo")} matches '{println(${Expr("foo")})}) + assert('{println(Some("foo"))} matches '{println(${ val a = '{Some("foo")}; a})}) + } +}