diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index bb63d565..fa230999 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -77,7 +77,7 @@ private[async] trait AnfTransform { stats :+ expr :+ api.typecheck(atPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), nme.CONSTRUCTOR), Nil)))) expr match { case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal(name.await, expr, tree.pos) + val valDef = defineVal(name.await(), expr, tree.pos) val ref = gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe) val ref1 = if (ref.tpe =:= definitions.UnitTpe) // https://github.com/scala/async/issues/74 @@ -109,7 +109,7 @@ private[async] trait AnfTransform { } else if (expr.tpe =:= definitions.NothingTpe) { statsExprThrow } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + val varDef = defineVar(name.ifRes(), expr.tpe, tree.pos) def typedAssign(lhs: Tree) = api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol))))) @@ -140,7 +140,7 @@ private[async] trait AnfTransform { } else if (expr.tpe =:= definitions.NothingTpe) { statsExprThrow } else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + val varDef = defineVar(name.matchRes(), expr.tpe, tree.pos) def typedAssign(lhs: Tree) = api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol))))) val casesWithAssign = cases map { @@ -163,14 +163,14 @@ private[async] trait AnfTransform { } } - def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { - val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) + def defineVar(name: TermName, tp: Type, pos: Position): ValDef = { + val sym = api.currentOwner.newTermSymbol(name, pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos) } } - def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe)) + def defineVal(name: TermName, lhs: Tree, pos: Position): ValDef = { + val sym = api.currentOwner.newTermSymbol(name, pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe)) internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos) } @@ -212,7 +212,7 @@ private[async] trait AnfTransform { case Arg(expr, _, argName) => linearize.transformToList(expr) match { case stats :+ expr1 => - val valDef = defineVal(argName, expr1, expr1.pos) + val valDef = defineVal(name.freshen(argName), expr1, expr1.pos) require(valDef.tpe != null, valDef) val stats1 = stats :+ valDef (stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol)))) @@ -279,8 +279,9 @@ private[async] trait AnfTransform { // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. val block = linearize.transformToBlock(body) val (valDefs, mappings) = (pat collect { - case b@Bind(name, _) => - val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos) + case b@Bind(bindName, _) => + val vd = defineVal(name.freshen(bindName.toTermName), gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos) + vd.symbol.updateAttachment(SyntheticBindVal) (vd, (b.symbol, vd.symbol)) }).unzip val (from, to) = mappings.unzip @@ -333,7 +334,7 @@ private[async] trait AnfTransform { // Otherwise, create the matchres var. We'll callers of the label def below. // Remember: we're iterating through the statement sequence in reverse, so we'll get // to the LabelDef and mutate `matchResults` before we'll get to its callers. - val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos) + val matchResult = linearize.defineVar(name.matchRes(), param.tpe, ld.pos) matchResults += matchResult caseDefToMatchResult(ld.symbol) = matchResult.symbol val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil) @@ -408,3 +409,5 @@ private[async] trait AnfTransform { }).asInstanceOf[Block] } } + +object SyntheticBindVal diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index 113e7a8f..2b9b68a9 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -3,8 +3,18 @@ package scala.async.internal object AsyncMacro { def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = { import language.reflectiveCalls + + // Use an attachment on RootClass as a sneaky place for a per-Global cache + val att = c0.internal.attachments(c0.universe.rootMirror.RootClass) + val names = att.get[AsyncNames[_]].getOrElse { + val names = new AsyncNames[c0.universe.type](c0.universe) + att.update(names) + names + } + new AsyncMacro { self => val c: c0.type = c0 + val asyncNames: AsyncNames[c.universe.type] = names.asInstanceOf[AsyncNames[c.universe.type]] val body: c.Tree = body0 // This member is required by `AsyncTransform`: val asyncBase: AsyncBase = base @@ -23,6 +33,7 @@ private[async] trait AsyncMacro val c: scala.reflect.macros.Context val body: c.Tree var containsAwait: c.Tree => Boolean + val asyncNames: AsyncNames[c.universe.type] lazy val macroPos: c.universe.Position = c.macroApplication.pos.makeTransparent def atMacroPos(t: c.Tree): c.Tree = c.universe.atPos(macroPos)(t) diff --git a/src/main/scala/scala/async/internal/AsyncNames.scala b/src/main/scala/scala/async/internal/AsyncNames.scala new file mode 100644 index 00000000..cf551584 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncNames.scala @@ -0,0 +1,109 @@ +package scala.async.internal + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.api.Names + +/** + * A per-global cache of names needed by the Async macro. + */ +final class AsyncNames[U <: Names with Singleton](val u: U) { + self => + import u._ + + abstract class NameCache[N <: U#Name](base: String) { + val cached = new ArrayBuffer[N]() + protected def newName(s: String): N + def apply(i: Int): N = { + if (cached.isDefinedAt(i)) cached(i) + else { + assert(cached.length == i) + val name = newName(freshenString(base, i)) + cached += name + name + } + } + } + + final class TermNameCache(base: String) extends NameCache[U#TermName](base) { + override protected def newName(s: String): U#TermName = newTermName(s) + } + final class TypeNameCache(base: String) extends NameCache[U#TypeName](base) { + override protected def newName(s: String): U#TypeName = newTypeName(s) + } + private val matchRes: TermNameCache = new TermNameCache("match") + private val ifRes: TermNameCache = new TermNameCache("if") + private val await: TermNameCache = new TermNameCache("await") + + private val result = newTermName("result$async") + private val completed: TermName = newTermName("completed$async") + private val apply = newTermName("apply") + private val stateMachine = newTermName("stateMachine$async") + private val stateMachineT = stateMachine.toTypeName + private val state: u.TermName = newTermName("state$async") + private val execContext = newTermName("execContext$async") + private val tr: u.TermName = newTermName("tr$async") + private val t: u.TermName = newTermName("throwable$async") + + final class NameSource[N <: U#Name](cache: NameCache[N]) { + private val count = new AtomicInteger(0) + def apply(): N = cache(count.getAndIncrement()) + } + + class AsyncName { + final val matchRes = new NameSource[U#TermName](self.matchRes) + final val ifRes = new NameSource[U#TermName](self.matchRes) + final val await = new NameSource[U#TermName](self.await) + final val completed = self.completed + final val result = self.result + final val apply = self.apply + final val stateMachine = self.stateMachine + final val stateMachineT = self.stateMachineT + final val state: u.TermName = self.state + final val execContext = self.execContext + final val tr: u.TermName = self.tr + final val t: u.TermName = self.t + + private val seenPrefixes = mutable.AnyRefMap[Name, AtomicInteger]() + private val freshened = mutable.HashSet[Name]() + + final def freshenIfNeeded(name: TermName): TermName = { + seenPrefixes.getOrNull(name) match { + case null => + seenPrefixes.put(name, new AtomicInteger()) + name + case counter => + freshen(name, counter) + } + } + final def freshenIfNeeded(name: TypeName): TypeName = { + seenPrefixes.getOrNull(name) match { + case null => + seenPrefixes.put(name, new AtomicInteger()) + name + case counter => + freshen(name, counter) + } + } + final def freshen(name: TermName): TermName = { + val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger()) + freshen(name, counter) + } + final def freshen(name: TypeName): TypeName = { + val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger()) + freshen(name, counter) + } + private def freshen(name: TermName, counter: AtomicInteger): TermName = { + if (freshened.contains(name)) name + else TermName(freshenString(name.toString, counter.incrementAndGet())) + } + private def freshen(name: TypeName, counter: AtomicInteger): TypeName = { + if (freshened.contains(name)) name + else TypeName(freshenString(name.toString, counter.incrementAndGet())) + } + } + + private def freshenString(name: String, counter: Int): String = name.toString + "$async$" + counter +} diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 7ef63f70..ba0b522b 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -70,9 +70,6 @@ trait AsyncTransform { buildAsyncBlock(anfTree, symLookup) } - if(AsyncUtils.verbose) - logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) - val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates) // live variables analysis @@ -114,10 +111,15 @@ trait AsyncTransform { futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` else startStateMachine + + if(AsyncUtils.verbose) { + logDiagnostics(anfTree, asyncBlock, asyncBlock.asyncStates.map(_.toString)) + } + futureSystemOps.dot(enclosingOwner, body).foreach(f => f(asyncBlock.toDot)) cleanupContainsAwaitAttachments(result) } - def logDiagnostics(anfTree: Tree, states: Seq[String]): Unit = { + def logDiagnostics(anfTree: Tree, block: AsyncBlock, states: Seq[String]): Unit = { def location = try { macroPos.source.path } catch { @@ -129,6 +131,8 @@ trait AsyncTransform { AsyncUtils.vprintln(s"${c.macroApplication}") AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") states foreach (s => AsyncUtils.vprintln(s)) + AsyncUtils.vprintln("===== DOT =====") + AsyncUtils.vprintln(block.toDot) } /** diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index e1ab6c86..5fbf63c8 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -3,6 +3,8 @@ */ package scala.async.internal +import java.util.function.IntUnaryOperator + import scala.collection.mutable import scala.collection.mutable.ListBuffer import language.existentials @@ -23,7 +25,7 @@ trait ExprBuilder { trait AsyncState { def state: Int - def nextStates: List[Int] + def nextStates: Array[Int] def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef @@ -55,8 +57,8 @@ trait ExprBuilder { final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) extends AsyncState { - def nextStates: List[Int] = - List(nextState) + val nextStates: Array[Int] = + Array(nextState) def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil)) @@ -69,23 +71,23 @@ trait ExprBuilder { /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { + final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: Array[Int]) extends AsyncState { override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = mkHandlerCase(state, stats) override val toString: String = - s"AsyncStateWithoutAwait #$state, nextStates = $nextStates" + s"AsyncStateWithoutAwait #$state, nextStates = ${nextStates.toList}" } /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nextState`. */ - final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, onCompleteState: Int, nextState: Int, + final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, val onCompleteState: Int, nextState: Int, val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { - def nextStates: List[Int] = - List(nextState) + val nextStates: Array[Int] = + Array(nextState) override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { val fun = This(tpnme.EMPTY) @@ -93,7 +95,7 @@ trait ExprBuilder { c.Expr[futureSystem.Tryy[Any] => Unit](fun), c.Expr[futureSystem.ExecContext](Ident(name.execContext))).tree val tryGetOrCallOnComplete: List[Tree] = if (futureSystemOps.continueCompletedFutureOnSameThread) { - val tempName = name.fresh(name.completed) + val tempName = name.completed val initTemp = ValDef(NoMods, tempName, TypeTree(futureSystemOps.tryType[Any]), futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) val ifTree = If(Apply(Select(Literal(Constant(null)), TermName("ne")), Ident(tempName) :: Nil), adaptToUnit(ifIsFailureTree[T](Ident(tempName)) :: Nil), @@ -191,7 +193,7 @@ trait ExprBuilder { def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { def mkBranch(state: Int) = mkStateTree(state, symLookup) this += If(condTree, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState)) + new AsyncStateWithoutAwait(stats.toList, state, Array(thenState, elseState)) } /** @@ -204,7 +206,7 @@ trait ExprBuilder { * @param caseStates starting state of the right-hand side of the each case * @return an `AsyncState` representing the match expression */ - def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: Array[Int], symLookup: SymLookup): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => @@ -218,7 +220,7 @@ trait ExprBuilder { def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { this += mkStateTree(startLabelState, symLookup) - new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState)) + new AsyncStateWithoutAwait(stats.toList, state, Array(startLabelState)) } override def toString: String = { @@ -266,11 +268,11 @@ trait ExprBuilder { } // populate asyncStates - def add(stat: Tree): Unit = stat match { + def add(stat: Tree, afterState: Option[Int] = None): Unit = stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() - val afterAwaitState = nextState() + val afterAwaitState = afterState.getOrElse(nextState()) val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await currState = afterAwaitState @@ -281,7 +283,7 @@ trait ExprBuilder { val thenStartState = nextState() val elseStartState = nextState() - val afterIfState = nextState() + val afterIfState = afterState.getOrElse(nextState()) asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively @@ -299,8 +301,11 @@ trait ExprBuilder { case Match(scrutinee, cases) if containsAwait(stat) => checkForUnsupportedAwait(scrutinee) - val caseStates = cases.map(_ => nextState()) - val afterMatchState = nextState() + val caseStates = new Array[Int](cases.length) + java.util.Arrays.setAll(caseStates, new IntUnaryOperator { + override def applyAsInt(operand: Int): Int = nextState() + }) + val afterMatchState = afterState.getOrElse(nextState()) asyncStates += stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) @@ -318,7 +323,7 @@ trait ExprBuilder { if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) => val startLabelState = stateIdForLabel(ld.symbol) - val afterLabelState = nextState() + val afterLabelState = afterState.getOrElse(nextState()) asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) @@ -326,7 +331,8 @@ trait ExprBuilder { currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) case b @ Block(stats, expr) => - (stats :+ expr) foreach (add) + for (stat <- stats) add(stat) + add(expr, afterState = Some(endState)) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -340,6 +346,8 @@ trait ExprBuilder { def asyncStates: List[AsyncState] def onCompleteHandler[T: WeakTypeTag]: Tree + + def toDot: String } case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { @@ -364,7 +372,106 @@ trait ExprBuilder { val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) new AsyncBlock { - def asyncStates = blockBuilder.asyncStates.toList + val switchIds = mutable.AnyRefMap[Integer, Integer]() + + // render with http://graphviz.it/#/new + def toDot: String = { + val states = asyncStates + def toHtmlLabel(label: String, preText: String, builder: StringBuilder): Unit = { + val br = "
" + builder.append("").append(label).append("").append("
") + builder.append("") + preText.split("\n").foreach { + (line: String) => + builder.append(br) + builder.append(line.replaceAllLiterally("\"", """).replaceAllLiterally("<", "<").replaceAllLiterally(">", ">").replaceAllLiterally(" ", " ")) + } + builder.append(br) + builder.append("") + } + val dotBuilder = new StringBuilder() + dotBuilder.append("digraph {\n") + def stateLabel(s: Int) = { + if (s == 0) "INITIAL" else if (s == Int.MaxValue) "TERMINAL" else switchIds.getOrElse(s, s).toString + } + val length = states.size + for ((state, i) <- asyncStates.zipWithIndex) { + dotBuilder.append(s"""${stateLabel(state.state)} [label=""").append("<") + def show(t: Tree): String = { + (t match { + case Block(stats, expr) => stats ::: expr :: Nil + case t => t :: Nil + }).iterator.map(t => showCode(t)).mkString("\n") + } + if (i != length - 1) { + val CaseDef(_, _, body) = state.mkHandlerCaseForState + toHtmlLabel(stateLabel(state.state), show(compactStateTransform.transform(body)), dotBuilder) + } else { + toHtmlLabel(stateLabel(state.state), state.allStats.map(show(_)).mkString("\n"), dotBuilder) + } + dotBuilder.append("> ]\n") + state match { + case s: AsyncStateWithAwait => + val CaseDef(_, _, body) = s.mkOnCompleteHandler.get + dotBuilder.append(s"""${stateLabel(s.onCompleteState)} [label=""").append("<") + toHtmlLabel(stateLabel(s.onCompleteState), show(compactStateTransform.transform(body)), dotBuilder) + dotBuilder.append("> ]\n") + case _ => + } + } + for (state <- states) { + state match { + case s: AsyncStateWithAwait => + dotBuilder.append(s"""${stateLabel(state.state)} -> ${stateLabel(s.onCompleteState)} [style=dashed color=red]""") + dotBuilder.append("\n") + for (succ <- state.nextStates) { + dotBuilder.append(s"""${stateLabel(s.onCompleteState)} -> ${stateLabel(succ)}""") + dotBuilder.append("\n") + } + case _ => + for (succ <- state.nextStates) { + dotBuilder.append(s"""${stateLabel(state.state)} -> ${stateLabel(succ)}""") + dotBuilder.append("\n") + } + } + } + dotBuilder.append("}\n") + dotBuilder.toString + } + + lazy val asyncStates: List[AsyncState] = filterStates + + def filterStates = { + val all = blockBuilder.asyncStates.toList + val (initial :: rest) = all + val map = all.iterator.map(x => (x.state, x)).toMap + var seen = mutable.HashSet[Int]() + def loop(state: AsyncState): Unit = { + seen.add(state.state) + for (i <- state.nextStates) { + if (i != Int.MaxValue && !seen.contains(i)) { + loop(map(i)) + } + } + } + loop(initial) + val live = rest.filter(state => seen(state.state)) + var nextSwitchId = 0 + (initial :: live).foreach { state => + val switchId = nextSwitchId + switchIds(state.state) = switchId + nextSwitchId += 1 + state match { + case state: AsyncStateWithAwait => + val switchId = nextSwitchId + switchIds(state.onCompleteState) = switchId + nextSwitchId += 1 + case _ => + } + } + initial :: live + + } def mkCombinedHandlerCases[T: WeakTypeTag]: List[CaseDef] = { val caseForLastState: CaseDef = { @@ -408,7 +515,7 @@ trait ExprBuilder { val stateMemberSymbol = symLookup.stateMachineMember(name.state) val stateMemberRef = symLookup.memberRef(name.state) val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List()))))) - val body1 = eliminateDeadStates(body) + val body1 = compactStates(body) maybeTry( body1, @@ -427,62 +534,24 @@ trait ExprBuilder { })), EmptyTree) } - // Identify dead states: `case => { state = nextId; (); (); ... }, eliminated, and compact state ids to - // enable emission of a tableswitch. - private def eliminateDeadStates(m: Match): Tree = { - object DeadState { - private val liveStates = mutable.AnyRefMap[Integer, Integer]() - private val deadStates = mutable.AnyRefMap[Integer, Integer]() - private var compactedStateId = 1 - for (CaseDef(Literal(Constant(stateId: Integer)), EmptyTree, body) <- m.cases) { - body match { - case _ if (stateId == 0) => liveStates(stateId) = stateId - case Block(Assign(_, Literal(Constant(nextState: Integer))) :: rest, expr) if (expr :: rest).forall(t => isLiteralUnit(t)) => - deadStates(stateId) = nextState - case _ => - liveStates(stateId) = compactedStateId - compactedStateId += 1 - } - } - if (deadStates.nonEmpty) - AsyncUtils.vprintln(s"${deadStates.size} dead states eliminated") - def isDead(i: Integer) = deadStates.contains(i) - def translatedStateId(i: Integer, tree: Tree): Integer = { - def chaseDead(i: Integer): Integer = { - val replacement = deadStates.getOrNull(i) - if (replacement == null) i - else chaseDead(replacement) - } - - val live = chaseDead(i) - liveStates.get(live) match { - case Some(x) => x - case None => sys.error(s"$live, $liveStates \n$deadStates\n$m\n\n====\n$tree") - } - } - } - val stateMemberSymbol = symLookup.stateMachineMember(name.state) - // - remove CaseDef-s for dead states - // - rewrite state transitions to dead states to instead transition to the - // non-dead successor. - val elimDeadStateTransform = new Transformer { - override def transform(tree: Tree): Tree = tree match { - case as @ Assign(lhs, Literal(Constant(i: Integer))) if lhs.symbol == stateMemberSymbol => - val replacement = DeadState.translatedStateId(i, as) - treeCopy.Assign(tree, lhs, Literal(Constant(replacement))) - case _: Match | _: CaseDef | _: Block | _: If => - super.transform(tree) - case _ => tree - } + private lazy val stateMemberSymbol = symLookup.stateMachineMember(name.state) + private val compactStateTransform = new Transformer { + override def transform(tree: Tree): Tree = tree match { + case as @ Assign(lhs, Literal(Constant(i: Integer))) if lhs.symbol == stateMemberSymbol => + val replacement = switchIds(i) + treeCopy.Assign(tree, lhs, Literal(Constant(replacement))) + case _: Match | _: CaseDef | _: Block | _: If => + super.transform(tree) + case _ => tree } + } + + private def compactStates(m: Match): Tree = { val cases1 = m.cases.flatMap { case cd @ CaseDef(Literal(Constant(i: Integer)), EmptyTree, rhs) => - if (DeadState.isDead(i)) Nil - else { - val replacement = DeadState.translatedStateId(i, cd) - val rhs1 = elimDeadStateTransform.transform(rhs) - treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil - } + val replacement = switchIds(i) + val rhs1 = compactStateTransform.transform(rhs) + treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil case x => x :: Nil } treeCopy.Match(m, m.selector, cases1) @@ -515,7 +584,7 @@ trait ExprBuilder { } private def isSyntheticBindVal(tree: Tree) = tree match { - case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case vd@ValDef(_, lname, _, Ident(rname)) => attachments(vd.symbol).contains[SyntheticBindVal.type] case _ => false } diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala index 3ca9c834..9a9d7ef9 100644 --- a/src/main/scala/scala/async/internal/FutureSystem.scala +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -71,12 +71,17 @@ trait FutureSystem { /** A hook for custom macros to transform the tree post-ANF transform */ def postAnfTransform(tree: Block): Block = tree + + /** A hook for custom macros to selectively generate and process a Graphviz visualization of the transformed state machine */ + def dot(enclosingOwner: Symbol, macroApplication: Tree): Option[(String => Unit)] = None } def mkOps(c0: Context): Ops { val c: c0.type } + @deprecated("No longer honoured by the macro, all generated names now contain $async to avoid accidental clashes with lambda lifted names", "0.9.7") def freshenAllNames: Boolean = false def emitTryCatch: Boolean = true + @deprecated("No longer honoured by the macro, all generated names now contain $async to avoid accidental clashes with lambda lifted names", "0.9.7") def resultFieldName: String = "result" } diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala index ff905768..db015d13 100644 --- a/src/main/scala/scala/async/internal/Lifter.scala +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -120,13 +120,13 @@ trait Lifter { val rhs1 = if (sym.asTerm.isLazy) rhs else EmptyTree treeCopy.ValDef(vd, Modifiers(sym.flags), sym.name, TypeTree(tpe(sym)).setPos(t.pos), rhs1) case dd@DefDef(_, _, tparams, vparamss, tpt, rhs) => - sym.setName(this.name.fresh(sym.name.toTermName)) + sym.setName(this.name.freshen(sym.name.toTermName)) sym.setFlag(PRIVATE | LOCAL) // Was `DefDef(sym, rhs)`, but this ran afoul of `ToughTypeSpec.nestedMethodWithInconsistencyTreeAndInfoParamSymbols` // due to the handling of type parameter skolems in `thisMethodType` in `Namers` treeCopy.DefDef(dd, Modifiers(sym.flags), sym.name, tparams, vparamss, tpt, rhs) case cd@ClassDef(_, _, tparams, impl) => - sym.setName(newTypeName(name.fresh(sym.name.toString).toString)) + sym.setName(name.freshen(sym.name.toTypeName)) companionship.companionOf(cd.symbol) match { case NoSymbol => case moduleSymbol => @@ -137,13 +137,13 @@ trait Lifter { case md@ModuleDef(_, _, impl) => companionship.companionOf(md.symbol) match { case NoSymbol => - sym.setName(name.fresh(sym.name.toTermName)) + sym.setName(name.freshen(sym.name.toTermName)) sym.asModule.moduleClass.setName(sym.name.toTypeName) case classSymbol => // will be renamed by `case ClassDef` above. } treeCopy.ModuleDef(md, Modifiers(sym.flags), sym.name, impl) case td@TypeDef(_, _, tparams, rhs) => - sym.setName(newTypeName(name.fresh(sym.name.toString).toString)) + sym.setName(name.freshen(sym.name.toTypeName)) treeCopy.TypeDef(td, Modifiers(sym.flags), sym.name, tparams, rhs) } atPos(t.pos)(treeLifted) diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 692d0bf6..8df998c2 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -1,5 +1,10 @@ package scala.async.internal +import java.util +import java.util.function.{IntConsumer, IntPredicate} + +import scala.collection.immutable.IntMap + trait LiveVariables { self: AsyncMacro => import c.universe._ @@ -17,19 +22,22 @@ trait LiveVariables { def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = { // live variables analysis: // the result map indicates in which states a given field should be nulled out - val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables) + val liveVarsMap: Map[Tree, StateSet] = liveVars(asyncStates, liftables) var assignsOf = Map[Int, List[Tree]]() - for ((fld, where) <- liveVarsMap; state <- where) - assignsOf get state match { - case None => - assignsOf += (state -> List(fld)) - case Some(trees) if !trees.exists(_.symbol == fld.symbol) => - assignsOf += (state -> (fld +: trees)) - case _ => - /* do nothing */ - } + for ((fld, where) <- liveVarsMap) { + where.foreach { new IntConsumer { def accept(state: Int): Unit = { + assignsOf get state match { + case None => + assignsOf += (state -> List(fld)) + case Some(trees) if !trees.exists(_.symbol == fld.symbol) => + assignsOf += (state -> (fld +: trees)) + case _ => + // do nothing + } + }}} + } assignsOf } @@ -46,9 +54,9 @@ trait LiveVariables { * @param liftables the lifted fields * @return a map which indicates for a given field (the key) the states in which it should be nulled out */ - def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = { + def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, StateSet] = { val liftedSyms: Set[Symbol] = // include only vars - liftables.filter { + liftables.iterator.filter { case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE) case _ => false }.map(_.symbol).toSet @@ -122,20 +130,30 @@ trait LiveVariables { * A state `i` is contained in the list that is the value to which * key `j` maps iff control can flow from state `j` to state `i`. */ - val cfg: Map[Int, List[Int]] = asyncStates.map(as => as.state -> as.nextStates).toMap + val cfg: Map[Int, Array[Int]] = { + var res = IntMap.empty[Array[Int]] + + for (as <- asyncStates) res = res.updated(as.state, as.nextStates) + res + } /** Tests if `state1` is a predecessor of `state2`. */ def isPred(state1: Int, state2: Int): Boolean = { - val seen = scala.collection.mutable.HashSet[Int]() + val seen = new StateSet() def isPred0(state1: Int, state2: Int): Boolean = if(state1 == state2) false - else if (seen(state1)) false // breaks cycles in the CFG + else if (seen.contains(state1)) false // breaks cycles in the CFG else cfg get state1 match { case Some(nextStates) => seen += state1 - nextStates.contains(state2) || nextStates.exists(isPred0(_, state2)) + var i = 0 + while (i < nextStates.length) { + if (nextStates(i) == state2 || isPred0(nextStates(i), state2)) return true + i += 1 + } + false case None => false } @@ -164,8 +182,8 @@ trait LiveVariables { * 7. repeat if something has changed */ - var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() - var LVexit = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + var LVentry = IntMap[Set[Symbol]]() withDefaultValue Set[Symbol]() + var LVexit = IntMap[Set[Symbol]]() withDefaultValue Set[Symbol]() // All fields are declared to be dead at the exit of the final async state, except for the ones // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def. @@ -174,6 +192,14 @@ trait LiveVariables { var currStates = List(finalState) // start at final state var captured: Set[Symbol] = Set() + def contains(as: Array[Int], a: Int): Boolean = { + var i = 0 + while (i < as.length) { + if (as(i) == a) return true + i += 1 + } + false + } while (!currStates.isEmpty) { var entryChanged: List[AsyncState] = Nil @@ -183,19 +209,19 @@ trait LiveVariables { captured ++= referenced.captured val LVentryNew = LVexit(cs.state) ++ referenced.used if (!LVentryNew.sameElements(LVentryOld)) { - LVentry = LVentry + (cs.state -> LVentryNew) + LVentry = LVentry.updated(cs.state, LVentryNew) entryChanged ::= cs } } - val pred = entryChanged.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state))) + val pred = entryChanged.flatMap(cs => asyncStates.filter(state => contains(state.nextStates, cs.state))) var exitChanged: List[AsyncState] = Nil for (p <- pred) { val LVexitOld = LVexit(p.state) val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet if (!LVexitNew.sameElements(LVexitOld)) { - LVexit = LVexit + (p.state -> LVexitNew) + LVexit = LVexit.updated(p.state, LVexitNew) exitChanged ::= p } } @@ -210,53 +236,64 @@ trait LiveVariables { } } - def lastUsagesOf(field: Tree, at: AsyncState): Set[Int] = { + def lastUsagesOf(field: Tree, at: AsyncState): StateSet = { val avoid = scala.collection.mutable.HashSet[AsyncState]() - def lastUsagesOf0(field: Tree, at: AsyncState): Set[Int] = { - if (avoid(at)) Set() + val result = new StateSet + def lastUsagesOf0(field: Tree, at: AsyncState): Unit = { + if (avoid(at)) () else if (captured(field.symbol)) { - Set() + () } else LVentry get at.state match { case Some(fields) if fields.contains(field.symbol) => - Set(at.state) + result += at.state case _ => avoid += at - val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet - preds.flatMap(p => lastUsagesOf0(field, p)) + for (state <- asyncStates) { + if (contains(state.nextStates, at.state)) { + lastUsagesOf0(field, state) + } + } } } lastUsagesOf0(field, at) + result } - val lastUsages: Map[Tree, Set[Int]] = - liftables.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap + val lastUsages: Map[Tree, StateSet] = + liftables.iterator.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap if(AsyncUtils.verbose) { for ((fld, lastStates) <- lastUsages) - AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}") + AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.iterator.mkString(", ")}") } - val nullOutAt: Map[Tree, Set[Int]] = + val nullOutAt: Map[Tree, StateSet] = for ((fld, lastStates) <- lastUsages) yield { - val killAt = lastStates.flatMap { s => - if (s == finalState.state) Set() - else { + var result = new StateSet + lastStates.foreach(new IntConsumer { def accept(s: Int): Unit = { + if (s != finalState.state) { val lastAsyncState = asyncStates.find(_.state == s).get val succNums = lastAsyncState.nextStates // all successor states that are not indirect predecessors // filter out successor states where the field is live at the entry - succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).contains(fld.symbol)) + var i = 0 + while (i < succNums.length) { + val num = succNums(i) + if (!isPred(num, s) && !LVentry(num).contains(fld.symbol)) + result += num + i += 1 + } } - } - (fld, killAt) + }}) + (fld, result) } if(AsyncUtils.verbose) { for ((fld, killAt) <- nullOutAt) - AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}") + AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.iterator.mkString(", ")}") } nullOutAt diff --git a/src/main/scala/scala/async/internal/StateSet.scala b/src/main/scala/scala/async/internal/StateSet.scala new file mode 100644 index 00000000..2dc61e7c --- /dev/null +++ b/src/main/scala/scala/async/internal/StateSet.scala @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2018 Lightbend Inc. + */ +package scala.async.internal + +import java.util +import java.util.function.{Consumer, IntConsumer} + +import scala.collection.JavaConverters.{asScalaIteratorConverter, iterableAsScalaIterableConverter} +import scala.collection.mutable + +// Set for StateIds, which are either small positive integers or -symbolID. +final class StateSet { + private var bitSet = new java.util.BitSet() + private var caseSet = new util.HashSet[Integer]() + def +=(stateId: Int): Unit = if (stateId > 0) bitSet.set(stateId) else caseSet.add(stateId) + def contains(stateId: Int): Boolean = if (stateId > 0 && stateId < 1024) bitSet.get(stateId) else caseSet.contains(stateId) + def iterator: Iterator[Integer] = { + bitSet.stream().iterator().asScala ++ caseSet.asScala.iterator + } + def foreach(f: IntConsumer): Unit = { + bitSet.stream().forEach(f) + caseSet.stream().forEach(new Consumer[Integer] { + override def accept(value: Integer): Unit = f.accept(value) + }) + } +} diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 855cbd28..49148894 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -17,42 +17,8 @@ private[async] trait TransformUtils { import c.internal._ import decorators._ - private object baseNames { - - val matchRes = "matchres" - val ifRes = "ifres" - val bindSuffix = "$bind" - val completed = newTermName("completed") - - val state = newTermName("state") - val result = newTermName(self.futureSystem.resultFieldName) - val execContext = newTermName("execContext") - val tr = newTermName("tr") - val t = newTermName("throwable") - } - - object name { - def matchRes = maybeFresh(baseNames.matchRes) - def ifRes = maybeFresh(baseNames.ifRes) - def bindSuffix = maybeFresh(baseNames.bindSuffix) - def completed = maybeFresh(baseNames.completed) - - val state = maybeFresh(baseNames.state) - val result = baseNames.result - val execContext = maybeFresh(baseNames.execContext) - val tr = maybeFresh(baseNames.tr) - val t = maybeFresh(baseNames.t) - - val await = "await" - val resume = newTermName("resume") - val apply = newTermName("apply") - val stateMachine = newTermName(fresh("stateMachine")) - val stateMachineT = stateMachine.toTypeName - - def maybeFresh(name: TermName): TermName = if (self.asyncBase.futureSystem.freshenAllNames) fresh(name) else name - def maybeFresh(name: String): String = if (self.asyncBase.futureSystem.freshenAllNames) fresh(name) else name - def fresh(name: TermName): TermName = c.freshName(name) - + object name extends asyncNames.AsyncName { + def fresh(name: TermName): TermName = freshenIfNeeded(name) def fresh(name: String): String = c.freshName(name) } @@ -162,10 +128,10 @@ private[async] trait TransformUtils { (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) } } - private def argName(fun: Tree): ((Int, Int) => String) = { + private def argName(fun: Tree): ((Int, Int) => TermName) = { val paramss = fun.tpe.paramss - val namess = paramss.map(_.map(_.name.toString)) - (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + val namess = paramss.map(_.map(_.name.toTermName)) + (i, j) => util.Try(namess(i)(j)).getOrElse(TermName(s"arg_${i}_${j}")) } object defn { @@ -246,7 +212,7 @@ private[async] trait TransformUtils { } } - case class Arg(expr: Tree, isByName: Boolean, argName: String) + case class Arg(expr: Tree, isByName: Boolean, argName: TermName) /** * Transform a list of argument lists, producing the transformed lists, and lists of auxillary @@ -261,7 +227,7 @@ private[async] trait TransformUtils { */ def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { val isByNamess: (Int, Int) => Boolean = isByName(fun) - val argNamess: (Int, Int) => String = argName(fun) + val argNamess: (Int, Int) => TermName = argName(fun) argss.zipWithIndex.map { case (args, i) => mapArguments[A](args) { (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 3b685c82..cc4febc2 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -38,7 +38,7 @@ class TreeInterrogation { val varDefs = tree1.collect { case vd @ ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) && vd.symbol.owner.isClass => name } - varDefs.map(_.decoded.trim).toSet.toList.sorted mustStartWith (List("await$macro$", "await$macro$", "state")) + varDefs.map(_.decoded.trim).toSet.toList.sorted mustStartWith (List("await$async$", "await$async", "state$async")) val defDefs = tree1.collect { case t: Template => @@ -49,11 +49,11 @@ class TreeInterrogation { && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name } }.flatten - defDefs.map(_.decoded.trim) mustStartWith List("foo$macro$", "", "apply", "apply") + defDefs.map(_.decoded.trim) mustStartWith List("foo$async$", "", "apply", "apply") } } -object TreeInterrogation extends App { +object TreeInterrogationApp extends App { def withDebug[T](t: => T): T = { def set(level: String, value: Boolean) = System.setProperty(s"scala.async.$level", value.toString) val levels = Seq("trace", "debug") @@ -65,7 +65,7 @@ object TreeInterrogation extends App { withDebug { val cm = reflect.runtime.currentMirror - val tb = mkToolbox(s"-cp ${toolboxClasspath} -Xprint:typer -uniqid") + val tb = mkToolbox(s"-cp ${toolboxClasspath} -Xprint:typer") import scala.async.internal.AsyncId._ val tree = tb.parse( """ @@ -75,6 +75,9 @@ object TreeInterrogation extends App { | while(await(b)) { | b = false | } + | (1, 1) match { + | case (x, y) => await(2); println(x) + | } | await(b) | } | diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 16321cdb..2b54e169 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -403,7 +403,8 @@ class AnfTransformSpec { """.stripMargin }) val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x } - applyImplicitView.map(_.toString) mustStartWith List("view(a$macro$") + println(applyImplicitView) + applyImplicitView.map(_.toString) mustStartWith List("view(") } @Test