Skip to content

Commit 9748f0d

Browse files
committed
Implement zip (linear/nested)
1 parent 0328028 commit 9748f0d

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

tests/run-with-compiler-custom-args/staged-streams_1.check

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@
1010

1111
7
1212

13-
12
13+
12
14+
15+
15

tests/run-with-compiler-custom-args/staged-streams_1.scala

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,18 +227,53 @@ object Test {
227227
case (Linear(producer1), Linear(producer2)) =>
228228
Linear(zip_producer(producer1, producer2))
229229

230-
case (Linear(producer1), Nested(producer2, nestf2)) => ???
230+
case (Linear(producer1), Nested(producer2, nestf2)) =>
231+
pushLinear(producer1, producer2, nestf2)
231232

232233
case (Nested(producer1, nestf1), Linear(producer2)) => ???
233234

234235
case (Nested(producer1, nestf1), Nested(producer2, nestf2)) => ???
235236
}
236237
}
237238

239+
private def pushLinear[A, B, C](producer: Producer[A], nestedProducer: Producer[B], nestedf: (B => StagedStream[C])): StagedStream[(A, C)] = {
240+
val newProducer = new Producer[(Var[Boolean], producer.St, B)] {
241+
242+
type St = (Var[Boolean], producer.St, nestedProducer.St)
243+
val card: Cardinality = Many
244+
245+
def init(k: St => Expr[Unit]): Expr[Unit] = {
246+
producer.init(s1 => '{ ~nestedProducer.init(s2 =>
247+
Var('{ ~producer.hasNext(s1) }) { term1r =>
248+
k((term1r, s1, s2))
249+
})})
250+
}
251+
252+
def step(st: St, k: ((Var[Boolean], producer.St, B)) => Expr[Unit]): Expr[Unit] = {
253+
val (flag, s1, s2) = st
254+
nestedProducer.step(s2, b => '{ ~k((flag, s1, b)) })
255+
}
256+
257+
def hasNext(st: St): Expr[Boolean] = {
258+
val (flag, s1, s2) = st
259+
'{ ~flag.get && ~nestedProducer.hasNext(s2) }
260+
}
261+
}
262+
263+
Nested(newProducer, (t: (Var[Boolean], producer.St, B)) => {
264+
val (flag, s1, b) = t
265+
266+
mapRaw[C, (A, C)]((c => k => '{
267+
~producer.step(s1, a => '{ ~k((a, c)) })
268+
~flag.update(producer.hasNext(s1))
269+
}), moreTermination((b_flag: Expr[Boolean]) => '{ ~flag.get && ~b_flag }, nestedf(b)))
270+
})
271+
}
272+
238273
private def zip_producer[A, B](producer1: Producer[A], producer2: Producer[B]) = {
239274
new Producer[(A, B)] {
240-
type St = (producer1.St, producer2.St)
241275

276+
type St = (producer1.St, producer2.St)
242277
val card: Cardinality = Many
243278

244279
def init(k: St => Expr[Unit]): Expr[Unit] = {
@@ -336,6 +371,11 @@ object Test {
336371
.zip(((a : Expr[Int]) => (b : Expr[Int]) => '{ ~a + ~b }), Stream.of('{Array(1, 2, 3)}))
337372
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
338373

374+
def test8() = Stream
375+
.of('{Array(1, 2, 3)})
376+
.zip(((a : Expr[Int]) => (b : Expr[Int]) => '{ ~a + ~b }), Stream.of('{Array(1, 2, 3)}).flatMap((d: Expr[Int]) => Stream.of('{Array(1, 2, 3)}).map((dp: Expr[Int]) => '{ ~d + ~dp })))
377+
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
378+
339379
def main(args: Array[String]): Unit = {
340380
println(test1().run)
341381
println
@@ -350,6 +390,8 @@ object Test {
350390
println(test6().run)
351391
println
352392
println(test7().run)
393+
println
394+
println(test8().run)
353395
}
354396
}
355397

0 commit comments

Comments
 (0)