Skip to content

Commit 8a90e45

Browse files
committed
Add type safe rewrite prototype
1 parent fe80b80 commit 8a90e45

File tree

4 files changed

+134
-1
lines changed

4 files changed

+134
-1
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Macro_1$package.plus(1, 4)
2+
5
3+
4+
Macro_1$package.plus(0, a)
5+
a
6+
7+
Macro_1$package.plus(a, b)
8+
a.+(b)
9+
10+
Macro_1$package.plus(Macro_1$package.plus(a, 0), Macro_1$package.plus(0, b))
11+
0.+(a).+(b)
12+
13+
Macro_1$package.power(4, 5)
14+
1024
15+
16+
Macro_1$package.power(a, 5)
17+
a.*(a.*(a.*(a.*(1.*(a)))))
18+
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def rewrite[T](x: => T): T = ${ rewriteMacro('x) }
5+
6+
def plus(x: Int, y: Int): Int = x + y
7+
def times(x: Int, y: Int): Int = x * y
8+
def power(x: Int, y: Int): Int = if y == 0 then 1 else times(x, power(x, y - 1))
9+
10+
private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
11+
val rewriter = Rewriter(
12+
postTransform = List(
13+
Transformation[Int] {
14+
case '{ plus($x, $y) } =>
15+
(x, y) match {
16+
case (Const(0), _) => y
17+
case (Const(a), Const(b)) => Expr(a + b)
18+
case (_, Const(_)) => '{ $y + $x }
19+
case _ => '{ $x + $y }
20+
}
21+
case '{ times($x, $y) } =>
22+
(x, y) match {
23+
case (Const(0), _) => '{0}
24+
case (Const(1), _) => y
25+
case (Const(a), Const(b)) => Expr(a * b)
26+
case (_, Const(_)) => '{ $y * $x }
27+
case _ => '{ $x * $y }
28+
}
29+
case '{ power(${Const(x)}, ${Const(y)}) } =>
30+
Expr(power(x, y))
31+
case '{ power($x, ${Const(y)}) } =>
32+
if y == 0 then '{1}
33+
else '{ times($x, power($x, ${Expr(y-1)})) }
34+
}),
35+
fixPoint = true
36+
)
37+
38+
val x2 = rewriter.rewrite(x)
39+
40+
'{
41+
println(${Expr(x.show)})
42+
println(${Expr(x2.show)})
43+
println()
44+
$x2
45+
}
46+
}
47+
48+
object Transformation {
49+
def apply[T: Type](transform: PartialFunction[Expr[T], Expr[T]]) =
50+
new Transformation(transform)
51+
}
52+
class Transformation[T: Type](transform: PartialFunction[Expr[T], Expr[T]]) {
53+
def apply[U: Type](e: Expr[U])(given QuoteContext): Expr[U] = {
54+
e match {
55+
case '{ $e: T } => transform.applyOrElse(e, identity) match { case '{ $e2: U } => e2 }
56+
case e => e
57+
}
58+
}
59+
}
60+
61+
private object Rewriter {
62+
def apply(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean = false): Rewriter =
63+
new Rewriter(preTransform, postTransform, fixPoint)
64+
}
65+
66+
private class Rewriter(preTransform: List[Transformation[_]] = Nil, postTransform: List[Transformation[_]] = Nil, fixPoint: Boolean) {
67+
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
68+
val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei))
69+
val e3 = rewriteChildren(e2)
70+
val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei))
71+
if fixPoint && e4 != e then rewrite(e4)
72+
else e4
73+
}
74+
75+
def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
76+
import qctx.tasty.{_, given}
77+
class MapChildren extends TreeMap {
78+
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
79+
case IsClosure(_) =>
80+
tree
81+
case IsInlined(_) | IsSelect(_) =>
82+
transformChildrenTerm(tree)
83+
case _ =>
84+
tree.tpe.widen match {
85+
case IsMethodType(_) | IsPolyType(_) =>
86+
transformChildrenTerm(tree)
87+
case _ =>
88+
tree.seal match {
89+
case '{ $x: $t } => rewrite(x).unseal
90+
}
91+
}
92+
}
93+
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
94+
super.transformTerm(tree)
95+
}
96+
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
97+
}
98+
99+
}
100+
101+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
val a: Int = 5
5+
val b: Int = 6
6+
rewrite(plus(1, 4))
7+
rewrite(plus(0, a))
8+
rewrite(plus(a, b))
9+
rewrite(plus(plus(a, 0), plus(0, b)))
10+
rewrite(power(4, 5))
11+
rewrite(power(a, 5))
12+
}
13+
14+
}

tests/run-macros/flops-rewrite/Macro_1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private class Rewriter(preTransform: Expr[Any] => Expr[Any], postTransform: Expr
6363
case IsInlined(_) | IsSelect(_) =>
6464
transformChildrenTerm(tree)
6565
case _ =>
66-
tree.tpe match {
66+
tree.tpe.widen match {
6767
case IsMethodType(_) | IsPolyType(_) =>
6868
transformChildrenTerm(tree)
6969
case _ =>

0 commit comments

Comments
 (0)