Skip to content

Commit 126588e

Browse files
committed
Add quoted.util.ExprMap
1 parent 89c2c36 commit 126588e

File tree

10 files changed

+362
-93
lines changed

10 files changed

+362
-93
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,13 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
17131713

17141714
def Symbol_noSymbol(given ctx: Context): Symbol = core.Symbols.NoSymbol
17151715

1716+
def Symbol_typeRef(symbol: Symbol)(given ctx: Context): TypeOrBounds = symbol.typeRef
1717+
1718+
def Symbol_termRef(symbol: Symbol)(given ctx: Context): TypeOrBounds = symbol.termRef
1719+
1720+
def Symbol_info(symbol: Symbol)(given ctx: Context): TypeOrBounds = symbol.info
1721+
1722+
17161723
//
17171724
// FLAGS
17181725
//
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package scala.quoted.util
2+
3+
import scala.quoted._
4+
5+
trait ExprMap {
6+
7+
/** Map an expression `e` with a type `tpe` */
8+
def map[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T]
9+
10+
/** Map subexpressions an expression `e` with a type `tpe` */
11+
def mapChildren[T](e: Expr[T])(given qctx: QuoteContext, tpe: Type[T]): Expr[T] = {
12+
import qctx.tasty.{_, given}
13+
class MapChildren() {
14+
15+
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
16+
def localCtx(definition: Definition): Context = definition.symbol.localContext
17+
tree match {
18+
case tree: Term =>
19+
transformTerm(tree, defn.AnyType)
20+
case tree: Definition =>
21+
transformDefinition(tree)
22+
case tree: Import =>
23+
tree
24+
}
25+
}
26+
27+
def transformDefinition(tree: Definition)(given ctx: Context): Definition = {
28+
def localCtx(definition: Definition): Context = definition.symbol.localContext
29+
tree match {
30+
case tree: ValDef =>
31+
implicit val ctx = localCtx(tree)
32+
val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe))
33+
ValDef.copy(tree)(tree.name, tree.tpt, rhs1)
34+
case tree: DefDef =>
35+
implicit val ctx = localCtx(tree)
36+
DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe)))
37+
case tree: TypeDef =>
38+
tree
39+
case tree: ClassDef =>
40+
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
41+
}
42+
}
43+
44+
def transformTermChildren(tree: Term, tpe: Type)(given ctx: Context): Term = tree match {
45+
case Ident(name) =>
46+
tree
47+
case Select(qualifier, name) =>
48+
Select.copy(tree)(transformTerm(qualifier, qualifier.tpe.widen), name)
49+
case This(qual) =>
50+
tree
51+
case Super(qual, mix) =>
52+
tree
53+
case tree @ Apply(fun, args) =>
54+
val MethodType(_, tpes, _) = fun.tpe.widen
55+
Apply.copy(tree)(transformTerm(fun, defn.AnyType), transformTerms(args, tpes))
56+
case TypeApply(fun, args) =>
57+
TypeApply.copy(tree)(transformTerm(fun, defn.AnyType), args)
58+
case _: Literal =>
59+
tree
60+
case New(tpt) =>
61+
New.copy(tree)(transformTypeTree(tpt))
62+
case Typed(expr, tpt) =>
63+
val tp = tpt.tpe match
64+
// TODO improve code
65+
case AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "<repeated>"), List(tp0: Type)) =>
66+
type T
67+
val a = tp0.seal.asInstanceOf[quoted.Type[T]]
68+
'[Seq[$a]].unseal.tpe
69+
case tp => tp
70+
Typed.copy(tree)(transformTerm(expr, tp), transformTypeTree(tpt))
71+
case tree: NamedArg =>
72+
NamedArg.copy(tree)(tree.name, transformTerm(tree.value, tpe))
73+
case Assign(lhs, rhs) =>
74+
Assign.copy(tree)(lhs, transformTerm(rhs, lhs.tpe.widen))
75+
case Block(stats, expr) =>
76+
Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe))
77+
case If(cond, thenp, elsep) =>
78+
If.copy(tree)(
79+
transformTerm(cond, defn.BooleanType),
80+
transformTerm(thenp, tpe),
81+
transformTerm(elsep, tpe))
82+
case _: Closure =>
83+
tree
84+
case Match(selector, cases) =>
85+
Match.copy(tree)(transformTerm(selector, selector.tpe), transformCaseDefs(cases, tpe))
86+
case Return(expr) =>
87+
// FIXME
88+
// ctx.owner seems to be set to the wrong symbol
89+
// Return.copy(tree)(transformTerm(expr, expr.tpe))
90+
tree
91+
case While(cond, body) =>
92+
While.copy(tree)(transformTerm(cond, defn.BooleanType), transformTerm(body, defn.AnyType))
93+
case Try(block, cases, finalizer) =>
94+
Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases, defn.AnyType), finalizer.map(x => transformTerm(x, defn.AnyType)))
95+
case Repeated(elems, elemtpt) =>
96+
Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt)
97+
case Inlined(call, bindings, expansion) =>
98+
Inlined.copy(tree)(call, transformDefinitions(bindings), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/)
99+
}
100+
101+
def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term =
102+
tree match {
103+
case _: Closure =>
104+
tree
105+
case _: Inlined | _: Select =>
106+
transformTermChildren(tree, tpe)
107+
case _ =>
108+
tree.tpe.widen match {
109+
case _: MethodType | _: PolyType =>
110+
transformTermChildren(tree, tpe)
111+
case _ =>
112+
type X
113+
val expr = tree.seal.asInstanceOf[Expr[X]]
114+
val t = tpe.seal.asInstanceOf[quoted.Type[X]]
115+
map(expr)(given qctx, t).unseal
116+
}
117+
}
118+
119+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree
120+
121+
def transformCaseDef(tree: CaseDef, tpe: Type)(given ctx: Context): CaseDef =
122+
CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, defn.BooleanType)), transformTerm(tree.rhs, tpe))
123+
124+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
125+
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
126+
}
127+
128+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
129+
trees mapConserve (transformStatement(_))
130+
131+
def transformDefinitions(trees: List[Definition])(given ctx: Context): List[Definition] =
132+
trees mapConserve (transformDefinition(_))
133+
134+
def transformTerms(trees: List[Term], tpes: List[Type])(given ctx: Context): List[Term] =
135+
var tpes2 = tpes // TODO use proper zipConserve
136+
trees mapConserve { x =>
137+
val tpe :: tail = tpes2
138+
tpes2 = tail
139+
transformTerm(x, tpe)
140+
}
141+
142+
def transformTerms(trees: List[Term], tpe: Type)(given ctx: Context): List[Term] =
143+
trees.mapConserve(x => transformTerm(x, tpe))
144+
145+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
146+
trees mapConserve (transformTypeTree(_))
147+
148+
def transformCaseDefs(trees: List[CaseDef], tpe: Type)(given ctx: Context): List[CaseDef] =
149+
trees mapConserve (x => transformCaseDef(x, tpe))
150+
151+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
152+
trees mapConserve (transformTypeCaseDef(_))
153+
154+
}
155+
new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).seal.cast[T] // Cast will only fail if this implementation has a bug
156+
}
157+
158+
}

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,12 @@ trait CompilerInterface {
12551255

12561256
def Symbol_noSymbol(given ctx: Context): Symbol
12571257

1258+
def Symbol_typeRef(symbol: Symbol)(given ctx: Context): TypeOrBounds
1259+
1260+
def Symbol_termRef(symbol: Symbol)(given ctx: Context): TypeOrBounds
1261+
1262+
def Symbol_info(symbol: Symbol)(given ctx: Context): TypeOrBounds
1263+
12581264
//
12591265
// FLAGS
12601266
//

library/src/scala/tasty/reflect/SymbolOps.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ trait SymbolOps extends Core { selfSymbolOps: FlagsOps =>
143143
/** The symbol of the companion module */
144144
def companionModule(given ctx: Context): Symbol =
145145
internal.Symbol_companionModule(self)
146+
147+
def typeRef(given ctx: Context): TypeOrBounds =
148+
internal.Symbol_typeRef(self)
149+
150+
def termRef(given ctx: Context): TypeOrBounds =
151+
internal.Symbol_termRef(self)
152+
153+
def info(given ctx: Context): TypeOrBounds =
154+
internal.Symbol_info(self)
155+
146156
}
147157

148158
}

tests/run-macros/expr-map-1.check

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
oof
2+
oofoof
3+
ylppa
4+
kcolb
5+
kcolb
6+
neht
7+
esle
8+
lav
9+
vals
10+
fed
11+
defs
12+
fed
13+
rab
14+
yrt
15+
yllanif
16+
hctac
17+
elihw
18+
wen
19+
depyt
20+
depyt
21+
grAdeman
22+
qual
23+
adbmal
24+
ravsgra
25+
hctam
26+
def
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def rewrite[T](x: => Any): Any = ${ stringRewriter('x) }
5+
6+
private def stringRewriter(e: Expr[Any])(given QuoteContext): Expr[Any] =
7+
StringRewriter.map(e)
8+
9+
private object StringRewriter extends util.ExprMap {
10+
11+
def map[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = e match
12+
case Const(s: String) =>
13+
Expr(s.reverse) match
14+
case '{ $x: T } => x
15+
case _ => e // e had a singlton String type
16+
case _ => mapChildren(e)
17+
18+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
println(rewrite("foo"))
5+
println(rewrite("foo" + "foo"))
6+
7+
rewrite {
8+
println("apply")
9+
}
10+
11+
rewrite {
12+
println("block")
13+
println("block")
14+
}
15+
16+
val b: Boolean = true
17+
rewrite {
18+
if b then println("then")
19+
else println("else")
20+
}
21+
22+
rewrite {
23+
if !b then println("then")
24+
else println("else")
25+
}
26+
27+
rewrite {
28+
val s: String = "val"
29+
println(s)
30+
}
31+
32+
rewrite {
33+
val s: "vals" = "vals"
34+
println(s) // prints "foo" not "oof"
35+
}
36+
37+
rewrite {
38+
def s: String = "def"
39+
println(s)
40+
}
41+
42+
rewrite {
43+
def s: "defs" = "defs"
44+
println(s) // prints "foo" not "oof"
45+
}
46+
47+
rewrite {
48+
def s(x: String): String = x
49+
println(s("def"))
50+
}
51+
52+
rewrite {
53+
var s: String = "var"
54+
s = "bar"
55+
println(s)
56+
}
57+
58+
rewrite {
59+
try println("try")
60+
finally println("finally")
61+
}
62+
63+
rewrite {
64+
try throw new Exception()
65+
catch case x: Exception => println("catch")
66+
}
67+
68+
rewrite {
69+
var x = true
70+
while (x) {
71+
println("while")
72+
x = false
73+
}
74+
}
75+
76+
rewrite {
77+
val t = new Tuple1("new")
78+
println(t._1)
79+
}
80+
81+
rewrite {
82+
println("typed": String)
83+
println("typed": Any)
84+
}
85+
86+
rewrite {
87+
val f = new Foo(foo = "namedArg")
88+
println(f.foo)
89+
}
90+
91+
rewrite {
92+
println("qual".reverse)
93+
}
94+
95+
rewrite {
96+
val f = () => "lambda"
97+
println(f())
98+
}
99+
100+
rewrite {
101+
def f(args: String*): String = args.mkString
102+
println(f("var", "args"))
103+
}
104+
105+
rewrite {
106+
"match" match {
107+
case "match" => println("match")
108+
case x => println("x")
109+
}
110+
}
111+
112+
// FIXME should print fed
113+
rewrite {
114+
def s: String = return "def"
115+
println(s)
116+
}
117+
118+
}
119+
120+
}
121+
122+
class Foo(val foo: String)

0 commit comments

Comments
 (0)