Skip to content

Commit 7dfb647

Browse files
committed
wip
1 parent db76e1e commit 7dfb647

File tree

4 files changed

+212
-45
lines changed

4 files changed

+212
-45
lines changed

library/src-bootstrapped/scala/quoted/util/ExprMap.scala

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,47 +12,43 @@ trait ExprMap {
1212
import qctx.tasty.{_, given}
1313
class MapChildren() {
1414

15-
def transformTree(tree: Tree, tpe: Type)(given ctx: Context): Tree = {
15+
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
16+
def localCtx(definition: Definition): Context = definition.symbol.localContext
1617
tree match {
17-
case IsStatement(tree) =>
18-
transformStatement(tree)
19-
case IsCaseDef(tree) =>
20-
transformCaseDef(tree)
21-
case _ => tree
18+
case tree: Term =>
19+
transformTerm(tree, defn.AnyType)
20+
case tree: Definition =>
21+
transformDefinition(tree)
22+
case tree: Import =>
23+
tree
2224
}
2325
}
2426

25-
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
27+
def transformDefinition(tree: Definition)(given ctx: Context): Definition = {
2628
def localCtx(definition: Definition): Context = definition.symbol.localContext
2729
tree match {
28-
case IsTerm(tree) =>
29-
transformTerm(tree, defn.AnyType)
30-
case IsValDef(tree) =>
30+
case tree: ValDef =>
3131
implicit val ctx = localCtx(tree)
3232
val rhs1 = tree.rhs.map(x => transformTerm(x, tree.tpt.tpe))
3333
ValDef.copy(tree)(tree.name, tree.tpt, rhs1)
34-
case IsDefDef(tree) =>
34+
case tree: DefDef =>
3535
implicit val ctx = localCtx(tree)
3636
DefDef.copy(tree)(tree.name, tree.typeParams, tree.paramss, tree.returnTpt, tree.rhs.map(x => transformTerm(x, tree.returnTpt.tpe)))
37-
case IsTypeDef(tree) =>
37+
case tree: TypeDef =>
3838
tree
39-
case IsClassDef(tree) =>
39+
case tree: ClassDef =>
4040
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
41-
case IsImport(tree) =>
42-
tree
4341
}
4442
}
4543

4644
def transformTermChildren(tree: Term, tpe: Type)(given ctx: Context): Term = tree match {
4745
case Ident(name) =>
4846
tree
4947
case Select(qualifier, name) =>
50-
val IsType(qualTpe) = tree.symbol.owner.typeRef
51-
Select.copy(tree)(transformTerm(qualifier, qualTpe), name)
48+
Select.copy(tree)(transformTerm(qualifier, qualifier.tpe.widen), name)
5249
case This(qual) =>
5350
tree
5451
case Super(qual, mix) =>
55-
// Super.copy(tree)(transformTerm(qual, ???), mix)
5652
tree
5753
case tree @ Apply(fun, args) =>
5854
val MethodType(_, tpes, _) = fun.tpe.widen
@@ -66,52 +62,63 @@ trait ExprMap {
6662
case New(tpt) =>
6763
New.copy(tree)(transformTypeTree(tpt))
6864
case Typed(expr, tpt) =>
69-
Typed.copy(tree)(transformTerm(expr, tpt.tpe), transformTypeTree(tpt))
70-
case IsNamedArg(tree) =>
71-
NamedArg.copy(tree)(tree.name, transformTerm(tree.value, ???))
65+
val tp = tpt.tpe match
66+
// TODO improve code
67+
case AppliedType(TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "<repeated>"), List(IsType(tp0))) =>
68+
type T
69+
val a = tp0.seal.asInstanceOf[quoted.Type[T]]
70+
'[Seq[$a]].unseal.tpe
71+
case tp => tp
72+
Typed.copy(tree)(transformTerm(expr, tp), transformTypeTree(tpt))
73+
case tree: NamedArg =>
74+
NamedArg.copy(tree)(tree.name, transformTerm(tree.value, tpe))
7275
case Assign(lhs, rhs) =>
73-
Assign.copy(tree)(transformTerm(lhs, ???), transformTerm(rhs, ???))
76+
Assign.copy(tree)(lhs, transformTerm(rhs, lhs.tpe.widen))
7477
case Block(stats, expr) =>
7578
Block.copy(tree)(transformStats(stats), transformTerm(expr, tpe))
7679
case If(cond, thenp, elsep) =>
77-
If.copy(tree)(transformTerm(cond, ???), transformTerm(thenp, ???), transformTerm(elsep, ???))
78-
case IsClosure(_) =>
80+
If.copy(tree)(
81+
transformTerm(cond, defn.BooleanType),
82+
transformTerm(thenp, tpe),
83+
transformTerm(elsep, tpe))
84+
case _: Closure =>
7985
tree
8086
case Match(selector, cases) =>
81-
Match.copy(tree)(transformTerm(selector, ???), transformCaseDefs(cases))
87+
Match.copy(tree)(transformTerm(selector, ???), transformCaseDefs(cases, tpe))
8288
case Return(expr) =>
8389
Return.copy(tree)(transformTerm(expr, ???))
8490
case While(cond, body) =>
85-
While.copy(tree)(transformTerm(cond, ???), transformTerm(body, ???))
91+
While.copy(tree)(transformTerm(cond, defn.BooleanType), transformTerm(body, defn.AnyType))
8692
case Try(block, cases, finalizer) =>
87-
Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases), finalizer.map(x => transformTerm(x, defn.AnyType)))
93+
Try.copy(tree)(transformTerm(block, tpe), transformCaseDefs(cases, defn.AnyType), finalizer.map(x => transformTerm(x, defn.AnyType)))
8894
case Repeated(elems, elemtpt) =>
8995
Repeated.copy(tree)(transformTerms(elems, elemtpt.tpe), elemtpt)
9096
case Inlined(call, bindings, expansion) =>
91-
Inlined.copy(tree)(call, transformSubTrees(bindings, defn.AnyType), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/)
97+
Inlined.copy(tree)(call, transformDefinitions(bindings), transformTerm(expansion, tpe)/*()call.symbol.localContext)*/)
9298
}
9399

94-
def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term = tree match {
95-
case IsClosure(_) =>
100+
def transformTerm(tree: Term, tpe: Type)(given ctx: Context): Term =
101+
tree match {
102+
case _: Closure =>
96103
tree
97-
case IsInlined(_) | IsSelect(_) =>
104+
case _: Inlined | _: Select =>
98105
transformTermChildren(tree, tpe)
99106
case _ =>
100107
tree.tpe.widen match {
101-
case IsMethodType(_) | IsPolyType(_) =>
108+
case _: MethodType | _: PolyType =>
102109
transformTermChildren(tree, tpe)
103110
case _ =>
104-
tree.seal match {
105-
case '{ $x: $t } => map(x).unseal
106-
case _ => ???
107-
}
111+
type X
112+
val expr = tree.seal.asInstanceOf[Expr[X]]
113+
val t = tpe.seal.asInstanceOf[quoted.Type[X]]
114+
map(expr)(given qctx, t).unseal
108115
}
109116
}
110117

111118
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree
112119

113-
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
114-
CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, ???)), transformTerm(tree.rhs, ???))
120+
def transformCaseDef(tree: CaseDef, tpe: Type)(given ctx: Context): CaseDef =
121+
CaseDef.copy(tree)(tree.pattern, tree.guard.map(x => transformTerm(x, defn.BooleanType)), transformTerm(tree.rhs, tpe))
115122

116123
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
117124
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
@@ -120,8 +127,8 @@ trait ExprMap {
120127
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
121128
trees mapConserve (transformStatement(_))
122129

123-
def transformTrees(trees: List[Tree], tpe: Type)(given ctx: Context): List[Tree] =
124-
trees mapConserve (x => transformTree(x, tpe))
130+
def transformDefinitions(trees: List[Definition])(given ctx: Context): List[Definition] =
131+
trees mapConserve (transformDefinition(_))
125132

126133
def transformTerms(trees: List[Term], tpes: List[Type])(given ctx: Context): List[Term] =
127134
val a = trees.zip(tpes).map {case (x, tpe) => transformTerm(x, tpe) } // TODO zipConserve
@@ -133,15 +140,12 @@ trait ExprMap {
133140
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
134141
trees mapConserve (transformTypeTree(_))
135142

136-
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
137-
trees mapConserve (transformCaseDef(_))
143+
def transformCaseDefs(trees: List[CaseDef], tpe: Type)(given ctx: Context): List[CaseDef] =
144+
trees mapConserve (x => transformCaseDef(x, tpe))
138145

139146
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
140147
trees mapConserve (transformTypeCaseDef(_))
141148

142-
def transformSubTrees[Tr <: Tree](trees: List[Tr], tpe: Type)(given ctx: Context): List[Tr] =
143-
transformTrees(trees, tpe).asInstanceOf[List[Tr]]
144-
145149
}
146150
new MapChildren().transformTermChildren(e.unseal, tpe.unseal.tpe).seal.cast[T] // Cast will only fail if this implementation has a bug
147151
}

tests/run-macros/expr-map-1.check

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
oof
2+
oofoof
3+
ylppa
4+
kcolb
5+
kcolb
6+
neht
7+
esle
8+
lav
9+
vals
10+
fed
11+
defs
12+
fed
13+
rab
14+
yrt
15+
yllanif
16+
hctac
17+
elihw
18+
wen
19+
depyt
20+
depyt
21+
grAdeman
22+
qual
23+
adbmal
24+
ravsgra
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def rewrite[T](x: => Any): Any = ${ stringRewriter('x) }
5+
6+
private def stringRewriter(e: Expr[Any])(given QuoteContext): Expr[Any] =
7+
StringRewriter.map(e)
8+
9+
private object StringRewriter extends util.ExprMap {
10+
11+
def map[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = e match
12+
case Const(s: String) =>
13+
Expr(s.reverse) match
14+
case '{ $x: T } => x
15+
case _ => e // e had a singlton String type
16+
case _ => mapChildren(e)
17+
18+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
println(rewrite("foo"))
5+
println(rewrite("foo" + "foo"))
6+
7+
rewrite {
8+
println("apply")
9+
}
10+
11+
rewrite {
12+
println("block")
13+
println("block")
14+
}
15+
16+
val b: Boolean = true
17+
rewrite {
18+
if b then println("then")
19+
else println("else")
20+
}
21+
22+
rewrite {
23+
if !b then println("then")
24+
else println("else")
25+
}
26+
27+
rewrite {
28+
val s: String = "val"
29+
println(s)
30+
}
31+
32+
rewrite {
33+
val s: "vals" = "vals"
34+
println(s) // prints "foo" not "oof"
35+
}
36+
37+
rewrite {
38+
def s: String = "def"
39+
println(s)
40+
}
41+
42+
rewrite {
43+
def s: "defs" = "defs"
44+
println(s) // prints "foo" not "oof"
45+
}
46+
47+
rewrite {
48+
def s(x: String): String = x
49+
println(s("def"))
50+
}
51+
52+
rewrite {
53+
var s: String = "var"
54+
s = "bar"
55+
println(s)
56+
}
57+
58+
rewrite {
59+
try println("try")
60+
finally println("finally")
61+
}
62+
63+
rewrite {
64+
try throw new Exception()
65+
catch case x: Exception => println("catch")
66+
}
67+
68+
rewrite {
69+
var x = true
70+
while (x) {
71+
println("while")
72+
x = false
73+
}
74+
}
75+
76+
rewrite {
77+
val t = new Tuple1("new")
78+
println(t._1)
79+
}
80+
81+
rewrite {
82+
println("typed": String)
83+
println("typed": Any)
84+
}
85+
86+
rewrite {
87+
val f = new Foo(foo = "namedArg")
88+
println(f.foo)
89+
}
90+
91+
rewrite {
92+
println("qual".reverse)
93+
}
94+
95+
rewrite {
96+
val f = () => "lambda"
97+
println(f())
98+
}
99+
100+
rewrite {
101+
def f(args: String*): String = args.mkString
102+
println(f("var", "args"))
103+
}
104+
105+
// FIXME
106+
// rewrite {
107+
// def s: String = return "def"
108+
// println(s)
109+
// }
110+
111+
// rewrite {
112+
// "match" match {
113+
// case "match" => println("match")
114+
// case x => println("x")
115+
// }
116+
// }
117+
}
118+
119+
}
120+
121+
class Foo(val foo: String)

0 commit comments

Comments
 (0)