Skip to content

Commit 8d209a2

Browse files
committed
Add rangeunrolling and array lifter
1 parent baf8605 commit 8d209a2

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

tests/run-with-compiler/quote-unrolled-foreach.check

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,34 @@
6060
f.apply(arr.apply(i.+(1)))
6161
f.apply(arr.apply(i.+(2)))
6262
f.apply(arr.apply(i.+(3)))
63-
()
6463
i = i.+(4)
6564
}
6665
})
66+
67+
{
68+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](4)
69+
array.update(0, 1)
70+
array.update(1, 2)
71+
array.update(2, 3)
72+
array.update(3, 4)
73+
(array: scala.Array[scala.Int])
74+
}
75+
76+
{
77+
val arr1: scala.Array[scala.Int] = {
78+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](4)
79+
array.update(0, 1)
80+
array.update(1, 3)
81+
array.update(2, 4)
82+
array.update(3, 5)
83+
(array: scala.Array[scala.Int])
84+
}
85+
val size: scala.Int = arr1.length
86+
var i: scala.Int = 0
87+
while (i.<(size)) {
88+
val element: scala.Int = arr1.apply(i)
89+
90+
((x: scala.Int) => scala.Predef.println(x)).apply(element)
91+
i = i.+(1)
92+
}
93+
}

tests/run-with-compiler/quote-unrolled-foreach.scala

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ object Test {
2727

2828
val code4 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach4('(arr), '(f), 4) }
2929
println(code4.show)
30+
println()
31+
32+
val liftedArray = Array(1, 2, 3, 4).toExpr
33+
println(liftedArray.show)
34+
println()
35+
36+
37+
def printAll(arr: Array[Int]) = '{
38+
val arr1 = ~arr.toExpr
39+
~foreach1('(arr1), '(x => println(x)))
40+
}
41+
42+
println(printAll(Array(1, 3, 4, 5)).show)
43+
3044
}
3145

3246
def foreach1(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
@@ -81,19 +95,40 @@ object Test {
8195
}
8296
}
8397

98+
def foreach3_2(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
99+
val size = (~arrRef).length
100+
var i = 0
101+
if (size % 3 != 0) throw new Exception("...")// for simplicity of the implementation
102+
while (i < size) {
103+
(~f)((~arrRef)(i))
104+
(~f)((~arrRef)(i + 1))
105+
(~f)((~arrRef)(i + 2))
106+
i += 3
107+
}
108+
}
109+
84110
def foreach4(arrRef: Expr[Array[Int]], f: Expr[Int => Unit], unrollSize: Int): Expr[Unit] = '{
85111
val size = (~arrRef).length
86112
var i = 0
87113
if (size % ~unrollSize.toExpr != 0) throw new Exception("...") // for simplicity of the implementation
88114
while (i < size) {
89-
~{
90-
@tailrec def loop(j: Int, acc: Expr[Unit]): Expr[Unit] =
91-
if (j >= 0) loop(j - 1, '{ (~f)((~arrRef)(i + ~j.toExpr)); ~acc })
92-
else acc
93-
loop(unrollSize - 1, '())
94-
}
115+
~foreachInRange(0, unrollSize)(j => '{ (~f)((~arrRef)(i + ~j.toExpr)) })
95116
i += ~unrollSize.toExpr
96117
}
97118
}
98119

120+
implicit object ArrayIntIsLiftable extends Liftable[Array[Int]] {
121+
override def toExpr(x: Array[Int]): Expr[Array[Int]] = '{
122+
val array = new Array[Int](~x.length.toExpr)
123+
~foreachInRange(0, x.length)(i => '{ array(~i.toExpr) = ~x(i).toExpr})
124+
array
125+
}
126+
}
127+
128+
def foreachInRange(start: Int, end: Int)(f: Int => Expr[Unit]): Expr[Unit] = {
129+
@tailrec def unroll(i: Int, acc: Expr[Unit]): Expr[Unit] =
130+
if (i < end) unroll(i + 1, '{ ~acc; ~f(i) }) else acc
131+
if (start < end) unroll(start + 1, f(start)) else '()
132+
}
133+
99134
}

0 commit comments

Comments
 (0)