Skip to content

Commit a524b16

Browse files
committed
Avoid some boxing of state ids during transform
1 parent 1fe789a commit a524b16

File tree

3 files changed

+118
-49
lines changed

3 files changed

+118
-49
lines changed

src/main/scala/scala/async/internal/ExprBuilder.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
*/
44
package scala.async.internal
55

6+
import java.util.function.IntUnaryOperator
7+
68
import scala.collection.mutable
79
import scala.collection.mutable.ListBuffer
810
import language.existentials
@@ -23,7 +25,7 @@ trait ExprBuilder {
2325
trait AsyncState {
2426
def state: Int
2527

26-
def nextStates: List[Int]
28+
def nextStates: Array[Int]
2729

2830
def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef
2931

@@ -55,8 +57,8 @@ trait ExprBuilder {
5557
final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
5658
extends AsyncState {
5759

58-
def nextStates: List[Int] =
59-
List(nextState)
60+
val nextStates: Array[Int] =
61+
Array(nextState)
6062

6163
def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = {
6264
mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil))
@@ -69,7 +71,7 @@ trait ExprBuilder {
6971
/** A sequence of statements with a conditional transition to the next state, which will represent
7072
* a branch of an `if` or a `match`.
7173
*/
72-
final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState {
74+
final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: Array[Int]) extends AsyncState {
7375
override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef =
7476
mkHandlerCase(state, stats)
7577

@@ -84,8 +86,8 @@ trait ExprBuilder {
8486
val awaitable: Awaitable, symLookup: SymLookup)
8587
extends AsyncState {
8688

87-
def nextStates: List[Int] =
88-
List(nextState)
89+
val nextStates: Array[Int] =
90+
Array(nextState)
8991

9092
override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = {
9193
val fun = This(tpnme.EMPTY)
@@ -191,7 +193,7 @@ trait ExprBuilder {
191193
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
192194
def mkBranch(state: Int) = mkStateTree(state, symLookup)
193195
this += If(condTree, mkBranch(thenState), mkBranch(elseState))
194-
new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState))
196+
new AsyncStateWithoutAwait(stats.toList, state, Array(thenState, elseState))
195197
}
196198

197199
/**
@@ -204,7 +206,7 @@ trait ExprBuilder {
204206
* @param caseStates starting state of the right-hand side of the each case
205207
* @return an `AsyncState` representing the match expression
206208
*/
207-
def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = {
209+
def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: Array[Int], symLookup: SymLookup): AsyncState = {
208210
// 1. build list of changed cases
209211
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
210212
case CaseDef(pat, guard, rhs) =>
@@ -218,7 +220,7 @@ trait ExprBuilder {
218220

219221
def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
220222
this += mkStateTree(startLabelState, symLookup)
221-
new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState))
223+
new AsyncStateWithoutAwait(stats.toList, state, Array(startLabelState))
222224
}
223225

224226
override def toString: String = {
@@ -299,7 +301,10 @@ trait ExprBuilder {
299301
case Match(scrutinee, cases) if containsAwait(stat) =>
300302
checkForUnsupportedAwait(scrutinee)
301303

302-
val caseStates = cases.map(_ => nextState())
304+
val caseStates = new Array[Int](cases.length)
305+
java.util.Arrays.setAll(caseStates, new IntUnaryOperator {
306+
override def applyAsInt(operand: Int): Int = nextState()
307+
})
303308
val afterMatchState = nextState()
304309

305310
asyncStates +=

src/main/scala/scala/async/internal/LiveVariables.scala

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package scala.async.internal
22

3+
import java.util
4+
import java.util.function.{IntConsumer, IntPredicate}
5+
6+
import scala.collection.immutable.IntMap
7+
38
trait LiveVariables {
49
self: AsyncMacro =>
510
import c.universe._
@@ -17,19 +22,22 @@ trait LiveVariables {
1722
def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = {
1823
// live variables analysis:
1924
// the result map indicates in which states a given field should be nulled out
20-
val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables)
25+
val liveVarsMap: Map[Tree, StateSet] = liveVars(asyncStates, liftables)
2126

2227
var assignsOf = Map[Int, List[Tree]]()
2328

24-
for ((fld, where) <- liveVarsMap; state <- where)
25-
assignsOf get state match {
26-
case None =>
27-
assignsOf += (state -> List(fld))
28-
case Some(trees) if !trees.exists(_.symbol == fld.symbol) =>
29-
assignsOf += (state -> (fld +: trees))
30-
case _ =>
31-
/* do nothing */
32-
}
29+
for ((fld, where) <- liveVarsMap) {
30+
where.foreach { new IntConsumer { def accept(state: Int): Unit = {
31+
assignsOf get state match {
32+
case None =>
33+
assignsOf += (state -> List(fld))
34+
case Some(trees) if !trees.exists(_.symbol == fld.symbol) =>
35+
assignsOf += (state -> (fld +: trees))
36+
case _ =>
37+
// do nothing
38+
}
39+
}}}
40+
}
3341

3442
assignsOf
3543
}
@@ -46,9 +54,9 @@ trait LiveVariables {
4654
* @param liftables the lifted fields
4755
* @return a map which indicates for a given field (the key) the states in which it should be nulled out
4856
*/
49-
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = {
57+
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, StateSet] = {
5058
val liftedSyms: Set[Symbol] = // include only vars
51-
liftables.filter {
59+
liftables.iterator.filter {
5260
case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE)
5361
case _ => false
5462
}.map(_.symbol).toSet
@@ -122,20 +130,30 @@ trait LiveVariables {
122130
* A state `i` is contained in the list that is the value to which
123131
* key `j` maps iff control can flow from state `j` to state `i`.
124132
*/
125-
val cfg: Map[Int, List[Int]] = asyncStates.map(as => as.state -> as.nextStates).toMap
133+
val cfg: Map[Int, Array[Int]] = {
134+
var res = IntMap.empty[Array[Int]]
135+
136+
for (as <- asyncStates) res = res.updated(as.state, as.nextStates)
137+
res
138+
}
126139

127140
/** Tests if `state1` is a predecessor of `state2`.
128141
*/
129142
def isPred(state1: Int, state2: Int): Boolean = {
130-
val seen = scala.collection.mutable.HashSet[Int]()
143+
val seen = new StateSet()
131144

132145
def isPred0(state1: Int, state2: Int): Boolean =
133146
if(state1 == state2) false
134-
else if (seen(state1)) false // breaks cycles in the CFG
147+
else if (seen.contains(state1)) false // breaks cycles in the CFG
135148
else cfg get state1 match {
136149
case Some(nextStates) =>
137150
seen += state1
138-
nextStates.contains(state2) || nextStates.exists(isPred0(_, state2))
151+
var i = 0
152+
while (i < nextStates.length) {
153+
if (nextStates(i) == state2 || isPred0(nextStates(i), state2)) return true
154+
i += 1
155+
}
156+
false
139157
case None =>
140158
false
141159
}
@@ -164,8 +182,8 @@ trait LiveVariables {
164182
* 7. repeat if something has changed
165183
*/
166184

167-
var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]()
168-
var LVexit = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]()
185+
var LVentry = IntMap[Set[Symbol]]() withDefaultValue Set[Symbol]()
186+
var LVexit = IntMap[Set[Symbol]]() withDefaultValue Set[Symbol]()
169187

170188
// All fields are declared to be dead at the exit of the final async state, except for the ones
171189
// that cannot be nulled out at all (those in noNull), because they have been captured by a nested def.
@@ -174,6 +192,14 @@ trait LiveVariables {
174192
var currStates = List(finalState) // start at final state
175193
var captured: Set[Symbol] = Set()
176194

195+
def contains(as: Array[Int], a: Int): Boolean = {
196+
var i = 0
197+
while (i < as.length) {
198+
if (as(i) == a) return true
199+
i += 1
200+
}
201+
false
202+
}
177203
while (!currStates.isEmpty) {
178204
var entryChanged: List[AsyncState] = Nil
179205

@@ -183,19 +209,19 @@ trait LiveVariables {
183209
captured ++= referenced.captured
184210
val LVentryNew = LVexit(cs.state) ++ referenced.used
185211
if (!LVentryNew.sameElements(LVentryOld)) {
186-
LVentry = LVentry + (cs.state -> LVentryNew)
212+
LVentry = LVentry.updated(cs.state, LVentryNew)
187213
entryChanged ::= cs
188214
}
189215
}
190216

191-
val pred = entryChanged.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state)))
217+
val pred = entryChanged.flatMap(cs => asyncStates.filter(state => contains(state.nextStates, cs.state)))
192218
var exitChanged: List[AsyncState] = Nil
193219

194220
for (p <- pred) {
195221
val LVexitOld = LVexit(p.state)
196222
val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet
197223
if (!LVexitNew.sameElements(LVexitOld)) {
198-
LVexit = LVexit + (p.state -> LVexitNew)
224+
LVexit = LVexit.updated(p.state, LVexitNew)
199225
exitChanged ::= p
200226
}
201227
}
@@ -210,53 +236,64 @@ trait LiveVariables {
210236
}
211237
}
212238

213-
def lastUsagesOf(field: Tree, at: AsyncState): Set[Int] = {
239+
def lastUsagesOf(field: Tree, at: AsyncState): StateSet = {
214240
val avoid = scala.collection.mutable.HashSet[AsyncState]()
215241

216-
def lastUsagesOf0(field: Tree, at: AsyncState): Set[Int] = {
217-
if (avoid(at)) Set()
242+
val result = new StateSet
243+
def lastUsagesOf0(field: Tree, at: AsyncState): Unit = {
244+
if (avoid(at)) ()
218245
else if (captured(field.symbol)) {
219-
Set()
246+
()
220247
}
221248
else LVentry get at.state match {
222249
case Some(fields) if fields.contains(field.symbol) =>
223-
Set(at.state)
250+
result += at.state
224251
case _ =>
225252
avoid += at
226-
val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet
227-
preds.flatMap(p => lastUsagesOf0(field, p))
253+
for (state <- asyncStates) {
254+
if (contains(state.nextStates, at.state)) {
255+
lastUsagesOf0(field, state)
256+
}
257+
}
228258
}
229259
}
230260

231261
lastUsagesOf0(field, at)
262+
result
232263
}
233264

234-
val lastUsages: Map[Tree, Set[Int]] =
235-
liftables.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap
265+
val lastUsages: Map[Tree, StateSet] =
266+
liftables.iterator.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap
236267

237268
if(AsyncUtils.verbose) {
238269
for ((fld, lastStates) <- lastUsages)
239-
AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}")
270+
AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.iterator.mkString(", ")}")
240271
}
241272

242-
val nullOutAt: Map[Tree, Set[Int]] =
273+
val nullOutAt: Map[Tree, StateSet] =
243274
for ((fld, lastStates) <- lastUsages) yield {
244-
val killAt = lastStates.flatMap { s =>
245-
if (s == finalState.state) Set()
246-
else {
275+
var result = new StateSet
276+
lastStates.foreach(new IntConsumer { def accept(s: Int): Unit = {
277+
if (s != finalState.state) {
247278
val lastAsyncState = asyncStates.find(_.state == s).get
248279
val succNums = lastAsyncState.nextStates
249280
// all successor states that are not indirect predecessors
250281
// filter out successor states where the field is live at the entry
251-
succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).contains(fld.symbol))
282+
var i = 0
283+
while (i < succNums.length) {
284+
val num = succNums(i)
285+
if (!isPred(num, s) && !LVentry(num).contains(fld.symbol))
286+
result += num
287+
i += 1
288+
}
252289
}
253-
}
254-
(fld, killAt)
290+
}})
291+
(fld, result)
255292
}
256293

257294
if(AsyncUtils.verbose) {
258295
for ((fld, killAt) <- nullOutAt)
259-
AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}")
296+
AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.iterator.mkString(", ")}")
260297
}
261298

262299
nullOutAt
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (C) 2018 Lightbend Inc. <http://www.lightbend.com>
3+
*/
4+
package scala.async.internal
5+
6+
import java.util
7+
import java.util.function.{Consumer, IntConsumer}
8+
9+
import scala.collection.JavaConverters.{asScalaIteratorConverter, iterableAsScalaIterableConverter}
10+
import scala.collection.mutable
11+
12+
// Set for StateIds, which are either small positive integers or -symbolID.
13+
final class StateSet {
14+
private var bitSet = new java.util.BitSet()
15+
private var caseSet = new util.HashSet[Integer]()
16+
def +=(stateId: Int): Unit = if (stateId > 0) bitSet.set(stateId) else caseSet.add(stateId)
17+
def contains(stateId: Int): Boolean = if (stateId > 0 && stateId < 1024) bitSet.get(stateId) else caseSet.contains(stateId)
18+
def iterator: Iterator[Integer] = {
19+
bitSet.stream().iterator().asScala ++ caseSet.asScala.iterator
20+
}
21+
def foreach(f: IntConsumer): Unit = {
22+
bitSet.stream().forEach(f)
23+
caseSet.stream().forEach(new Consumer[Integer] {
24+
override def accept(value: Integer): Unit = f.accept(value)
25+
})
26+
}
27+
}

0 commit comments

Comments
 (0)