@@ -82,8 +82,8 @@ object Test {
82
82
Stream (mapRaw[Expr [A ], Expr [B ]](a => k => ' { ~ k(f(a)) }, stream))
83
83
}
84
84
85
- private def mapRaw [A , B ](f : (A => (B => Expr [Unit ]) => Expr [Unit ]), s : StagedStream [A ]): StagedStream [B ] = {
86
- s match {
85
+ private def mapRaw [A , B ](f : (A => (B => Expr [Unit ]) => Expr [Unit ]), stream : StagedStream [A ]): StagedStream [B ] = {
86
+ stream match {
87
87
case Linear (producer) => {
88
88
val prod = new Producer [B ] {
89
89
@@ -200,7 +200,7 @@ object Test {
200
200
}
201
201
}
202
202
203
- def takeRaw [A ](n : Expr [Int ], stream : StagedStream [A ]): StagedStream [A ] = {
203
+ private def takeRaw [A ](n : Expr [Int ], stream : StagedStream [A ]): StagedStream [A ] = {
204
204
stream match {
205
205
case Linear (producer) => {
206
206
mapRaw[(Var [Int ], A ), A ]((t : (Var [Int ], A )) => k => ' {
@@ -219,7 +219,50 @@ object Test {
219
219
}
220
220
}
221
221
222
- def take (n : Expr [Int ]): Stream [A ] = Stream (takeRaw[Expr [A ]](n, stream))
222
+ def take (n : Expr [Int ]): Stream [A ] = Stream (takeRaw[Expr [A ]](n, stream))
223
+
224
+ private def zipRaw [A , B ](stream1 : StagedStream [A ], stream2 : StagedStream [B ]): StagedStream [(A , B )] = {
225
+ (stream1, stream2) match {
226
+
227
+ case (Linear (producer1), Linear (producer2)) =>
228
+ Linear (zip_producer(producer1, producer2))
229
+
230
+ case (Linear (producer1), Nested (producer2, nestf2)) => ???
231
+
232
+ case (Nested (producer1, nestf1), Linear (producer2)) => ???
233
+
234
+ case (Nested (producer1, nestf1), Nested (producer2, nestf2)) => ???
235
+ }
236
+ }
237
+
238
+ private def zip_producer [A , B ](producer1 : Producer [A ], producer2 : Producer [B ]) = {
239
+ new Producer [(A , B )] {
240
+ type St = (producer1.St , producer2.St )
241
+
242
+ val card : Cardinality = Many
243
+
244
+ def init (k : St => Expr [Unit ]): Expr [Unit ] = {
245
+ producer1.init(s1 => ' { ~ producer2.init(s2 => ' { ~ k((s1, s2)) })})
246
+ }
247
+
248
+ def step (st : St , k : ((A , B )) => Expr [Unit ]): Expr [Unit ] = {
249
+ val (s1, s2) = st
250
+ producer1.step(s1, el1 => ' { ~ producer2.step(s2, el2 => ' { ~ k((el1, el2)) })})
251
+ }
252
+
253
+ def hasNext (st : St ): Expr [Boolean ] = {
254
+ val (s1, s2) = st
255
+ ' { ~ producer1.hasNext(s1) && ~ producer2.hasNext(s2) }
256
+ }
257
+ }
258
+ }
259
+
260
+ def zip [B : Type , C : Type ](f : (Expr [A ] => Expr [B ] => Expr [C ]), stream2 : Stream [B ]): Stream [C ] = {
261
+
262
+ val Stream (stream_b) = stream2
263
+
264
+ Stream (mapRaw[(Expr [A ], Expr [B ]), Expr [C ]]((t => k => ' { ~ k(f(t._1)(t._2)) }), zipRaw[Expr [A ], Expr [B ]](stream, stream_b)))
265
+ }
223
266
}
224
267
225
268
object Stream {
@@ -288,6 +331,11 @@ object Test {
288
331
.take(' {5 })
289
332
.fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
290
333
334
+ def test7 () = Stream
335
+ .of(' {Array (1 , 2 , 3 )})
336
+ .zip(((a : Expr [Int ]) => (b : Expr [Int ]) => ' { ~ a + ~ b }), Stream .of(' {Array (1 , 2 , 3 )}))
337
+ .fold(' {0 }, ((a : Expr [Int ], b : Expr [Int ]) => ' { ~ a + ~ b }))
338
+
291
339
def main (args : Array [String ]): Unit = {
292
340
println(test1().run)
293
341
println
@@ -300,6 +348,8 @@ object Test {
300
348
println(test5().run)
301
349
println
302
350
println(test6().run)
351
+ println
352
+ println(test7().run)
303
353
}
304
354
}
305
355
0 commit comments