Skip to content

Commit 1fe789a

Browse files
authored
Merge pull request #176 from retronym/topic/extensions
Improve generated code and flexibility
2 parents c09aa77 + 2e38116 commit 1fe789a

File tree

3 files changed

+116
-26
lines changed

3 files changed

+116
-26
lines changed

src/main/scala/scala/async/internal/ExprBuilder.scala

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44
package scala.async.internal
55

6+
import scala.collection.mutable
67
import scala.collection.mutable.ListBuffer
78
import language.existentials
89

@@ -117,16 +118,22 @@ trait ExprBuilder {
117118
* <mkResumeApply>
118119
* }
119120
*/
120-
def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) =
121-
If(futureSystemOps.tryyIsFailure(c.Expr[futureSystem.Tryy[T]](tryReference)).tree,
122-
Block(toList(futureSystemOps.completeProm[T](
123-
c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
124-
c.Expr[futureSystem.Tryy[T]](
125-
TypeApply(Select(tryReference, newTermName("asInstanceOf")),
126-
List(TypeTree(futureSystemOps.tryType[T]))))).tree),
127-
Return(literalUnit)),
128-
Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup))
129-
)
121+
def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) = {
122+
val getAndUpdateState = Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup))
123+
if (asyncBase.futureSystem.emitTryCatch) {
124+
If(futureSystemOps.tryyIsFailure(c.Expr[futureSystem.Tryy[T]](tryReference)).tree,
125+
Block(toList(futureSystemOps.completeProm[T](
126+
c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
127+
c.Expr[futureSystem.Tryy[T]](
128+
TypeApply(Select(tryReference, newTermName("asInstanceOf")),
129+
List(TypeTree(futureSystemOps.tryType[T]))))).tree),
130+
Return(literalUnit)),
131+
getAndUpdateState
132+
)
133+
} else {
134+
getAndUpdateState
135+
}
136+
}
130137

131138
override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
132139
Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree[T](Ident(symLookup.applyTrParam)))))
@@ -401,9 +408,10 @@ trait ExprBuilder {
401408
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
402409
val stateMemberRef = symLookup.memberRef(name.state)
403410
val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List())))))
411+
val body1 = eliminateDeadStates(body)
404412

405-
Try(
406-
body,
413+
maybeTry(
414+
body1,
407415
List(
408416
CaseDef(
409417
Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))),
@@ -417,8 +425,67 @@ trait ExprBuilder {
417425
If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t)))
418426
then
419427
})), EmptyTree)
428+
}
420429

421-
//body
430+
// Identify dead states: `case <id> => { state = nextId; (); (); ... }, eliminated, and compact state ids to
431+
// enable emission of a tableswitch.
432+
private def eliminateDeadStates(m: Match): Tree = {
433+
object DeadState {
434+
private val liveStates = mutable.AnyRefMap[Integer, Integer]()
435+
private val deadStates = mutable.AnyRefMap[Integer, Integer]()
436+
private var compactedStateId = 1
437+
for (CaseDef(Literal(Constant(stateId: Integer)), EmptyTree, body) <- m.cases) {
438+
body match {
439+
case _ if (stateId == 0) => liveStates(stateId) = stateId
440+
case Block(Assign(_, Literal(Constant(nextState: Integer))) :: rest, expr) if (expr :: rest).forall(t => isLiteralUnit(t)) =>
441+
deadStates(stateId) = nextState
442+
case _ =>
443+
liveStates(stateId) = compactedStateId
444+
compactedStateId += 1
445+
}
446+
}
447+
if (deadStates.nonEmpty)
448+
AsyncUtils.vprintln(s"${deadStates.size} dead states eliminated")
449+
def isDead(i: Integer) = deadStates.contains(i)
450+
def translatedStateId(i: Integer, tree: Tree): Integer = {
451+
def chaseDead(i: Integer): Integer = {
452+
val replacement = deadStates.getOrNull(i)
453+
if (replacement == null) i
454+
else chaseDead(replacement)
455+
}
456+
457+
val live = chaseDead(i)
458+
liveStates.get(live) match {
459+
case Some(x) => x
460+
case None => sys.error(s"$live, $liveStates \n$deadStates\n$m\n\n====\n$tree")
461+
}
462+
}
463+
}
464+
val stateMemberSymbol = symLookup.stateMachineMember(name.state)
465+
// - remove CaseDef-s for dead states
466+
// - rewrite state transitions to dead states to instead transition to the
467+
// non-dead successor.
468+
val elimDeadStateTransform = new Transformer {
469+
override def transform(tree: Tree): Tree = tree match {
470+
case as @ Assign(lhs, Literal(Constant(i: Integer))) if lhs.symbol == stateMemberSymbol =>
471+
val replacement = DeadState.translatedStateId(i, as)
472+
treeCopy.Assign(tree, lhs, Literal(Constant(replacement)))
473+
case _: Match | _: CaseDef | _: Block | _: If =>
474+
super.transform(tree)
475+
case _ => tree
476+
}
477+
}
478+
val cases1 = m.cases.flatMap {
479+
case cd @ CaseDef(Literal(Constant(i: Integer)), EmptyTree, rhs) =>
480+
if (DeadState.isDead(i)) Nil
481+
else {
482+
val replacement = DeadState.translatedStateId(i, cd)
483+
val rhs1 = elimDeadStateTransform.transform(rhs)
484+
treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil
485+
}
486+
case x => x :: Nil
487+
}
488+
treeCopy.Match(m, m.selector, cases1)
422489
}
423490

424491
def forever(t: Tree): Tree = {

src/main/scala/scala/async/internal/FutureSystem.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ trait FutureSystem {
7474
}
7575

7676
def mkOps(c0: Context): Ops { val c: c0.type }
77+
78+
def freshenAllNames: Boolean = false
79+
def emitTryCatch: Boolean = true
80+
def resultFieldName: String = "result"
7781
}
7882

7983
object ScalaConcurrentFutureSystem extends FutureSystem {

src/main/scala/scala/async/internal/TransformUtils.scala

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,47 @@ private[async] trait TransformUtils {
1717
import c.internal._
1818
import decorators._
1919

20+
private object baseNames {
21+
22+
val matchRes = "matchres"
23+
val ifRes = "ifres"
24+
val bindSuffix = "$bind"
25+
val completed = newTermName("completed")
26+
27+
val state = newTermName("state")
28+
val result = newTermName(self.futureSystem.resultFieldName)
29+
val execContext = newTermName("execContext")
30+
val tr = newTermName("tr")
31+
val t = newTermName("throwable")
32+
}
33+
2034
object name {
21-
val resume = newTermName("resume")
22-
val apply = newTermName("apply")
23-
val matchRes = "matchres"
24-
val ifRes = "ifres"
25-
val await = "await"
26-
val bindSuffix = "$bind"
27-
val completed = newTermName("completed")
28-
29-
val state = newTermName("state")
30-
val result = newTermName("result")
31-
val execContext = newTermName("execContext")
35+
def matchRes = maybeFresh(baseNames.matchRes)
36+
def ifRes = maybeFresh(baseNames.ifRes)
37+
def bindSuffix = maybeFresh(baseNames.bindSuffix)
38+
def completed = maybeFresh(baseNames.completed)
39+
40+
val state = maybeFresh(baseNames.state)
41+
val result = baseNames.result
42+
val execContext = maybeFresh(baseNames.execContext)
43+
val tr = maybeFresh(baseNames.tr)
44+
val t = maybeFresh(baseNames.t)
45+
46+
val await = "await"
47+
val resume = newTermName("resume")
48+
val apply = newTermName("apply")
3249
val stateMachine = newTermName(fresh("stateMachine"))
3350
val stateMachineT = stateMachine.toTypeName
34-
val tr = newTermName("tr")
35-
val t = newTermName("throwable")
3651

52+
def maybeFresh(name: TermName): TermName = if (self.asyncBase.futureSystem.freshenAllNames) fresh(name) else name
53+
def maybeFresh(name: String): String = if (self.asyncBase.futureSystem.freshenAllNames) fresh(name) else name
3754
def fresh(name: TermName): TermName = c.freshName(name)
3855

3956
def fresh(name: String): String = c.freshName(name)
4057
}
4158

59+
def maybeTry(block: Tree, catches: List[CaseDef], finalizer: Tree) = if (asyncBase.futureSystem.emitTryCatch) Try(block, catches, finalizer) else block
60+
4261
def isAsync(fun: Tree) =
4362
fun.symbol == defn.Async_async
4463

0 commit comments

Comments
 (0)