diff --git a/community-build/community-projects/utest b/community-build/community-projects/utest index e0e59628c321..9fa499c1f0e6 160000 --- a/community-build/community-projects/utest +++ b/community-build/community-projects/utest @@ -1 +1 @@ -Subproject commit e0e59628c321e213a098feb94a7d4b258266c422 +Subproject commit 9fa499c1f0e6ef8ee7fbdb916fc245f754bb27ed diff --git a/docs/docs/reference/metaprogramming/tasty-reflect.md b/docs/docs/reference/metaprogramming/tasty-reflect.md index 5ab40ba057d9..250430950bb3 100644 --- a/docs/docs/reference/metaprogramming/tasty-reflect.md +++ b/docs/docs/reference/metaprogramming/tasty-reflect.md @@ -119,7 +119,7 @@ def macroImpl()(qctx: QuoteContext): Expr[Unit] = { ### Tree Utilities -`scala.tasty.reflect.TreeUtils` contains three facilities for tree traversal and +`scala.tasty.reflect` contains three facilities for tree traversal and transformations. `TreeAccumulator` ties the knot of a traversal. By calling `foldOver(x, tree))` @@ -144,7 +144,7 @@ but without returning any value. Finally a `TreeMap` performs a transformation. #### Let -`scala.tasty.reflect.utils.TreeUtils` also offers a method `let` that allows us +`scala.tasty.Reflection` also offers a method `let` that allows us to bind the `rhs` to a `val` and use it in `body`. Additionally, `lets` binds the given `terms` to names and use them in the `body`. Their type definitions are shown below: diff --git a/library/src/scala/internal/quoted/Matcher.scala b/library/src/scala/internal/quoted/Matcher.scala index b8860f150c8c..21f351a51e71 100644 --- a/library/src/scala/internal/quoted/Matcher.scala +++ b/library/src/scala/internal/quoted/Matcher.scala @@ -317,7 +317,7 @@ private[quoted] object Matcher { if freePatternVars(term).isEmpty then Some(term) else None /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */ - def freePatternVars(term: Term)(given qctx: Context, env: Env): Set[Symbol] = + def freePatternVars(term: Term)(given ctx: Context, env: Env): Set[Symbol] = val accumulator = new TreeAccumulator[Set[Symbol]] { def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] = tree match diff --git a/library/src/scala/tasty/Reflection.scala b/library/src/scala/tasty/Reflection.scala index 8d81c81138aa..70e8b1fd5436 100644 --- a/library/src/scala/tasty/Reflection.scala +++ b/library/src/scala/tasty/Reflection.scala @@ -2731,266 +2731,23 @@ class Reflection(private[scala] val internal: CompilerInterface) { self => // UTILS // /////////////// - abstract class TreeAccumulator[X] { - - // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node. - def foldTree(x: X, tree: Tree)(given ctx: Context): X - - def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X = trees.foldLeft(x)(foldTree) - - def foldOverTree(x: X, tree: Tree)(given ctx: Context): X = { - def localCtx(definition: Definition): Context = definition.symbol.localContext - tree match { - case Ident(_) => - x - case Select(qualifier, _) => - foldTree(x, qualifier) - case This(qual) => - x - case Super(qual, _) => - foldTree(x, qual) - case Apply(fun, args) => - foldTrees(foldTree(x, fun), args) - case TypeApply(fun, args) => - foldTrees(foldTree(x, fun), args) - case Literal(const) => - x - case New(tpt) => - foldTree(x, tpt) - case Typed(expr, tpt) => - foldTree(foldTree(x, expr), tpt) - case NamedArg(_, arg) => - foldTree(x, arg) - case Assign(lhs, rhs) => - foldTree(foldTree(x, lhs), rhs) - case Block(stats, expr) => - foldTree(foldTrees(x, stats), expr) - case If(cond, thenp, elsep) => - foldTree(foldTree(foldTree(x, cond), thenp), elsep) - case While(cond, body) => - foldTree(foldTree(x, cond), body) - case Closure(meth, tpt) => - foldTree(x, meth) - case Match(selector, cases) => - foldTrees(foldTree(x, selector), cases) - case Return(expr) => - foldTree(x, expr) - case Try(block, handler, finalizer) => - foldTrees(foldTrees(foldTree(x, block), handler), finalizer) - case Repeated(elems, elemtpt) => - foldTrees(foldTree(x, elemtpt), elems) - case Inlined(call, bindings, expansion) => - foldTree(foldTrees(x, bindings), expansion) - case vdef @ ValDef(_, tpt, rhs) => - val ctx = localCtx(vdef) - given Context = ctx - foldTrees(foldTree(x, tpt), rhs) - case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) => - val ctx = localCtx(ddef) - given Context = ctx - foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs) - case tdef @ TypeDef(_, rhs) => - val ctx = localCtx(tdef) - given Context = ctx - foldTree(x, rhs) - case cdef @ ClassDef(_, constr, parents, derived, self, body) => - val ctx = localCtx(cdef) - given Context = ctx - foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body) - case Import(expr, _) => - foldTree(x, expr) - case clause @ PackageClause(pid, stats) => - foldTrees(foldTree(x, pid), stats)(given clause.symbol.localContext) - case Inferred() => x - case TypeIdent(_) => x - case TypeSelect(qualifier, _) => foldTree(x, qualifier) - case Projection(qualifier, _) => foldTree(x, qualifier) - case Singleton(ref) => foldTree(x, ref) - case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements) - case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args) - case ByName(result) => foldTree(x, result) - case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot) - case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg) - case TypeBind(_, tbt) => foldTree(x, tbt) - case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt) - case MatchTypeTree(boundopt, selector, cases) => - foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases) - case WildcardTypeTree() => x - case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi) - case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body) - case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body) - case Bind(_, body) => foldTree(x, body) - case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns) - case Alternatives(patterns) => foldTrees(x, patterns) - } - } + /** TASTy Reflect tree accumulator */ + trait TreeAccumulator[X] extends reflect.TreeAccumulator[X] { + val reflect: self.type = self } - abstract class TreeTraverser extends TreeAccumulator[Unit] { - - def traverseTree(tree: Tree)(given ctx: Context): Unit = traverseTreeChildren(tree) - - def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit = traverseTree(tree) - - protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit = foldOverTree((), tree) - + /** TASTy Reflect tree traverser */ + trait TreeTraverser extends reflect.TreeTraverser { + val reflect: self.type = self } - abstract class TreeMap { self => - - def transformTree(tree: Tree)(given ctx: Context): Tree = { - tree match { - case tree: PackageClause => - PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(given tree.symbol.localContext)) - case tree: Import => - Import.copy(tree)(transformTerm(tree.expr), tree.selectors) - case tree: Statement => - transformStatement(tree) - case tree: TypeTree => transformTypeTree(tree) - case tree: TypeBoundsTree => tree // TODO traverse tree - case tree: WildcardTypeTree => tree // TODO traverse tree - case tree: CaseDef => - transformCaseDef(tree) - case tree: TypeCaseDef => - transformTypeCaseDef(tree) - case pattern: Bind => - Bind.copy(pattern)(pattern.name, pattern.pattern) - case pattern: Unapply => - Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns)) - case pattern: Alternatives => - Alternatives.copy(pattern)(transformTrees(pattern.patterns)) - } - } - - def transformStatement(tree: Statement)(given ctx: Context): Statement = { - def localCtx(definition: Definition): Context = definition.symbol.localContext - tree match { - case tree: Term => - transformTerm(tree) - case tree: ValDef => - val ctx = localCtx(tree) - given Context = ctx - val tpt1 = transformTypeTree(tree.tpt) - val rhs1 = tree.rhs.map(x => transformTerm(x)) - ValDef.copy(tree)(tree.name, tpt1, rhs1) - case tree: DefDef => - val ctx = localCtx(tree) - given Context = ctx - DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x))) - case tree: TypeDef => - val ctx = localCtx(tree) - given Context = ctx - TypeDef.copy(tree)(tree.name, transformTree(tree.rhs)) - case tree: ClassDef => - ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body) - case tree: Import => - Import.copy(tree)(transformTerm(tree.expr), tree.selectors) - } - } - - def transformTerm(tree: Term)(given ctx: Context): Term = { - tree match { - case Ident(name) => - tree - case Select(qualifier, name) => - Select.copy(tree)(transformTerm(qualifier), name) - case This(qual) => - tree - case Super(qual, mix) => - Super.copy(tree)(transformTerm(qual), mix) - case Apply(fun, args) => - Apply.copy(tree)(transformTerm(fun), transformTerms(args)) - case TypeApply(fun, args) => - TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args)) - case Literal(const) => - tree - case New(tpt) => - New.copy(tree)(transformTypeTree(tpt)) - case Typed(expr, tpt) => - Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt)) - case tree: NamedArg => - NamedArg.copy(tree)(tree.name, transformTerm(tree.value)) - case Assign(lhs, rhs) => - Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs)) - case Block(stats, expr) => - Block.copy(tree)(transformStats(stats), transformTerm(expr)) - case If(cond, thenp, elsep) => - If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep)) - case Closure(meth, tpt) => - Closure.copy(tree)(transformTerm(meth), tpt) - case Match(selector, cases) => - Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases)) - case Return(expr) => - Return.copy(tree)(transformTerm(expr)) - case While(cond, body) => - While.copy(tree)(transformTerm(cond), transformTerm(body)) - case Try(block, cases, finalizer) => - Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x))) - case Repeated(elems, elemtpt) => - Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt)) - case Inlined(call, bindings, expansion) => - Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/) - } - } - - def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree match { - case Inferred() => tree - case tree: TypeIdent => tree - case tree: TypeSelect => - TypeSelect.copy(tree)(tree.qualifier, tree.name) - case tree: Projection => - Projection.copy(tree)(tree.qualifier, tree.name) - case tree: Annotated => - Annotated.copy(tree)(tree.arg, tree.annotation) - case tree: Singleton => - Singleton.copy(tree)(transformTerm(tree.ref)) - case tree: Refined => - Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]]) - case tree: Applied => - Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args)) - case tree: MatchTypeTree => - MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases)) - case tree: ByName => - ByName.copy(tree)(transformTypeTree(tree.result)) - case tree: LambdaTypeTree => - LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))(given tree.symbol.localContext) - case tree: TypeBind => - TypeBind.copy(tree)(tree.name, tree.body) - case tree: TypeBlock => - TypeBlock.copy(tree)(tree.aliases, tree.tpt) - } - - def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef = { - CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs)) - } - - def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = { - TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs)) - } - - def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] = - trees mapConserve (transformStatement(_)) - - def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] = - trees mapConserve (transformTree(_)) - - def transformTerms(trees: List[Term])(given ctx: Context): List[Term] = - trees mapConserve (transformTerm(_)) - - def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] = - trees mapConserve (transformTypeTree(_)) - - def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] = - trees mapConserve (transformCaseDef(_)) - - def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] = - trees mapConserve (transformTypeCaseDef(_)) - - def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] = - transformTrees(trees).asInstanceOf[List[Tr]] - + /** TASTy Reflect tree map */ + trait TreeMap extends reflect.TreeMap { + val reflect: self.type = self } + // TODO extract from Reflection + /** Bind the `rhs` to a `val` and use it in `body` */ def let(rhs: Term)(body: Ident => Term)(given ctx: Context): Term = { import scala.quoted.QuoteContext diff --git a/library/src/scala/tasty/reflect/TreeAccumulator.scala b/library/src/scala/tasty/reflect/TreeAccumulator.scala new file mode 100644 index 000000000000..f1a3b0f6576a --- /dev/null +++ b/library/src/scala/tasty/reflect/TreeAccumulator.scala @@ -0,0 +1,111 @@ +package scala.tasty +package reflect + +/** TASTy Reflect tree accumulator. + * + * Usage: + * ``` + * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R) + * extends scala.tasty.reflect.TreeAccumulator[X] { + * import reflect.{given, _} + * def foldTree(x: X, tree: Tree)(given ctx: Context): X = ... + * } + * ``` + */ +trait TreeAccumulator[X] { + + val reflect: Reflection + import reflect.{given, _} + + // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node. + def foldTree(x: X, tree: Tree)(given ctx: Context): X + + def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X = trees.foldLeft(x)(foldTree) + + def foldOverTree(x: X, tree: Tree)(given ctx: Context): X = { + def localCtx(definition: Definition): Context = definition.symbol.localContext + tree match { + case Ident(_) => + x + case Select(qualifier, _) => + foldTree(x, qualifier) + case This(qual) => + x + case Super(qual, _) => + foldTree(x, qual) + case Apply(fun, args) => + foldTrees(foldTree(x, fun), args) + case TypeApply(fun, args) => + foldTrees(foldTree(x, fun), args) + case Literal(const) => + x + case New(tpt) => + foldTree(x, tpt) + case Typed(expr, tpt) => + foldTree(foldTree(x, expr), tpt) + case NamedArg(_, arg) => + foldTree(x, arg) + case Assign(lhs, rhs) => + foldTree(foldTree(x, lhs), rhs) + case Block(stats, expr) => + foldTree(foldTrees(x, stats), expr) + case If(cond, thenp, elsep) => + foldTree(foldTree(foldTree(x, cond), thenp), elsep) + case While(cond, body) => + foldTree(foldTree(x, cond), body) + case Closure(meth, tpt) => + foldTree(x, meth) + case Match(selector, cases) => + foldTrees(foldTree(x, selector), cases) + case Return(expr) => + foldTree(x, expr) + case Try(block, handler, finalizer) => + foldTrees(foldTrees(foldTree(x, block), handler), finalizer) + case Repeated(elems, elemtpt) => + foldTrees(foldTree(x, elemtpt), elems) + case Inlined(call, bindings, expansion) => + foldTree(foldTrees(x, bindings), expansion) + case vdef @ ValDef(_, tpt, rhs) => + val ctx = localCtx(vdef) + given Context = ctx + foldTrees(foldTree(x, tpt), rhs) + case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) => + val ctx = localCtx(ddef) + given Context = ctx + foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs) + case tdef @ TypeDef(_, rhs) => + val ctx = localCtx(tdef) + given Context = ctx + foldTree(x, rhs) + case cdef @ ClassDef(_, constr, parents, derived, self, body) => + val ctx = localCtx(cdef) + given Context = ctx + foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body) + case Import(expr, _) => + foldTree(x, expr) + case clause @ PackageClause(pid, stats) => + foldTrees(foldTree(x, pid), stats)(given clause.symbol.localContext) + case Inferred() => x + case TypeIdent(_) => x + case TypeSelect(qualifier, _) => foldTree(x, qualifier) + case Projection(qualifier, _) => foldTree(x, qualifier) + case Singleton(ref) => foldTree(x, ref) + case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements) + case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args) + case ByName(result) => foldTree(x, result) + case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot) + case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg) + case TypeBind(_, tbt) => foldTree(x, tbt) + case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt) + case MatchTypeTree(boundopt, selector, cases) => + foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases) + case WildcardTypeTree() => x + case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi) + case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body) + case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body) + case Bind(_, body) => foldTree(x, body) + case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns) + case Alternatives(patterns) => foldTrees(x, patterns) + } + } +} diff --git a/library/src/scala/tasty/reflect/TreeMap.scala b/library/src/scala/tasty/reflect/TreeMap.scala new file mode 100644 index 000000000000..aa926de83c96 --- /dev/null +++ b/library/src/scala/tasty/reflect/TreeMap.scala @@ -0,0 +1,171 @@ +package scala.tasty +package reflect + +/** TASTy Reflect tree map. + * + * Usage: + * ``` + * class MyTreeMap[R <: scala.tasty.Reflection & Singleton](val reflect: R) + * extends scala.tasty.reflect.TreeMap { + * import reflect.{given, _} + * override def transformTree(tree: Tree)(using ctx: Context): Tree = ... + * } + * ``` + */ +trait TreeMap { + + val reflect: Reflection + import reflect.{given, _} + + def transformTree(tree: Tree)(given ctx: Context): Tree = { + tree match { + case tree: PackageClause => + PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(given tree.symbol.localContext)) + case tree: Import => + Import.copy(tree)(transformTerm(tree.expr), tree.selectors) + case tree: Statement => + transformStatement(tree) + case tree: TypeTree => transformTypeTree(tree) + case tree: TypeBoundsTree => tree // TODO traverse tree + case tree: WildcardTypeTree => tree // TODO traverse tree + case tree: CaseDef => + transformCaseDef(tree) + case tree: TypeCaseDef => + transformTypeCaseDef(tree) + case pattern: Bind => + Bind.copy(pattern)(pattern.name, pattern.pattern) + case pattern: Unapply => + Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns)) + case pattern: Alternatives => + Alternatives.copy(pattern)(transformTrees(pattern.patterns)) + } + } + + def transformStatement(tree: Statement)(given ctx: Context): Statement = { + def localCtx(definition: Definition): Context = definition.symbol.localContext + tree match { + case tree: Term => + transformTerm(tree) + case tree: ValDef => + val ctx = localCtx(tree) + given Context = ctx + val tpt1 = transformTypeTree(tree.tpt) + val rhs1 = tree.rhs.map(x => transformTerm(x)) + ValDef.copy(tree)(tree.name, tpt1, rhs1) + case tree: DefDef => + val ctx = localCtx(tree) + given Context = ctx + DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x))) + case tree: TypeDef => + val ctx = localCtx(tree) + given Context = ctx + TypeDef.copy(tree)(tree.name, transformTree(tree.rhs)) + case tree: ClassDef => + ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body) + case tree: Import => + Import.copy(tree)(transformTerm(tree.expr), tree.selectors) + } + } + + def transformTerm(tree: Term)(given ctx: Context): Term = { + tree match { + case Ident(name) => + tree + case Select(qualifier, name) => + Select.copy(tree)(transformTerm(qualifier), name) + case This(qual) => + tree + case Super(qual, mix) => + Super.copy(tree)(transformTerm(qual), mix) + case Apply(fun, args) => + Apply.copy(tree)(transformTerm(fun), transformTerms(args)) + case TypeApply(fun, args) => + TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args)) + case Literal(const) => + tree + case New(tpt) => + New.copy(tree)(transformTypeTree(tpt)) + case Typed(expr, tpt) => + Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt)) + case tree: NamedArg => + NamedArg.copy(tree)(tree.name, transformTerm(tree.value)) + case Assign(lhs, rhs) => + Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs)) + case Block(stats, expr) => + Block.copy(tree)(transformStats(stats), transformTerm(expr)) + case If(cond, thenp, elsep) => + If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep)) + case Closure(meth, tpt) => + Closure.copy(tree)(transformTerm(meth), tpt) + case Match(selector, cases) => + Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases)) + case Return(expr) => + Return.copy(tree)(transformTerm(expr)) + case While(cond, body) => + While.copy(tree)(transformTerm(cond), transformTerm(body)) + case Try(block, cases, finalizer) => + Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x))) + case Repeated(elems, elemtpt) => + Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt)) + case Inlined(call, bindings, expansion) => + Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/) + } + } + + def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree match { + case Inferred() => tree + case tree: TypeIdent => tree + case tree: TypeSelect => + TypeSelect.copy(tree)(tree.qualifier, tree.name) + case tree: Projection => + Projection.copy(tree)(tree.qualifier, tree.name) + case tree: Annotated => + Annotated.copy(tree)(tree.arg, tree.annotation) + case tree: Singleton => + Singleton.copy(tree)(transformTerm(tree.ref)) + case tree: Refined => + Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]]) + case tree: Applied => + Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args)) + case tree: MatchTypeTree => + MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases)) + case tree: ByName => + ByName.copy(tree)(transformTypeTree(tree.result)) + case tree: LambdaTypeTree => + LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))(given tree.symbol.localContext) + case tree: TypeBind => + TypeBind.copy(tree)(tree.name, tree.body) + case tree: TypeBlock => + TypeBlock.copy(tree)(tree.aliases, tree.tpt) + } + + def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef = { + CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs)) + } + + def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = { + TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs)) + } + + def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] = + trees mapConserve (transformStatement(_)) + + def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] = + trees mapConserve (transformTree(_)) + + def transformTerms(trees: List[Term])(given ctx: Context): List[Term] = + trees mapConserve (transformTerm(_)) + + def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] = + trees mapConserve (transformTypeTree(_)) + + def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] = + trees mapConserve (transformCaseDef(_)) + + def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] = + trees mapConserve (transformTypeCaseDef(_)) + + def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] = + transformTrees(trees).asInstanceOf[List[Tr]] + +} diff --git a/library/src/scala/tasty/reflect/TreeTraverser.scala b/library/src/scala/tasty/reflect/TreeTraverser.scala new file mode 100644 index 000000000000..d5604ffafd8e --- /dev/null +++ b/library/src/scala/tasty/reflect/TreeTraverser.scala @@ -0,0 +1,25 @@ +package scala.tasty +package reflect + +/** TASTy Reflect tree traverser. + * + * Usage: + * ``` + * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R) + * extends scala.tasty.reflect.TreeTraverser { + * import reflect.{given, _} + * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ... + * } + * ``` + */ +trait TreeTraverser extends TreeAccumulator[Unit] { + + import reflect._ + + def traverseTree(tree: Tree)(given ctx: Context): Unit = traverseTreeChildren(tree) + + def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit = traverseTree(tree) + + protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit = foldOverTree((), tree) + +} diff --git a/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala b/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala index a9dae76a6511..6951e5e71e90 100644 --- a/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala +++ b/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala @@ -11,28 +11,31 @@ object Macros { val buff = new StringBuilder - val output = new TreeTraverser { - override def traverseTree(tree: Tree)(implicit ctx: Context): Unit = { - tree match { - case tree @ DefDef(name, _, _, _, _) => - buff.append(name) - buff.append("\n") - buff.append(tree.symbol.owner.tree.showExtractors) - buff.append("\n\n") - case tree @ ValDef(name, _, _) => - buff.append(name) - buff.append("\n") - buff.append(tree.symbol.owner.tree.showExtractors) - buff.append("\n\n") - case _ => - } - traverseTreeChildren(tree) - } - } + val output = new MyTraverser(qctx.tasty)(buff) val tree = x.unseal output.traverseTree(tree) '{print(${buff.result()})} } + class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)(buff: StringBuilder) extends scala.tasty.reflect.TreeTraverser { + import reflect.{given, _} + override def traverseTree(tree: Tree)(implicit ctx: Context): Unit = { + tree match { + case tree @ DefDef(name, _, _, _, _) => + buff.append(name) + buff.append("\n") + buff.append(tree.symbol.owner.tree.showExtractors) + buff.append("\n\n") + case tree @ ValDef(name, _, _) => + buff.append(name) + buff.append("\n") + buff.append(tree.symbol.owner.tree.showExtractors) + buff.append("\n\n") + case _ => + } + traverseTreeChildren(tree) + } + } + }