Skip to content

Commit 3231eb7

Browse files
committed
Refactor and add comments
1 parent 0d47443 commit 3231eb7

File tree

1 file changed

+106
-24
lines changed

1 file changed

+106
-24
lines changed

tests/run-with-compiler-custom-args/staged-streams_1.scala

Lines changed: 106 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,17 @@ object Test {
4444

4545
def fold[W: Type](z: Expr[W], f: ((Expr[W], Expr[A]) => Expr[W])): Expr[W] = {
4646
Var(z) { s: Var[W] => '{
47-
4847
~fold_raw[Expr[A]]((a: Expr[A]) => '{
49-
~s.update(f(s.get, a))
48+
~s.update(f(s.get, a))
5049
}, stream)
5150

5251
~s.get
5352
}
5453
}
5554
}
5655

57-
private def fold_raw[A](consumer: A => Expr[Unit], s: StagedStream[A]): Expr[Unit] = {
58-
s match {
56+
private def fold_raw[A](consumer: A => Expr[Unit], stream: StagedStream[A]): Expr[Unit] = {
57+
stream match {
5958
case Linear(producer) => {
6059
producer.card match {
6160
case Many =>
@@ -72,16 +71,36 @@ object Test {
7271
})
7372
}
7473
}
75-
case nested: Nested[a, bt] => {
76-
fold_raw[bt](((e: bt) => fold_raw[a](consumer, nested.nestedf(e))), Linear(nested.producer))
74+
case nested: Nested[A, bt] => {
75+
fold_raw[bt](((e: bt) => fold_raw[A](consumer, nested.nestedf(e))), Linear(nested.producer))
7776
}
7877
}
7978
}
8079

80+
/** Builds a new stream by applying a function to all elements of this stream.
81+
*
82+
* @param f the function to apply to each quoted element.
83+
* @tparam B the element type of the returned stream
84+
* @return a new stream resulting from applying `mapRaw` and threading the element of the first stream downstream.
85+
*/
8186
def map[B : Type](f: (Expr[A] => Expr[B])): Stream[B] = {
8287
Stream(mapRaw[Expr[A], Expr[B]](a => k => '{ ~k(f(a)) }, stream))
8388
}
8489

90+
/** Handles generically the mapping of elements from one producer to another.
91+
* `mapRaw` can be potentially used threading quoted values from one stream to another. However
92+
* is can be also used by handling any kind of quoted value.
93+
*
94+
* e.g., `mapRaw[(Var[Int], A), A]` transforms a stream that declares a variable and holds a value in each
95+
* iteration step to a stream that is not aware of the aforementioned variable.
96+
*
97+
* @param f the function to apply at each step. f is of type `(A => (B => Expr[Unit])` where A is the type of
98+
* the incoming stream. When applied to an element, `f` returns the continuation for elements of `B`
99+
* @param stream that contains the stream we want to map.
100+
* @tparam A the type of the input stream
101+
* @tparam B the element type of the resulting stream
102+
* @return a new stream resulting from applying `f` in the `step` function of the input stream's producer.
103+
*/
85104
private def mapRaw[A, B](f: (A => (B => Expr[Unit]) => Expr[Unit]), stream: StagedStream[A]): StagedStream[B] = {
86105
stream match {
87106
case Linear(producer) => {
@@ -105,16 +124,30 @@ object Test {
105124

106125
Linear(prod)
107126
}
108-
case nested: Nested[a, bt] => {
127+
case nested: Nested[A, bt] => {
109128
Nested(nested.producer, (a: bt) => mapRaw[A, B](f, nested.nestedf(a)))
110129
}
111130
}
112131
}
113132

133+
/** Flatmap */
114134
def flatMap[B : Type](f: (Expr[A] => Stream[B])): Stream[B] = {
115135
Stream(flatMapRaw[Expr[A], Expr[B]]((a => { val Stream (nested) = f(a); nested }), stream))
116136
}
117137

138+
/** Returns a new stream that applies a function `f` to each element of the input stream.
139+
* If the input stream is simply linear then its packed with the function `f`.
140+
* If the input stream is nested then a new one is created by using its producer and then passing the `f`
141+
* recursively to build the `nestedf` of the returned stream.
142+
*
143+
* Note: always returns a nested stream.
144+
*
145+
* @param f the function of `flatMap``
146+
* @param stream the input stream
147+
* @tparam A the type of the input stream
148+
* @tparam B the element type of the resulting stream
149+
* @return a new stream resulting from registering `f`
150+
*/
118151
private def flatMapRaw[A, B](f: (A => StagedStream[B]), stream: StagedStream[A]): StagedStream[B] = {
119152
stream match {
120153
case Linear(producer) => Nested(producer, f)
@@ -123,9 +156,19 @@ object Test {
123156
}
124157
}
125158

126-
def filter(f: (Expr[A] => Expr[Boolean])): Stream[A] = {
159+
/** Selects all elements of this stream which satisfy a predicate.
160+
*
161+
* Note: this is merely a special case of `flatMap` as the resulting stream in each step may return 0 or 1
162+
* element.
163+
*
164+
* @param f the predicate used to test elements.
165+
* @return a new stream consisting of all elements of the input stream that do satisfy the given
166+
* predicate `pred`.
167+
*/
168+
def filter(pred: (Expr[A] => Expr[Boolean])): Stream[A] = {
127169
val filterStream = (a: Expr[A]) =>
128170
new Producer[Expr[A]] {
171+
129172
type St = Expr[A]
130173
val card = AtMost1
131174

@@ -136,13 +179,22 @@ object Test {
136179
k(st)
137180

138181
def hasNext(st: St): Expr[Boolean] =
139-
f(st)
182+
pred(st)
140183
}
141184

142185
Stream(flatMapRaw[Expr[A], Expr[A]]((a => { Linear(filterStream(a)) }), stream))
143186
}
144187

145-
private def moreTermination[A](f: Expr[Boolean] => Expr[Boolean], stream: StagedStream[A]): StagedStream[A] = {
188+
/** Adds a new termination condition to a producer of cardinality `Many`.
189+
*
190+
* @param condition the termination condition as a function accepting the existing condition (the result
191+
* of the `hasNext` from the passed `stream`'s producer.
192+
* @param stream that contains the producer we want to enhance.
193+
* @tparam A the type of the stream's elements.
194+
* @return the stream with the new producer. If the passed stream was linear, the new termination is added
195+
* otherwise the new termination is propagated to all nested ones, recursively.
196+
*/
197+
private def addTerminationCondition[A](condition: Expr[Boolean] => Expr[Boolean], stream: StagedStream[A]): StagedStream[A] = {
146198
def addToProducer[A](f: Expr[Boolean] => Expr[Boolean], producer: Producer[A]): Producer[A] = {
147199
producer.card match {
148200
case Many =>
@@ -164,12 +216,20 @@ object Test {
164216
}
165217

166218
stream match {
167-
case Linear(producer) => Linear(addToProducer(f, producer))
219+
case Linear(producer) => Linear(addToProducer(condition, producer))
168220
case nested: Nested[a, bt] =>
169-
Nested(addToProducer(f, nested.producer), (a: bt) => moreTermination(f, nested.nestedf(a)))
221+
Nested(addToProducer(condition, nested.producer), (a: bt) => addTerminationCondition(condition, nested.nestedf(a)))
170222
}
171223
}
172224

225+
/** Adds a new counter variable by enhancing a producer's state with a variable of type `Int`.
226+
* The counter is initialized in `init`, propageted in `step` and checked in the `hasNext` of the *current* stream.
227+
*
228+
* @param n is the initial value of the counter
229+
* @param producer the producer that we want to enhance
230+
* @tparam A the type of the producer's elements.
231+
* @return the enhanced producer
232+
*/
173233
private def addCounter[A](n: Expr[Int], producer: Producer[A]): Producer[(Var[Int], A)] = {
174234
new Producer[(Var[Int], A)] {
175235
type St = (Var[Int], producer.St)
@@ -184,41 +244,58 @@ object Test {
184244
}
185245

186246
def step(st: St, k: (((Var[Int], A)) => Expr[Unit])): Expr[Unit] = {
187-
val (counter, nst) = st
188-
producer.step(nst, el => '{
247+
val (counter, currentState) = st
248+
producer.step(currentState, el => '{
189249
~k((counter, el))
190250
})
191251
}
192252

193253
def hasNext(st: St): Expr[Boolean] = {
194-
val (counter, nst) = st
254+
val (counter, currentState) = st
195255
producer.card match {
196-
case Many => '{ ~counter.get > 0 && ~producer.hasNext(nst) }
197-
case AtMost1 => '{ ~producer.hasNext(nst) }
256+
case Many => '{ ~counter.get > 0 && ~producer.hasNext(currentState) }
257+
case AtMost1 => '{ ~producer.hasNext(currentState) }
198258
}
199259
}
200260
}
201261
}
202262

263+
/** The nested stream receives the same variable reference; thus all streams decrement the same global count.
264+
*
265+
* @param n code of the variable to be threaded to the downstream.
266+
* @param stream the upstream to enhance.
267+
* @tparam A the type of the producer's elements.
268+
* @return a linear or nested stream aware of the variable reference to decrement.
269+
*/
203270
private def takeRaw[A](n: Expr[Int], stream: StagedStream[A]): StagedStream[A] = {
204271
stream match {
205-
case Linear(producer) => {
272+
case linear: Linear[A] => {
273+
val enhancedProducer: Producer[(Var[Int], A)] = addCounter[A](n, linear.producer)
274+
val enhancedStream: Linear[(Var[Int], A)] = Linear(enhancedProducer)
275+
276+
// Map an enhanced stream to a stream that produces the elements. Before
277+
// invoking the continuation for the element, "use" the variable accordingly.
206278
mapRaw[(Var[Int], A), A]((t: (Var[Int], A)) => k => '{
207279
~t._1.update('{~t._1.get - 1})
208280
~k(t._2)
209-
}, Linear(addCounter(n, producer)))
281+
}, enhancedStream)
210282
}
211-
case nested: Nested[a, bt] => {
212-
Nested(addCounter(n, nested.producer), (t: (Var[Int], bt)) => {
283+
case nested: Nested[A, bt] => {
284+
val enhancedProducer: Producer[(Var[Int], bt)] = addCounter[bt](n, nested.producer)
285+
286+
Nested(enhancedProducer, (t: (Var[Int], bt)) => {
287+
// Before invoking the continuation for the element, "use" the variable accordingly.
288+
// In contrast to the linear case, the variable is initialized in the originating stream.
213289
mapRaw[A, A]((el => k => '{
214290
~t._1.update('{~t._1.get - 1})
215291
~k(el)
216-
}), moreTermination(b => '{ ~t._1.get > 0 && ~b}, nested.nestedf(t._2)))
292+
}), addTerminationCondition(b => '{ ~t._1.get > 0 && ~b}, nested.nestedf(t._2)))
217293
})
218294
}
219295
}
220296
}
221297

298+
/** A stream containing the first `n` elements of this stream. */
222299
def take(n: Expr[Int]): Stream[A] = Stream(takeRaw[Expr[A]](n, stream))
223300

224301
private def zipRaw[A, B](stream1: StagedStream[A], stream2: StagedStream[B]): StagedStream[(A, B)] = {
@@ -233,10 +310,15 @@ object Test {
233310
case (Nested(producer1, nestf1), Linear(producer2)) =>
234311
mapRaw[(B, A), (A, B)]((t => k => '{ ~k((t._2, t._1)) }), pushLinear[B, _, A](producer2, producer1, nestf1))
235312

236-
case (Nested(producer1, nestf1), Nested(producer2, nestf2)) => ???
313+
case (Nested(producer1, nestf1), Nested(producer2, nestf2)) =>
314+
zipRaw(makeLinear(stream1), stream2)
237315
}
238316
}
239317

318+
private def makeLinear[A](stream: StagedStream[A]): StagedStream[A] = {
319+
???
320+
}
321+
240322
private def pushLinear[A, B, C](producer: Producer[A], nestedProducer: Producer[B], nestedf: (B => StagedStream[C])): StagedStream[(A, C)] = {
241323
val newProducer = new Producer[(Var[Boolean], producer.St, B)] {
242324

@@ -267,7 +349,7 @@ object Test {
267349
mapRaw[C, (A, C)]((c => k => '{
268350
~producer.step(s1, a => '{ ~k((a, c)) })
269351
~flag.update(producer.hasNext(s1))
270-
}), moreTermination((b_flag: Expr[Boolean]) => '{ ~flag.get && ~b_flag }, nestedf(b)))
352+
}), addTerminationCondition((b_flag: Expr[Boolean]) => '{ ~flag.get && ~b_flag }, nestedf(b)))
271353
})
272354
}
273355

0 commit comments

Comments
 (0)