Skip to content

Commit 4237152

Browse files
authored
Merge pull request #7619 from dotty-staging/add-patmat-power-example
Add simple power test for patmat
2 parents a39f1c5 + 37540c7 commit 4237152

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
32.0
2+
512.0
3+
2.0
4+
1.0
5+
64.0
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
object Macros {
5+
6+
def power_s(x: Expr[Double], n: Int)(given QuoteContext): Expr[Double] =
7+
if (n == 0) '{1.0}
8+
else if (n % 2 == 1) '{ $x * ${power_s(x, n - 1)} }
9+
else '{ val y = $x * $x; ${power_s('y, n / 2)} }
10+
11+
inline def power(x: Double, inline n: Int): Double =
12+
${power_s('x, n)}
13+
14+
def power2(x: Double, y: Double): Double = if y == 0.0 then 1.0 else x * power2(x, y - 1.0)
15+
16+
inline def rewrite(expr: => Double): Double = ${rewrite('expr)}
17+
18+
// simple, 1-level, non-recursive rewriter for exponents
19+
def rewrite(expr: Expr[Double])(given QuoteContext): Expr[Double] = {
20+
val res = expr match {
21+
// product rule
22+
case '{ power2($a, $x) * power2($b, $y)} if a.matches(b) => '{ power2($a, $x + $y) }
23+
// rules of 1
24+
case '{ power2($a, 1)} => a
25+
case '{ power2(1, $a)} => '{ 1.0 }
26+
// rule of 0
27+
case '{ power2($a, 0)} => '{ 1.0 }
28+
// power rule
29+
case '{ power2(power2($a, $x), $y)} => '{ power2($a, $x * $y ) }
30+
case _ => expr
31+
}
32+
println(res.show)
33+
res
34+
}
35+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import Macros._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
val x = 2
7+
println(power(x, 5))
8+
9+
println(rewrite{ power2(2.0, 5.0) * power2(2.0, 4.0) })
10+
11+
println(rewrite{ power2(2.0, 1.0) })
12+
13+
println(rewrite{ power2(1.0, 1000) })
14+
15+
println(rewrite{ power2(power2(2.0, 2.0), 3.0) })
16+
}
17+
}

0 commit comments

Comments
 (0)