Skip to content

Commit bf01a4d

Browse files
committed
Extract Tree utils from Reflection
1 parent 957a80f commit bf01a4d

File tree

12 files changed

+352
-289
lines changed

12 files changed

+352
-289
lines changed

docs/docs/reference/metaprogramming/tasty-reflect.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def macroImpl()(qctx: QuoteContext): Expr[Unit] = {
119119

120120
### Tree Utilities
121121

122-
`scala.tasty.reflect.TreeUtils` contains three facilities for tree traversal and
122+
`scala.tasty.reflect` contains three facilities for tree traversal and
123123
transformations.
124124

125125
`TreeAccumulator` ties the knot of a traversal. By calling `foldOver(x, tree))`
@@ -144,7 +144,7 @@ but without returning any value. Finally a `TreeMap` performs a transformation.
144144

145145
#### Let
146146

147-
`scala.tasty.reflect.utils.TreeUtils` also offers a method `let` that allows us
147+
`scala.tasty.Reflection` also offers a method `let` that allows us
148148
to bind the `rhs` to a `val` and use it in `body`. Additionally, `lets` binds
149149
the given `terms` to names and use them in the `body`. Their type definitions
150150
are shown below:

library/src/scala/internal/quoted/Matcher.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ private[quoted] object Matcher {
156156
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole =>
157157
def bodyFn(lambdaArgs: List[Tree]): Tree = {
158158
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
159-
new TreeMap {
159+
new scala.tasty.reflect.TreeMap {
160+
val reflect: qctx.tasty.type = qctx.tasty
161+
import reflect.{given, _}
160162
override def transformTerm(tree: Term)(given ctx: Context): Term =
161163
tree match
162164
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
@@ -317,8 +319,10 @@ private[quoted] object Matcher {
317319
if freePatternVars(term).isEmpty then Some(term) else None
318320

319321
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
320-
def freePatternVars(term: Term)(given qctx: Context, env: Env): Set[Symbol] =
321-
val accumulator = new TreeAccumulator[Set[Symbol]] {
322+
def freePatternVars(term: Term)(given ctx: Context, env: Env): Set[Symbol] =
323+
val accumulator = new scala.tasty.reflect.TreeAccumulator[Set[Symbol]] {
324+
val reflect: qctx.tasty.type = qctx.tasty
325+
import reflect.{given, _}
322326
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
323327
tree match
324328
case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)

library/src/scala/quoted/unsafe/UnsafeExpr.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ object UnsafeExpr {
6565
private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = {
6666
import qctx.tasty.{given, _}
6767
val map = params.map(_.symbol).zip(args).toMap
68-
new TreeMap {
68+
new scala.tasty.reflect.TreeMap {
69+
val reflect: qctx.tasty.type = qctx.tasty
6970
override def transformTerm(tree: Term)(given ctx: Context): Term =
7071
super.transformTerm(tree) match
7172
case tree: Ident => map.getOrElse(tree.symbol, tree)

library/src/scala/tasty/Reflection.scala

Lines changed: 1 addition & 259 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,265 +2729,7 @@ class Reflection(private[scala] val internal: CompilerInterface) { self =>
27292729
// UTILS //
27302730
///////////////
27312731

2732-
abstract class TreeAccumulator[X] {
2733-
2734-
// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
2735-
def foldTree(x: X, tree: Tree)(given ctx: Context): X
2736-
2737-
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X = trees.foldLeft(x)(foldTree)
2738-
2739-
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X = {
2740-
def localCtx(definition: Definition): Context = definition.symbol.localContext
2741-
tree match {
2742-
case Ident(_) =>
2743-
x
2744-
case Select(qualifier, _) =>
2745-
foldTree(x, qualifier)
2746-
case This(qual) =>
2747-
x
2748-
case Super(qual, _) =>
2749-
foldTree(x, qual)
2750-
case Apply(fun, args) =>
2751-
foldTrees(foldTree(x, fun), args)
2752-
case TypeApply(fun, args) =>
2753-
foldTrees(foldTree(x, fun), args)
2754-
case Literal(const) =>
2755-
x
2756-
case New(tpt) =>
2757-
foldTree(x, tpt)
2758-
case Typed(expr, tpt) =>
2759-
foldTree(foldTree(x, expr), tpt)
2760-
case NamedArg(_, arg) =>
2761-
foldTree(x, arg)
2762-
case Assign(lhs, rhs) =>
2763-
foldTree(foldTree(x, lhs), rhs)
2764-
case Block(stats, expr) =>
2765-
foldTree(foldTrees(x, stats), expr)
2766-
case If(cond, thenp, elsep) =>
2767-
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
2768-
case While(cond, body) =>
2769-
foldTree(foldTree(x, cond), body)
2770-
case Closure(meth, tpt) =>
2771-
foldTree(x, meth)
2772-
case Match(selector, cases) =>
2773-
foldTrees(foldTree(x, selector), cases)
2774-
case Return(expr) =>
2775-
foldTree(x, expr)
2776-
case Try(block, handler, finalizer) =>
2777-
foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
2778-
case Repeated(elems, elemtpt) =>
2779-
foldTrees(foldTree(x, elemtpt), elems)
2780-
case Inlined(call, bindings, expansion) =>
2781-
foldTree(foldTrees(x, bindings), expansion)
2782-
case vdef @ ValDef(_, tpt, rhs) =>
2783-
val ctx = localCtx(vdef)
2784-
given Context = ctx
2785-
foldTrees(foldTree(x, tpt), rhs)
2786-
case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) =>
2787-
val ctx = localCtx(ddef)
2788-
given Context = ctx
2789-
foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
2790-
case tdef @ TypeDef(_, rhs) =>
2791-
val ctx = localCtx(tdef)
2792-
given Context = ctx
2793-
foldTree(x, rhs)
2794-
case cdef @ ClassDef(_, constr, parents, derived, self, body) =>
2795-
val ctx = localCtx(cdef)
2796-
given Context = ctx
2797-
foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
2798-
case Import(expr, _) =>
2799-
foldTree(x, expr)
2800-
case clause @ PackageClause(pid, stats) =>
2801-
foldTrees(foldTree(x, pid), stats)(given clause.symbol.localContext)
2802-
case Inferred() => x
2803-
case TypeIdent(_) => x
2804-
case TypeSelect(qualifier, _) => foldTree(x, qualifier)
2805-
case Projection(qualifier, _) => foldTree(x, qualifier)
2806-
case Singleton(ref) => foldTree(x, ref)
2807-
case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
2808-
case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
2809-
case ByName(result) => foldTree(x, result)
2810-
case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
2811-
case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
2812-
case TypeBind(_, tbt) => foldTree(x, tbt)
2813-
case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
2814-
case MatchTypeTree(boundopt, selector, cases) =>
2815-
foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
2816-
case WildcardTypeTree() => x
2817-
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
2818-
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
2819-
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
2820-
case Bind(_, body) => foldTree(x, body)
2821-
case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
2822-
case Alternatives(patterns) => foldTrees(x, patterns)
2823-
}
2824-
}
2825-
}
2826-
2827-
abstract class TreeTraverser extends TreeAccumulator[Unit] {
2828-
2829-
def traverseTree(tree: Tree)(given ctx: Context): Unit = traverseTreeChildren(tree)
2830-
2831-
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit = traverseTree(tree)
2832-
2833-
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit = foldOverTree((), tree)
2834-
2835-
}
2836-
2837-
abstract class TreeMap { self =>
2838-
2839-
def transformTree(tree: Tree)(given ctx: Context): Tree = {
2840-
tree match {
2841-
case tree: PackageClause =>
2842-
PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(given tree.symbol.localContext))
2843-
case tree: Import =>
2844-
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
2845-
case tree: Statement =>
2846-
transformStatement(tree)
2847-
case tree: TypeTree => transformTypeTree(tree)
2848-
case tree: TypeBoundsTree => tree // TODO traverse tree
2849-
case tree: WildcardTypeTree => tree // TODO traverse tree
2850-
case tree: CaseDef =>
2851-
transformCaseDef(tree)
2852-
case tree: TypeCaseDef =>
2853-
transformTypeCaseDef(tree)
2854-
case pattern: Bind =>
2855-
Bind.copy(pattern)(pattern.name, pattern.pattern)
2856-
case pattern: Unapply =>
2857-
Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
2858-
case pattern: Alternatives =>
2859-
Alternatives.copy(pattern)(transformTrees(pattern.patterns))
2860-
}
2861-
}
2862-
2863-
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
2864-
def localCtx(definition: Definition): Context = definition.symbol.localContext
2865-
tree match {
2866-
case tree: Term =>
2867-
transformTerm(tree)
2868-
case tree: ValDef =>
2869-
val ctx = localCtx(tree)
2870-
given Context = ctx
2871-
val tpt1 = transformTypeTree(tree.tpt)
2872-
val rhs1 = tree.rhs.map(x => transformTerm(x))
2873-
ValDef.copy(tree)(tree.name, tpt1, rhs1)
2874-
case tree: DefDef =>
2875-
val ctx = localCtx(tree)
2876-
given Context = ctx
2877-
DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
2878-
case tree: TypeDef =>
2879-
val ctx = localCtx(tree)
2880-
given Context = ctx
2881-
TypeDef.copy(tree)(tree.name, transformTree(tree.rhs))
2882-
case tree: ClassDef =>
2883-
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
2884-
case tree: Import =>
2885-
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
2886-
}
2887-
}
2888-
2889-
def transformTerm(tree: Term)(given ctx: Context): Term = {
2890-
tree match {
2891-
case Ident(name) =>
2892-
tree
2893-
case Select(qualifier, name) =>
2894-
Select.copy(tree)(transformTerm(qualifier), name)
2895-
case This(qual) =>
2896-
tree
2897-
case Super(qual, mix) =>
2898-
Super.copy(tree)(transformTerm(qual), mix)
2899-
case Apply(fun, args) =>
2900-
Apply.copy(tree)(transformTerm(fun), transformTerms(args))
2901-
case TypeApply(fun, args) =>
2902-
TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args))
2903-
case Literal(const) =>
2904-
tree
2905-
case New(tpt) =>
2906-
New.copy(tree)(transformTypeTree(tpt))
2907-
case Typed(expr, tpt) =>
2908-
Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt))
2909-
case tree: NamedArg =>
2910-
NamedArg.copy(tree)(tree.name, transformTerm(tree.value))
2911-
case Assign(lhs, rhs) =>
2912-
Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs))
2913-
case Block(stats, expr) =>
2914-
Block.copy(tree)(transformStats(stats), transformTerm(expr))
2915-
case If(cond, thenp, elsep) =>
2916-
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
2917-
case Closure(meth, tpt) =>
2918-
Closure.copy(tree)(transformTerm(meth), tpt)
2919-
case Match(selector, cases) =>
2920-
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
2921-
case Return(expr) =>
2922-
Return.copy(tree)(transformTerm(expr))
2923-
case While(cond, body) =>
2924-
While.copy(tree)(transformTerm(cond), transformTerm(body))
2925-
case Try(block, cases, finalizer) =>
2926-
Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
2927-
case Repeated(elems, elemtpt) =>
2928-
Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
2929-
case Inlined(call, bindings, expansion) =>
2930-
Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/)
2931-
}
2932-
}
2933-
2934-
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree match {
2935-
case Inferred() => tree
2936-
case tree: TypeIdent => tree
2937-
case tree: TypeSelect =>
2938-
TypeSelect.copy(tree)(tree.qualifier, tree.name)
2939-
case tree: Projection =>
2940-
Projection.copy(tree)(tree.qualifier, tree.name)
2941-
case tree: Annotated =>
2942-
Annotated.copy(tree)(tree.arg, tree.annotation)
2943-
case tree: Singleton =>
2944-
Singleton.copy(tree)(transformTerm(tree.ref))
2945-
case tree: Refined =>
2946-
Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]])
2947-
case tree: Applied =>
2948-
Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
2949-
case tree: MatchTypeTree =>
2950-
MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
2951-
case tree: ByName =>
2952-
ByName.copy(tree)(transformTypeTree(tree.result))
2953-
case tree: LambdaTypeTree =>
2954-
LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))(given tree.symbol.localContext)
2955-
case tree: TypeBind =>
2956-
TypeBind.copy(tree)(tree.name, tree.body)
2957-
case tree: TypeBlock =>
2958-
TypeBlock.copy(tree)(tree.aliases, tree.tpt)
2959-
}
2960-
2961-
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef = {
2962-
CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
2963-
}
2964-
2965-
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
2966-
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
2967-
}
2968-
2969-
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
2970-
trees mapConserve (transformStatement(_))
2971-
2972-
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
2973-
trees mapConserve (transformTree(_))
2974-
2975-
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
2976-
trees mapConserve (transformTerm(_))
2977-
2978-
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
2979-
trees mapConserve (transformTypeTree(_))
2980-
2981-
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
2982-
trees mapConserve (transformCaseDef(_))
2983-
2984-
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
2985-
trees mapConserve (transformTypeCaseDef(_))
2986-
2987-
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
2988-
transformTrees(trees).asInstanceOf[List[Tr]]
2989-
2990-
}
2732+
// TODO extract from Reflection
29912733

29922734
/** Bind the `rhs` to a `val` and use it in `body` */
29932735
def let(rhs: Term)(body: Ident => Term)(given ctx: Context): Term = {

0 commit comments

Comments
 (0)