Skip to content

Commit 6558b34

Browse files
committed
Add a unified prototype
1 parent 8a90e45 commit 6558b34

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Macro_1$package.plus(1, 4)
2+
5
3+
4+
Macro_1$package.plus(0, a)
5+
a
6+
7+
Macro_1$package.plus(a, b)
8+
a.+(b)
9+
10+
Macro_1$package.plus(Macro_1$package.plus(a, 0), Macro_1$package.plus(0, b))
11+
0.+(a).+(b)
12+
13+
Macro_1$package.power(4, 5)
14+
1024
15+
16+
Macro_1$package.power(a, 5)
17+
a.*(a.*(a.*(a.*(1.*(a)))))
18+
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def rewrite[T](x: => T): T = ${ rewriteMacro('x) }
5+
6+
def plus(x: Int, y: Int): Int = x + y
7+
def times(x: Int, y: Int): Int = x * y
8+
def power(x: Int, y: Int): Int = if y == 0 then 1 else times(x, power(x, y - 1))
9+
10+
private def rewriteMacro[T: Type](x: Expr[T])(given QuoteContext): Expr[T] = {
11+
val rewriter = Rewriter().withFixPoint.withPost(
12+
Transformation.safe[Int] {
13+
case '{ plus($x, $y) } =>
14+
(x, y) match {
15+
case (Const(0), _) => y
16+
case (Const(a), Const(b)) => Expr(a + b)
17+
case (_, Const(_)) => '{ $y + $x }
18+
case _ => '{ $x + $y }
19+
}
20+
case '{ times($x, $y) } =>
21+
(x, y) match {
22+
case (Const(0), _) => '{0}
23+
case (Const(1), _) => y
24+
case (Const(a), Const(b)) => Expr(a * b)
25+
case (_, Const(_)) => '{ $y * $x }
26+
case _ => '{ $x * $y }
27+
}
28+
case '{ power(${Const(x)}, ${Const(y)}) } =>
29+
Expr(power(x, y))
30+
case '{ power($x, ${Const(y)}) } =>
31+
if y == 0 then '{1}
32+
else '{ times($x, power($x, ${Expr(y-1)})) }
33+
}
34+
)
35+
36+
val x2 = rewriter.rewrite(x)
37+
38+
'{
39+
println(${Expr(x.show)})
40+
println(${Expr(x2.show)})
41+
println()
42+
$x2
43+
}
44+
}
45+
46+
object Transformation {
47+
/** A restrictive transformer that is guaranteed to generate type correct code */
48+
def safe[T: Type](transform: PartialFunction[Expr[T], Expr[T]]): Transformation =
49+
new SafeTransformation(transform)
50+
51+
/** A general purpose transformer that may fail while transforming.
52+
* It will check the type of the returned Expr and will throw if the type does not conform to the expected type.
53+
*/
54+
def checked(transform: PartialFunction[Expr[Any], Expr[Any]]): Transformation =
55+
new CheckedTransformation(transform)
56+
}
57+
58+
class CheckedTransformation(transform: PartialFunction[Expr[Any], Expr[Any]]) extends Transformation {
59+
def apply[T: Type](e: Expr[T])(given QuoteContext): Expr[T] = {
60+
transform.applyOrElse(e, identity) match {
61+
case '{ $e2: T } => e2
62+
case '{ $e2: $t } =>
63+
throw new Exception(
64+
s"""Transformed
65+
|${e.show}
66+
|into
67+
|${e2.show}
68+
|
69+
|Expected type to be
70+
|${summon[Type[T]].show}
71+
|but was
72+
|${t.show}
73+
""".stripMargin)
74+
}
75+
}
76+
}
77+
78+
class SafeTransformation[U: Type](transform: PartialFunction[Expr[U], Expr[U]]) extends Transformation {
79+
def apply[T: Type](e: Expr[T])(given QuoteContext): Expr[T] = {
80+
e match {
81+
case '{ $e: U } => transform.applyOrElse(e, identity) match { case '{ $e2: T } => e2 }
82+
case e => e
83+
}
84+
}
85+
}
86+
87+
abstract class Transformation {
88+
def apply[T: Type](e: Expr[T])(given QuoteContext): Expr[T]
89+
}
90+
91+
private object Rewriter {
92+
def apply(): Rewriter = new Rewriter(Nil, Nil, false)
93+
}
94+
95+
private class Rewriter private (preTransform: List[Transformation] = Nil, postTransform: List[Transformation] = Nil, fixPoint: Boolean) {
96+
97+
def withFixPoint: Rewriter =
98+
new Rewriter(preTransform, postTransform, fixPoint = true)
99+
def withPre(transform: Transformation): Rewriter =
100+
new Rewriter(transform :: preTransform, postTransform, fixPoint)
101+
def withPost(transform: Transformation): Rewriter =
102+
new Rewriter(preTransform, transform :: postTransform, fixPoint)
103+
104+
def rewrite[T](e: Expr[T])(given QuoteContext, Type[T]): Expr[T] = {
105+
val e2 = preTransform.foldLeft(e)((ei, transform) => transform(ei))
106+
val e3 = rewriteChildren(e2)
107+
val e4 = postTransform.foldLeft(e3)((ei, transform) => transform(ei))
108+
if fixPoint && e4 != e then rewrite(e4) else e4
109+
}
110+
111+
def rewriteChildren[T: Type](e: Expr[T])(given qctx: QuoteContext): Expr[T] = {
112+
import qctx.tasty.{_, given}
113+
class MapChildren extends TreeMap {
114+
override def transformTerm(tree: Term)(given ctx: Context): Term = tree match {
115+
case IsClosure(_) =>
116+
tree
117+
case IsInlined(_) | IsSelect(_) =>
118+
transformChildrenTerm(tree)
119+
case _ =>
120+
tree.tpe.widen match {
121+
case IsMethodType(_) | IsPolyType(_) =>
122+
transformChildrenTerm(tree)
123+
case _ =>
124+
tree.seal match {
125+
case '{ $x: $t } => rewrite(x).unseal
126+
}
127+
}
128+
}
129+
def transformChildrenTerm(tree: Term)(given ctx: Context): Term =
130+
super.transformTerm(tree)
131+
}
132+
(new MapChildren).transformChildrenTerm(e.unseal).seal.cast[T] // Cast will only fail if this implementation has a bug
133+
}
134+
135+
}
136+
137+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
3+
def main(args: Array[String]): Unit = {
4+
val a: Int = 5
5+
val b: Int = 6
6+
rewrite(plus(1, 4))
7+
rewrite(plus(0, a))
8+
rewrite(plus(a, b))
9+
rewrite(plus(plus(a, 0), plus(0, b)))
10+
rewrite(power(4, 5))
11+
rewrite(power(a, 5))
12+
}
13+
14+
}

0 commit comments

Comments
 (0)