Skip to content

Commit ad06ab0

Browse files
committed
Implement take
1 parent 7bfbb06 commit ad06ab0

File tree

2 files changed

+97
-73
lines changed

2 files changed

+97
-73
lines changed

tests/run-with-compiler/staged-streams_1.check

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,8 @@
44

55
36
66

7-
2
7+
2
8+
9+
3
10+
11+
7

tests/run-with-compiler/staged-streams_1.scala

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ object Test {
8888
val prod = new Producer[B] {
8989

9090
type St = producer.St
91-
9291
val card = producer.card
9392

9493
def init(k: St => Expr[Unit]): Expr[Unit] = {
@@ -143,78 +142,84 @@ object Test {
143142
Stream(flatMapRaw[Expr[A], Expr[A]]((a => { Linear(filterStream(a)) }), stream))
144143
}
145144

146-
// def moreTermination[A](f: Rep[Boolean] => Rep[Boolean], stream: StagedStream[A]): StagedStream[A] = {
147-
// def addToProducer[A](f: Rep[Boolean] => Rep[Boolean], producer: Producer[A]): Producer[A] = {
148-
// producer.card match {
149-
// case Many =>
150-
// new Producer[A] {
151-
// type St = producer.St
152-
153-
// val card = producer.card
154-
// def init(k: St => Rep[Unit]): Rep[Unit] =
155-
// producer.init(k)
156-
// def step(st: St, k: (A => Rep[Unit])): Rep[Unit] =
157-
// producer.step(st, el => k(el))
158-
// def hasNext(st: St): Rep[Boolean] =
159-
// f(producer.hasNext(st))
160-
// }
161-
// case AtMost1 => producer
162-
// }
163-
// }
164-
// stream match {
165-
// case Linear(producer) => Linear(addToProducer(f, producer))
166-
// case Nested(producer, nestedf) =>
167-
// Nested(addToProducer(f, producer), (a: Id[_]) => moreTermination(f, nestedf(a)))
168-
// }
169-
// }
170-
171-
// private def addCounter[A](n: Expr[Int], producer: Producer[A]): Producer[(Expr[Int], A)] =
172-
// new Producer[(Var[Int], A)] {
173-
// type St = (Var[Int], producer.St)
174-
175-
// val card = producer.card
176-
// def init(k: St => Rep[Unit]): Rep[Unit] = {
177-
// producer.init(st => {
178-
// var counter: Var[Int] = n
179-
// k(counter, st)
180-
// })
181-
// }
182-
// def step(st: St, k: (((Var[Int], A)) => Rep[Unit])): Rep[Unit] = {
183-
// val (counter, nst) = st
184-
// producer.step(nst, el => {
185-
// k((counter, el))
186-
// })
187-
// }
188-
// def hasNext(st: St): Rep[Boolean] = {
189-
// val (counter, nst) = st
190-
// producer.card match {
191-
// case Many => counter > 0 && producer.hasNext(nst)
192-
// case AtMost1 => producer.hasNext(nst)
193-
// }
194-
// }
195-
// }
196-
197-
// def takeRaw[A](n: Rep[Int], stream: StagedStream[A]): StagedStream[A] = {
198-
// stream match {
199-
// case Linear(producer) => {
200-
// mapRaw[(Var[Int], A), A]((t => k => {
201-
// t._1 = t._1 - 1
202-
// k(t._2)
203-
// }), Linear(addCounter(n, producer)))
204-
// }
205-
// case Nested(producer, nestedf) => {
206-
// Nested(addCounter(n, producer), (t: (Var[Int], Id[_])) => {
207-
// mapRaw[A, A]((el => k => {
208-
// t._1 = t._1 - 1
209-
// k(el)
210-
// }), moreTermination(b => t._1 > 0 && b, nestedf(t._2)))
211-
// })
212-
// }
213-
// }
214-
// }
215-
216-
// def take(n: Rep[Int]): Stream[A] = Stream(takeRaw(n, stream))
145+
private def moreTermination[A](f: Expr[Boolean] => Expr[Boolean], stream: StagedStream[A]): StagedStream[A] = {
146+
def addToProducer[A](f: Expr[Boolean] => Expr[Boolean], producer: Producer[A]): Producer[A] = {
147+
producer.card match {
148+
case Many =>
149+
new Producer[A] {
150+
type St = producer.St
151+
val card = producer.card
152+
153+
def init(k: St => Expr[Unit]): Expr[Unit] =
154+
producer.init(k)
155+
156+
def step(st: St, k: (A => Expr[Unit])): Expr[Unit] =
157+
producer.step(st, el => k(el))
158+
159+
def hasNext(st: St): Expr[Boolean] =
160+
f(producer.hasNext(st))
161+
}
162+
case AtMost1 => producer
163+
}
164+
}
165+
166+
stream match {
167+
case Linear(producer) => Linear(addToProducer(f, producer))
168+
case nested: Nested[a, bt] =>
169+
Nested(addToProducer(f, nested.producer), (a: bt) => moreTermination(f, nested.nestedf(a)))
170+
}
171+
}
172+
173+
private def addCounter[A](n: Expr[Int], producer: Producer[A]): Producer[(Var[Int], A)] = {
174+
new Producer[(Var[Int], A)] {
175+
type St = (Var[Int], producer.St)
176+
val card = producer.card
177+
178+
def init(k: St => Expr[Unit]): Expr[Unit] = {
179+
producer.init(st => {
180+
Var(n) { counter =>
181+
k(counter, st)
182+
}
183+
})
184+
}
185+
186+
def step(st: St, k: (((Var[Int], A)) => Expr[Unit])): Expr[Unit] = {
187+
val (counter, nst) = st
188+
producer.step(nst, el => '{
189+
~k((counter, el))
190+
})
191+
}
192+
193+
def hasNext(st: St): Expr[Boolean] = {
194+
val (counter, nst) = st
195+
producer.card match {
196+
case Many => '{ ~counter.get > 0 && ~producer.hasNext(nst) }
197+
case AtMost1 => '{ ~producer.hasNext(nst) }
198+
}
199+
}
200+
}
201+
}
202+
203+
def takeRaw[A](n: Expr[Int], stream: StagedStream[A]): StagedStream[A] = {
204+
stream match {
205+
case Linear(producer) => {
206+
mapRaw[(Var[Int], A), A]((t: (Var[Int], A)) => k => '{
207+
~t._1.update('{~t._1.get - 1})
208+
~k(t._2)
209+
}, Linear(addCounter(n, producer)))
210+
}
211+
case nested: Nested[a, bt] => {
212+
Nested(addCounter(n, nested.producer), (t: (Var[Int], bt)) => {
213+
mapRaw[A, A]((el => k => '{
214+
~t._1.update('{~t._1.get - 1})
215+
~k(el)
216+
}), moreTermination(b => '{ ~t._1.get > 0 && ~b}, nested.nestedf(t._2)))
217+
})
218+
}
219+
}
220+
}
217221

222+
def take(n: Expr[Int]): Stream[A] = Stream(takeRaw[Expr[A]](n, stream))
218223
}
219224

220225
object Stream {
@@ -272,6 +277,17 @@ object Test {
272277
.filter((d: Expr[Int]) => '{ ~d % 2 == 0 })
273278
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
274279

280+
def test5() = Stream
281+
.of('{Array(1, 2, 3)})
282+
.take('{2})
283+
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
284+
285+
def test6() = Stream
286+
.of('{Array(1, 1, 1)})
287+
.flatMap((d: Expr[Int]) => Stream.of('{Array(1, 2, 3)}).take('{2}))
288+
.take('{5})
289+
.fold('{0}, ((a: Expr[Int], b : Expr[Int]) => '{ ~a + ~b }))
290+
275291
def main(args: Array[String]): Unit = {
276292
println(test1().run)
277293
println
@@ -280,6 +296,10 @@ object Test {
280296
println(test3().run)
281297
println
282298
println(test4().run)
299+
println
300+
println(test5().run)
301+
println
302+
println(test6().run)
283303
}
284304
}
285305

0 commit comments

Comments
 (0)