@@ -27,6 +27,20 @@ object Test {
27
27
28
28
val code4 = ' { (arr : Array [Int ], f : Int => Unit ) => ~ foreach4('(arr), ' (f), 4 ) }
29
29
println(code4.show)
30
+ println()
31
+
32
+ val liftedArray = Array (1 , 2 , 3 , 4 ).toExpr
33
+ println(liftedArray.show)
34
+ println()
35
+
36
+
37
+ def printAll (arr : Array [Int ]) = ' {
38
+ val arr1 = ~ arr.toExpr
39
+ ~ foreach1('(arr1), ' (x => println(x)))
40
+ }
41
+
42
+ println(printAll(Array (1 , 3 , 4 , 5 )).show)
43
+
30
44
}
31
45
32
46
def foreach1 (arrRef : Expr [Array [Int ]], f : Expr [Int => Unit ]): Expr [Unit ] = ' {
@@ -81,19 +95,40 @@ object Test {
81
95
}
82
96
}
83
97
98
+ def foreach3_2 (arrRef : Expr [Array [Int ]], f : Expr [Int => Unit ]): Expr [Unit ] = ' {
99
+ val size = (~ arrRef).length
100
+ var i = 0
101
+ if (size % 3 != 0 ) throw new Exception (" ..." )// for simplicity of the implementation
102
+ while (i < size) {
103
+ (~ f)((~ arrRef)(i))
104
+ (~ f)((~ arrRef)(i + 1 ))
105
+ (~ f)((~ arrRef)(i + 2 ))
106
+ i += 3
107
+ }
108
+ }
109
+
84
110
def foreach4 (arrRef : Expr [Array [Int ]], f : Expr [Int => Unit ], unrollSize : Int ): Expr [Unit ] = ' {
85
111
val size = (~ arrRef).length
86
112
var i = 0
87
113
if (size % ~ unrollSize.toExpr != 0 ) throw new Exception (" ..." ) // for simplicity of the implementation
88
114
while (i < size) {
89
- ~ {
90
- @ tailrec def loop (j : Int , acc : Expr [Unit ]): Expr [Unit ] =
91
- if (j >= 0 ) loop(j - 1 , ' { (~ f)((~ arrRef)(i + ~ j.toExpr)); ~ acc })
92
- else acc
93
- loop(unrollSize - 1 , '())
94
- }
115
+ ~ foreachInRange(0 , unrollSize)(j => ' { (~ f)((~ arrRef)(i + ~ j.toExpr)) })
95
116
i += ~ unrollSize.toExpr
96
117
}
97
118
}
98
119
120
+ implicit object ArrayIntIsLiftable extends Liftable [Array [Int ]] {
121
+ override def toExpr (x : Array [Int ]): Expr [Array [Int ]] = ' {
122
+ val array = new Array [Int ](~ x.length.toExpr)
123
+ ~ foreachInRange(0 , x.length)(i => ' { array(~ i.toExpr) = ~ x(i).toExpr})
124
+ array
125
+ }
126
+ }
127
+
128
+ def foreachInRange (start : Int , end : Int )(f : Int => Expr [Unit ]): Expr [Unit ] = {
129
+ @ tailrec def unroll (i : Int , acc : Expr [Unit ]): Expr [Unit ] =
130
+ if (i < end) unroll(i + 1 , ' { ~ acc; ~ f(i) }) else acc
131
+ if (start < end) unroll(start + 1 , f(start)) else '()
132
+ }
133
+
99
134
}
0 commit comments