Skip to content

Commit db76e1e

Browse files
committed
WIP Add quoted.util.ExprMap
1 parent 5e47774 commit db76e1e

File tree

7 files changed

+187
-93
lines changed

7 files changed

+187
-93
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,13 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
17131713

17141714
def Symbol_noSymbol(given ctx: Context): Symbol = core.Symbols.NoSymbol
17151715

1716+
def Symbol_typeRef(symbol: Symbol)(given ctx: Context): TypeOrBounds = symbol.typeRef
1717+
1718+
def Symbol_termRef(symbol: Symbol)(given ctx: Context): TypeOrBounds = symbol.termRef
1719+
1720+
def Symbol_info(symbol: Symbol)(given ctx: Context): TypeOrBounds = symbol.info
1721+
1722+
17161723
//
17171724
// FLAGS
17181725
//
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package scala.quoted.util
2+
3+
import scala.quoted._
4+
5+
trait ExprMap {
6+
7+
/** Map an expression `e` with a type `tpe` */
8+
def map[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T]
9+
10+
/** Map subexpressions an expression `e` with a type `tpe` */
11+
def mapChildren[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] = {
12+
import qctx.tasty.{_, given}
13+
class MapChildren() {
14+
15+
def transformTree(tree: Tree, tpe: Type)(given ctx: Context): Tree = {
16+
tree match {
17+
case IsStatement(tree) =>
18+
transformStatement(tree)
19+
case IsCaseDef(tree) =>
20+
transformCaseDef(tree)
21+
case _ => tree
22+
}
23+
}
24+
25+
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
26+
def localCtx(definition: Definition): Context = definition.symbol.localContext
27+
tree match {
28+
case IsTerm(tree) =>
29+
transformTerm(tree, defn.AnyType)
30+
case IsValDef(tree) =>
31+
implicit val ctx = localCtx(tree)
32+
val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe))
33+
ValDef.copy(tree)(tree.name, tree.tpt, rhs1)
34+
case IsDefDef(tree) =>
35+
implicit val ctx = localCtx(tree)
36+
DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe)))
37+
case IsTypeDef(tree) =>
38+
tree
39+
case IsClassDef(tree) =>
40+
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
41+
case IsImport(tree) =>
42+
tree
43+
}
44+
}
45+
46+
def transformTermChildren(tree: Term, tpe: Type)(given ctx: Context): Term = tree match {
47+
case Ident(name) =>
48+
tree
49+
case Select(qualifier, name) =>
50+
val IsType(qualTpe) = tree.symbol.owner.typeRef
51+
Select.copy(tree)(transformTerm(qualifier, qualTpe), name)
52+
case This(qual) =>
53+
tree
54+
case Super(qual, mix) =>
55+
// Super.copy(tree)(transformTerm(qual, ???), mix)
56+
tree
57+
case tree @ Apply(fun, args) =>
58+
val MethodType(_, tpes, _) = fun.tpe.widen
59+
val a = defn.AnyType // FIXME
60+
Apply.copy(tree)(transformTerm(fun, a), transformTerms(args, tpes))
61+
case TypeApply(fun, args) =>
62+
val a = defn.AnyType // FIXME
63+
TypeApply.copy(tree)(transformTerm(fun, a), args)
64+
case Literal(const) =>
65+
tree
66+
case New(tpt) =>
67+
New.copy(tree)(transformTypeTree(tpt))
68+
case Typed(expr, tpt) =>
69+
Typed.copy(tree)(transformTerm(expr, tpt.tpe), transformTypeTree(tpt))
70+
case IsNamedArg(tree) =>
71+
NamedArg.copy(tree)(tree.name, transformTerm(tree.value, ???))
72+
case Assign(lhs, rhs) =>
73+
Assign.copy(tree)(transformTerm(lhs, ???), transformTerm(rhs, ???))
74+
case Block(stats, expr) =>
75+
Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe))
76+
case If(cond, thenp, elsep) =>
77+
If.copy(tree)(transformTerm(cond, ???), transformTerm(thenp, ???), transformTerm(elsep, ???))
78+
case IsClosure(_) =>
79+
tree
80+
case Match(selector, cases) =>
81+
Match.copy(tree)(transformTerm(selector, ???), transformCaseDefs(cases))
82+
case Return(expr) =>
83+
Return.copy(tree)(transformTerm(expr, ???))
84+
case While(cond, body) =>
85+
While.copy(tree)(transformTerm(cond, ???), transformTerm(body, ???))
86+
case Try(block, cases, finalizer) =>
87+
Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases), finalizer.map(x => transformTerm(x, defn.AnyType)))
88+
case Repeated(elems, elemtpt) =>
89+
Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt)
90+
case Inlined(call, bindings, expansion) =>
91+
Inlined.copy(tree)(call, transformSubTrees(bindings, defn.AnyType), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/)
92+
}
93+
94+
def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term = tree match {
95+
case IsClosure(_) =>
96+
tree
97+
case IsInlined(_) | IsSelect(_) =>
98+
transformTermChildren(tree, tpe)
99+
case _ =>
100+
tree.tpe.widen match {
101+
case IsMethodType(_) | IsPolyType(_) =>
102+
transformTermChildren(tree, tpe)
103+
case _ =>
104+
tree.seal match {
105+
case '{ $x: $t } => map(x).unseal
106+
case _ => ???
107+
}
108+
}
109+
}
110+
111+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree
112+
113+
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
114+
CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, ???)), transformTerm(tree.rhs, ???))
115+
116+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
117+
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
118+
}
119+
120+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
121+
trees mapConserve (transformStatement(_))
122+
123+
def transformTrees(trees: List[Tree], tpe: Type)(given ctx: Context): List[Tree] =
124+
trees mapConserve (x => transformTree(x, tpe))
125+
126+
def transformTerms(trees: List[Term], tpes: List[Type])(given ctx: Context): List[Term] =
127+
val a = trees.zip(tpes).map {case (x, tpe) => transformTerm(x, tpe) } // TODO zipConserve
128+
if a == trees then trees else a
129+
130+
def transformTerms(trees: List[Term], tpe: Type)(given ctx: Context): List[Term] =
131+
trees.mapConserve(x => transformTerm(x, tpe))
132+
133+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
134+
trees mapConserve (transformTypeTree(_))
135+
136+
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
137+
trees mapConserve (transformCaseDef(_))
138+
139+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
140+
trees mapConserve (transformTypeCaseDef(_))
141+
142+
def transformSubTrees[Tr <: Tree](trees: List[Tr], tpe: Type)(given ctx: Context): List[Tr] =
143+
transformTrees(trees, tpe).asInstanceOf[List[Tr]]
144+
145+
}
146+
new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).seal.cast[T] // Cast will only fail if this implementation has a bug
147+
}
148+
149+
}

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,12 @@ trait CompilerInterface {
12551255

12561256
def Symbol_noSymbol(given ctx: Context): Symbol
12571257

1258+
def Symbol_typeRef(symbol: Symbol)(given ctx: Context): TypeOrBounds
1259+
1260+
def Symbol_termRef(symbol: Symbol)(given ctx: Context): TypeOrBounds
1261+
1262+
def Symbol_info(symbol: Symbol)(given ctx: Context): TypeOrBounds
1263+
12581264
//
12591265
// FLAGS
12601266
//

library/src/scala/tasty/reflect/SymbolOps.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ trait SymbolOps extends Core { selfSymbolOps: FlagsOps =>
143143
/** The symbol of the companion module */
144144
def companionModule(given ctx: Context): Symbol =
145145
internal.Symbol_companionModule(self)
146+
147+
def typeRef(given ctx: Context): TypeOrBounds =
148+
internal.Symbol_typeRef(self)
149+
150+
def termRef(given ctx: Context): TypeOrBounds =
151+
internal.Symbol_termRef(self)
152+
153+
def info(given ctx: Context): TypeOrBounds =
154+
internal.Symbol_info(self)
155+
146156
}
147157

148158
}

tests/run-macros/flops-rewrite-2/Macro_1.scala

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
3535
fixPoint = true
3636
)
3737

38-
val x2 = rewriter.rewrite(x)
38+
val x2 = rewriter.map(x)
3939

4040
'{
4141
println(${Expr(x.show)})
@@ -63,39 +63,13 @@ private object Rewriter {
6363
new Rewriter(preTransform, postTransform, fixPoint)
6464
}
6565

66-
private class Rewriter(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean) {
67-
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
66+
private class Rewriter(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean) extends util.ExprMap {
67+
def map[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
6868
val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei))
69-
val e3 = rewriteChildren(e2)
69+
val e3 = mapChildren(e2)
7070
val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei))
71-
if fixPoint && e4 != e then rewrite(e4)
71+
if fixPoint && e4 != e then map(e4)
7272
else e4
7373
}
7474

75-
def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
76-
import qctx.tasty.{_, given}
77-
class MapChildren extends TreeMap {
78-
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
79-
case IsClosure(_) =>
80-
tree
81-
case IsInlined(_) | IsSelect(_) =>
82-
transformChildrenTerm(tree)
83-
case _ =>
84-
tree.tpe.widen match {
85-
case IsMethodType(_) | IsPolyType(_) =>
86-
transformChildrenTerm(tree)
87-
case _ =>
88-
tree.seal match {
89-
case '{ $x: $t } => rewrite(x).unseal
90-
}
91-
}
92-
}
93-
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
94-
super.transformTerm(tree)
95-
}
96-
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
97-
}
98-
9975
}
100-
101-

tests/run-macros/flops-rewrite-3/Macro_1.scala

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
3333
}
3434
)
3535

36-
val x2 = rewriter.rewrite(x)
36+
val x2 = rewriter.map(x)
3737

3838
'{
3939
println(${Expr(x.show)})
@@ -92,7 +92,7 @@ private object Rewriter {
9292
def apply(): Rewriter = new Rewriter(Nil, Nil, false)
9393
}
9494

95-
private class Rewriter private (preTransform: List[Transformation] = Nil, postTransform: List[Transformation] = Nil, fixPoint: Boolean) {
95+
private class Rewriter private (preTransform: List[Transformation] = Nil, postTransform: List[Transformation] = Nil, fixPoint: Boolean) extends util.ExprMap {
9696

9797
def withFixPoint: Rewriter =
9898
new Rewriter(preTransform, postTransform, fixPoint = true)
@@ -101,37 +101,11 @@ private class Rewriter private (preTransform: List[Transformation] = Nil, postTr
101101
def withPost(transform: Transformation): Rewriter =
102102
new Rewriter(preTransform, transform :: postTransform, fixPoint)
103103

104-
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
104+
def map[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
105105
val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei))
106-
val e3 = rewriteChildren(e2)
106+
val e3 = mapChildren(e2)
107107
val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei))
108-
if fixPoint && e4 != e then rewrite(e4) else e4
109-
}
110-
111-
def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
112-
import qctx.tasty.{_, given}
113-
class MapChildren extends TreeMap {
114-
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
115-
case IsClosure(_) =>
116-
tree
117-
case IsInlined(_) | IsSelect(_) =>
118-
transformChildrenTerm(tree)
119-
case _ =>
120-
tree.tpe.widen match {
121-
case IsMethodType(_) | IsPolyType(_) =>
122-
transformChildrenTerm(tree)
123-
case _ =>
124-
tree.seal match {
125-
case '{ $x: $t } => rewrite(x).unseal
126-
}
127-
}
128-
}
129-
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
130-
super.transformTerm(tree)
131-
}
132-
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
108+
if fixPoint && e4 != e then map(e4) else e4
133109
}
134110

135111
}
136-
137-

tests/run-macros/flops-rewrite/Macro_1.scala

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
1313
}
1414
)
1515

16-
val x2 = rewriter.rewrite(x)
16+
val x2 = rewriter.map(x)
1717

1818
'{
1919
println(${Expr(x.show)})
@@ -28,12 +28,12 @@ private object Rewriter {
2828
new Rewriter(preTransform, postTransform, fixPoint)
2929
}
3030

31-
private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr[Any] => Expr[Any], fixPoint: Boolean) {
32-
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
31+
private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr[Any] => Expr[Any], fixPoint: Boolean) extends util.ExprMap {
32+
def map[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
3333
val e2 = checkedTransform(e, preTransform)
34-
val e3 = rewriteChildren(e2)
34+
val e3 = mapChildren(e2)
3535
val e4 = checkedTransform(e3, postTransform)
36-
if fixPoint && e4 != e then rewrite(e4)
36+
if fixPoint && e4 != e then map(e4)
3737
else e4
3838
}
3939

@@ -54,30 +54,4 @@ private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr
5454
}
5555
}
5656

57-
def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
58-
import qctx.tasty.{_, given}
59-
class MapChildren extends TreeMap {
60-
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
61-
case IsClosure(_) =>
62-
tree
63-
case IsInlined(_) | IsSelect(_) =>
64-
transformChildrenTerm(tree)
65-
case _ =>
66-
tree.tpe.widen match {
67-
case IsMethodType(_) | IsPolyType(_) =>
68-
transformChildrenTerm(tree)
69-
case _ =>
70-
tree.seal match {
71-
case '{ $x: $t } => rewrite(x).unseal
72-
}
73-
}
74-
}
75-
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
76-
super.transformTerm(tree)
77-
}
78-
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
79-
}
80-
8157
}
82-
83-

0 commit comments

Comments
 (0)