Skip to content

Commit 0328028

Browse files
committed
Implement zip (for linear streams)
1 parent a2e4c49 commit 0328028

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

tests/run-with-compiler-custom-args/staged-streams_1.check

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88

99
3
1010

11-
7
11+
7
12+
13+
12

tests/run-with-compiler-custom-args/staged-streams_1.scala

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ object Test {
8282
Stream(mapRaw[Expr[A], Expr[B]](a => k => '{ ~k(f(a)) }, stream))
8383
}
8484

85-
private def mapRaw[A, B](f: (A => (B => Expr[Unit]) => Expr[Unit]), s: StagedStream[A]): StagedStream[B] = {
86-
s match {
85+
private def mapRaw[A, B](f: (A => (B => Expr[Unit]) => Expr[Unit]), stream: StagedStream[A]): StagedStream[B] = {
86+
stream match {
8787
case Linear(producer) => {
8888
val prod = new Producer[B] {
8989

@@ -200,7 +200,7 @@ object Test {
200200
}
201201
}
202202

203-
def takeRaw[A](n: Expr[Int], stream: StagedStream[A]): StagedStream[A] = {
203+
private def takeRaw[A](n: Expr[Int], stream: StagedStream[A]): StagedStream[A] = {
204204
stream match {
205205
case Linear(producer) => {
206206
mapRaw[(Var[Int], A), A]((t: (Var[Int], A)) => k => '{
@@ -219,7 +219,50 @@ object Test {
219219
}
220220
}
221221

222-
def take(n: Expr[Int]): Stream[A] = Stream(takeRaw[Expr[A]](n, stream))
222+
def take(n: Expr[Int]): Stream[A] = Stream(takeRaw[Expr[A]](n, stream))
223+
224+
private def zipRaw[A, B](stream1: StagedStream[A], stream2: StagedStream[B]): StagedStream[(A, B)] = {
225+
(stream1, stream2) match {
226+
227+
case (Linear(producer1), Linear(producer2)) =>
228+
Linear(zip_producer(producer1, producer2))
229+
230+
case (Linear(producer1), Nested(producer2, nestf2)) => ???
231+
232+
case (Nested(producer1, nestf1), Linear(producer2)) => ???
233+
234+
case (Nested(producer1, nestf1), Nested(producer2, nestf2)) => ???
235+
}
236+
}
237+
238+
private def zip_producer[A, B](producer1: Producer[A], producer2: Producer[B]) = {
239+
new Producer[(A, B)] {
240+
type St = (producer1.St, producer2.St)
241+
242+
val card: Cardinality = Many
243+
244+
def init(k: St => Expr[Unit]): Expr[Unit] = {
245+
producer1.init(s1 => '{ ~producer2.init(s2 => '{ ~k((s1, s2)) })})
246+
}
247+
248+
def step(st: St, k: ((A, B)) => Expr[Unit]): Expr[Unit] = {
249+
val (s1, s2) = st
250+
producer1.step(s1, el1 => '{ ~producer2.step(s2, el2 => '{ ~k((el1, el2)) })})
251+
}
252+
253+
def hasNext(st: St): Expr[Boolean] = {
254+
val (s1, s2) = st
255+
'{ ~producer1.hasNext(s1) && ~producer2.hasNext(s2) }
256+
}
257+
}
258+
}
259+
260+
def zip[B : Type, C : Type](f: (Expr[A] => Expr[B] => Expr[C]), stream2: Stream[B]): Stream[C] = {
261+
262+
val Stream(stream_b) = stream2
263+
264+
Stream(mapRaw[(Expr[A], Expr[B]), Expr[C]]((t => k => '{ ~k(f(t._1)(t._2)) }), zipRaw[Expr[A], Expr[B]](stream, stream_b)))
265+
}
223266
}
224267

225268
object Stream {
@@ -288,6 +331,11 @@ object Test {
288331
.take('{5})
289332
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
290333

334+
def test7() = Stream
335+
.of('{Array(1, 2, 3)})
336+
.zip(((a : Expr[Int]) => (b : Expr[Int]) => '{ ~a + ~b }), Stream.of('{Array(1, 2, 3)}))
337+
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
338+
291339
def main(args: Array[String]): Unit = {
292340
println(test1().run)
293341
println
@@ -300,6 +348,8 @@ object Test {
300348
println(test5().run)
301349
println
302350
println(test6().run)
351+
println
352+
println(test7().run)
303353
}
304354
}
305355

0 commit comments

Comments
 (0)