@@ -88,7 +88,6 @@ object Test {
88
88
val prod = new Producer [B ] {
89
89
90
90
type St = producer.St
91
-
92
91
val card = producer.card
93
92
94
93
def init (k : St => Expr [Unit ]): Expr [Unit ] = {
@@ -143,78 +142,84 @@ object Test {
143
142
Stream (flatMapRaw[Expr [A ], Expr [A ]]((a => { Linear (filterStream(a)) }), stream))
144
143
}
145
144
146
- // def moreTermination[A](f: Rep[Boolean] => Rep[Boolean], stream: StagedStream[A]): StagedStream[A] = {
147
- // def addToProducer[A](f: Rep[Boolean] => Rep[Boolean], producer: Producer[A]): Producer[A] = {
148
- // producer.card match {
149
- // case Many =>
150
- // new Producer[A] {
151
- // type St = producer.St
152
-
153
- // val card = producer.card
154
- // def init(k: St => Rep[Unit]): Rep[Unit] =
155
- // producer.init(k)
156
- // def step(st: St, k: (A => Rep[Unit])): Rep[Unit] =
157
- // producer.step(st, el => k(el))
158
- // def hasNext(st: St): Rep[Boolean] =
159
- // f(producer.hasNext(st))
160
- // }
161
- // case AtMost1 => producer
162
- // }
163
- // }
164
- // stream match {
165
- // case Linear(producer) => Linear(addToProducer(f, producer))
166
- // case Nested(producer, nestedf) =>
167
- // Nested(addToProducer(f, producer), (a: Id[_]) => moreTermination(f, nestedf(a)))
168
- // }
169
- // }
170
-
171
- // private def addCounter[A](n: Expr[Int], producer: Producer[A]): Producer[(Expr[Int], A)] =
172
- // new Producer[(Var[Int], A)] {
173
- // type St = (Var[Int], producer.St)
174
-
175
- // val card = producer.card
176
- // def init(k: St => Rep[Unit]): Rep[Unit] = {
177
- // producer.init(st => {
178
- // var counter: Var[Int] = n
179
- // k(counter, st)
180
- // })
181
- // }
182
- // def step(st: St, k: (((Var[Int], A)) => Rep[Unit])): Rep[Unit] = {
183
- // val (counter, nst) = st
184
- // producer.step(nst, el => {
185
- // k((counter, el))
186
- // })
187
- // }
188
- // def hasNext(st: St): Rep[Boolean] = {
189
- // val (counter, nst) = st
190
- // producer.card match {
191
- // case Many => counter > 0 && producer.hasNext(nst)
192
- // case AtMost1 => producer.hasNext(nst)
193
- // }
194
- // }
195
- // }
196
-
197
- // def takeRaw[A](n: Rep[Int], stream: StagedStream[A]): StagedStream[A] = {
198
- // stream match {
199
- // case Linear(producer) => {
200
- // mapRaw[(Var[Int], A), A]((t => k => {
201
- // t._1 = t._1 - 1
202
- // k(t._2)
203
- // }), Linear(addCounter(n, producer)))
204
- // }
205
- // case Nested(producer, nestedf) => {
206
- // Nested(addCounter(n, producer), (t: (Var[Int], Id[_])) => {
207
- // mapRaw[A, A]((el => k => {
208
- // t._1 = t._1 - 1
209
- // k(el)
210
- // }), moreTermination(b => t._1 > 0 && b, nestedf(t._2)))
211
- // })
212
- // }
213
- // }
214
- // }
215
-
216
- // def take(n: Rep[Int]): Stream[A] = Stream(takeRaw(n, stream))
145
+ private def moreTermination [A ](f : Expr [Boolean ] => Expr [Boolean ], stream : StagedStream [A ]): StagedStream [A ] = {
146
+ def addToProducer [A ](f : Expr [Boolean ] => Expr [Boolean ], producer : Producer [A ]): Producer [A ] = {
147
+ producer.card match {
148
+ case Many =>
149
+ new Producer [A ] {
150
+ type St = producer.St
151
+ val card = producer.card
152
+
153
+ def init (k : St => Expr [Unit ]): Expr [Unit ] =
154
+ producer.init(k)
155
+
156
+ def step (st : St , k : (A => Expr [Unit ])): Expr [Unit ] =
157
+ producer.step(st, el => k(el))
158
+
159
+ def hasNext (st : St ): Expr [Boolean ] =
160
+ f(producer.hasNext(st))
161
+ }
162
+ case AtMost1 => producer
163
+ }
164
+ }
165
+
166
+ stream match {
167
+ case Linear (producer) => Linear (addToProducer(f, producer))
168
+ case nested : Nested [a, bt] =>
169
+ Nested (addToProducer(f, nested.producer), (a : bt) => moreTermination(f, nested.nestedf(a)))
170
+ }
171
+ }
172
+
173
+ private def addCounter [A ](n : Expr [Int ], producer : Producer [A ]): Producer [(Var [Int ], A )] = {
174
+ new Producer [(Var [Int ], A )] {
175
+ type St = (Var [Int ], producer.St )
176
+ val card = producer.card
177
+
178
+ def init (k : St => Expr [Unit ]): Expr [Unit ] = {
179
+ producer.init(st => {
180
+ Var (n) { counter =>
181
+ k(counter, st)
182
+ }
183
+ })
184
+ }
185
+
186
+ def step (st : St , k : (((Var [Int ], A )) => Expr [Unit ])): Expr [Unit ] = {
187
+ val (counter, nst) = st
188
+ producer.step(nst, el => ' {
189
+ ~ k((counter, el))
190
+ })
191
+ }
192
+
193
+ def hasNext (st : St ): Expr [Boolean ] = {
194
+ val (counter, nst) = st
195
+ producer.card match {
196
+ case Many => ' { ~ counter.get > 0 && ~ producer.hasNext(nst) }
197
+ case AtMost1 => ' { ~ producer.hasNext(nst) }
198
+ }
199
+ }
200
+ }
201
+ }
202
+
203
+ def takeRaw [A ](n : Expr [Int ], stream : StagedStream [A ]): StagedStream [A ] = {
204
+ stream match {
205
+ case Linear (producer) => {
206
+ mapRaw[(Var [Int ], A ), A ]((t : (Var [Int ], A )) => k => ' {
207
+ ~ t._1.update(' {~ t._1.get - 1 })
208
+ ~ k(t._2)
209
+ }, Linear (addCounter(n, producer)))
210
+ }
211
+ case nested : Nested [a, bt] => {
212
+ Nested (addCounter(n, nested.producer), (t : (Var [Int ], bt)) => {
213
+ mapRaw[A , A ]((el => k => ' {
214
+ ~ t._1.update(' {~ t._1.get - 1 })
215
+ ~ k(el)
216
+ }), moreTermination(b => ' { ~ t._1.get > 0 && ~ b}, nested.nestedf(t._2)))
217
+ })
218
+ }
219
+ }
220
+ }
217
221
222
+ def take (n : Expr [Int ]): Stream [A ] = Stream (takeRaw[Expr [A ]](n, stream))
218
223
}
219
224
220
225
object Stream {
@@ -272,6 +277,17 @@ object Test {
272
277
.filter((d : Expr [Int ]) => ' { ~ d % 2 == 0 })
273
278
.fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
274
279
280
+ def test5 () = Stream
281
+ .of(' {Array (1 , 2 , 3 )})
282
+ .take(' {2 })
283
+ .fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
284
+
285
+ def test6 () = Stream
286
+ .of(' {Array (1 , 1 , 1 )})
287
+ .flatMap((d : Expr [Int ]) => Stream .of(' {Array (1 , 2 , 3 )}).take(' {2 }))
288
+ .take(' {5 })
289
+ .fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
290
+
275
291
def main (args : Array [String ]): Unit = {
276
292
println(test1().run)
277
293
println
@@ -280,6 +296,10 @@ object Test {
280
296
println(test3().run)
281
297
println
282
298
println(test4().run)
299
+ println
300
+ println(test5().run)
301
+ println
302
+ println(test6().run)
283
303
}
284
304
}
285
305
0 commit comments