Skip to content

Commit aa5a1f6

Browse files
committed
Add rewrite prototype and couple of fixes
1 parent 22e64e2 commit aa5a1f6

File tree

9 files changed

+115
-3
lines changed

9 files changed

+115
-3
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
204204
def DefDef_apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given Context): DefDef =
205205
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))
206206

207-
def DefDef_copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
207+
def DefDef_copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given Context): DefDef =
208208
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))
209209

210210
type ValDef = tpd.ValDef

library/src/scala/quoted/Expr.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ package internal {
190190
* May contain references to code defined outside this TastyTreeExpr instance.
191191
*/
192192
final class TastyTreeExpr[Tree](val tree: Tree, val scopeId: Int) extends Expr[Any] {
193+
override def equals(that: Any): Boolean = that match {
194+
case that: TastyTreeExpr[_] => tree == that.tree && scopeId == that.scopeId
195+
case _ => false
196+
}
197+
override def hashCode: Int = tree.hashCode
193198
override def toString: String = s"Expr(<tasty tree>)"
194199
}
195200

library/src/scala/quoted/Type.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ package internal {
7171

7272
/** An Type backed by a tree */
7373
final class TreeType[Tree](val typeTree: Tree, val scopeId: Int) extends scala.quoted.Type[Any] {
74+
override def equals(that: Any): Boolean = that match {
75+
case that: TreeType[_] => typeTree == that.typeTree && scopeId == that.scopeId
76+
case _ => false
77+
}
78+
override def hashCode: Int = typeTree.hashCode
7479
override def toString: String = s"Type(<tasty tree>)"
7580
}
7681

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ trait CompilerInterface {
262262
def DefDef_rhs(self: DefDef)(given ctx: Context): Option[Term]
263263

264264
def DefDef_apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given ctx: Context): DefDef
265-
def DefDef_copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef
265+
def DefDef_copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef
266266

267267
/** Tree representing a value definition in the source code This inclues `val`, `lazy val`, `var`, `object` and parameter definitions. */
268268
type ValDef <: Definition

library/src/scala/tasty/reflect/SourceCodePrinter.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ class SourceCodePrinter[R <: Reflection & Singleton](val tasty: R)(syntaxHighlig
514514
case IsTypeTree(tpt) =>
515515
printTypeTree(tpt)
516516

517+
case Closure(meth, _) =>
518+
printTree(meth)
519+
517520
case _ =>
518521
throw new MatchError(tree.showExtractors)
519522

library/src/scala/tasty/reflect/TreeOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ trait TreeOps extends Core {
9797
object DefDef {
9898
def apply(symbol: Symbol, rhsFn: List[Type] => List[List[Term]] => Option[Term])(given ctx: Context): DefDef =
9999
internal.DefDef_apply(symbol, rhsFn)
100-
def copy(original: DefDef)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef =
100+
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term])(given ctx: Context): DefDef =
101101
internal.DefDef_copy(original)(name, typeParams, paramss, tpt, rhs)
102102
def unapply(tree: Tree)(given ctx: Context): Option[(String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term])] =
103103
internal.matchDefDef(tree).map(x => (x.name, x.typeParams, x.paramss, x.returnTpt, x.rhs))

tests/run-macros/flops-rewrite.check

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x))
2+
scala.Nil
3+
4+
scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Nothing](scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)))
5+
scala.Nil
6+
7+
scala.Nil.map[scala.Nothing](((x: scala.Nothing) => x)).++[scala.Int](scala.List.apply[scala.Int]((3: scala.<repeated>[scala.Int]))).++[scala.Int](scala.Nil)
8+
scala.List.apply[scala.Int]((3: scala.<repeated>[scala.Int]))
9+
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import scala.quoted._
2+
3+
inline def rewrite[T](x: => T): T = ${ rewriteMacro('x) }
4+
5+
private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
6+
val x2 = Rewriter(
7+
postTransform = {
8+
case '{ Nil.map[$t]($f) } => '{ Nil }
9+
case '{ Nil.filter($f) } => '{ Nil }
10+
case '{ Nil.++[$t]($xs) } => xs
11+
case '{ ($xs: List[$t]).++(Nil) } => xs
12+
case x => x
13+
}
14+
).rewrite(x)
15+
16+
'{
17+
println(${Expr(x.show)})
18+
println(${Expr(x2.show)})
19+
println()
20+
$x2
21+
}
22+
}
23+
24+
private object Rewriter {
25+
def apply(preTransform: Expr[Any] => Expr[Any] = identity, postTransform: Expr[Any] => Expr[Any] = identity, fixPoint: Boolean = false): Rewriter =
26+
new Rewriter(preTransform, postTransform, fixPoint)
27+
}
28+
29+
private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr[Any] => Expr[Any], fixPoint: Boolean) {
30+
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
31+
val e2 = checkedTransform(e, preTransform)
32+
val e3 = rewriteChildren(e2)
33+
val e4 = checkedTransform(e3, postTransform)
34+
if fixPoint && e4 != e then rewrite(e4)
35+
else e4
36+
}
37+
38+
private def checkedTransform[T: Type](e: Expr[T], transform: Expr[T] => Expr[Any])(given QuoteContext): Expr[T] = {
39+
transform(e) match {
40+
case '{ $x: T } => x
41+
case '{ $x: $t } => throw new Exception(
42+
s"""Transformed
43+
|${e.show}
44+
|into
45+
|${x.show}
46+
|
47+
|Expected type to be
48+
|${summon[Type[T]].show}
49+
|but was
50+
|${t.show}
51+
""".stripMargin)
52+
}
53+
}
54+
55+
def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
56+
import qctx.tasty.{_, given}
57+
class MapChildren extends TreeMap {
58+
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
59+
case IsClosure(_) =>
60+
tree
61+
case IsInlined(_) | IsSelect(_) =>
62+
transformChildrenTerm(tree)
63+
case _ =>
64+
tree.tpe match {
65+
case IsMethodType(_) | IsPolyType(_) =>
66+
transformChildrenTerm(tree)
67+
case _ =>
68+
tree.seal match {
69+
case '{ $x: $t } => rewrite(x).unseal
70+
}
71+
}
72+
}
73+
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
74+
super.transformTerm(tree)
75+
}
76+
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
77+
}
78+
79+
}
80+
81+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
rewrite(Nil.map(x => x))
5+
rewrite(Nil.map(x => x) ++ Nil.map(x => x))
6+
rewrite(Nil.map(x => x) ++ List(3) ++ Nil)
7+
}
8+
9+
}

0 commit comments

Comments
 (0)