Skip to content

Commit 5ee5ee0

Browse files
committed
Use Labeled blocks in TailRec, instead of label-defs.
It's easier to first explain on an example. Consider the following tail-recursive method: def fact(n: Int, acc: Int): Int = if (n == 0) acc else fact(n - 1, n * acc) It is now translated as follows by the `tailrec` transform: def fact(n: Int, acc: Int): Int = { var n$tailLocal1: Int = n var acc$tailLocal1: Int = acc while (true) { tailLabel1[Unit]: { return { if (n$tailLocal1 == 0) { acc } else { val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1 val acc$tailLocal1$tmp1: Int = n$tailLocal1 * acc$tailLocal1 n$tailLocal1 = n$tailLocal1$tmp1 acc$tailLocal1 = acc$tailLocal1$tmp1 (return[tailLabel1] ()): Int } } } } throw null // unreachable code } First, we allocate local `var`s for every parameter, as well as `this` if necessary. When we find a tail-recursive call, we evaluate the arguments into temporaries, then assign them to the `var`s. It is necessary to use temporaries in order not to use the new contents of a param local when computing the new value of another param local. We avoid reassigning param locals if their rhs (i.e., the actual argument to the recursive call) is itself, which does happen quite often in practice. In particular, we thus avoid reassigning the local var for `this` if the prefix is empty. We could further optimize this by avoiding the reassignment if the prefix is non-empty but equivalent to `this`. If only one parameter ends up changing value in any particular tail-recursive call, we can avoid the temporaries and directly assign it. This is also a fairly common situation, especially after discarding useless assignments to the local for `this`. After all that, we `return` from a labeled block, which is right inside an infinite `while` loop. The net result is to loop back to the beginning, implementing the jump. The `return` node is explicitly ascribed with the previous result type, so that lubs upstream are not affected (not doing so can cause Ycheck errors). For control flows that do *not* end up in a tail-recursive call, the result value is given to an explicit `return` out of the enclosing method, which prevents the looping. There is one pretty ugly artifact: after the `while` loop, we must insert a `throw null` for the body to still typecheck as an `Int` (the result type of the `def`). This could be avoided if we dared type a `WhileDo(Literal(Constant(true)), body)` as having type `Nothing` rather than `Unit`. This is probably dangerous, though, as we have no guarantee that further transformations will leave the `true` alone, especially in the presence of compiler plugins. If the `true` gets wrapped in any way, the type of the `WhileDo` will be altered, and chaos will ensue. In the future, we could enhance the codegen to avoid emitting that dead code. This should not be too difficult: * emitting a `WhileDo` whose argument is `true` would set the generated `BType` to `Nothing`. * then, when emitting a `Block`, we would drop any statements and expr following a statement whose generated `BType` was `Nothing`. This commit does not go to such lengths, however. This change removes the last source of label-defs in the compiler. After this commit, we will be able to entirely remove label-defs.
1 parent 97c63aa commit 5ee5ee0

File tree

2 files changed

+146
-86
lines changed

2 files changed

+146
-86
lines changed

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ object NameKinds {
286286
val NonLocalReturnKeyName: UniqueNameKind = new UniqueNameKind("nonLocalReturnKey")
287287
val WildcardParamName: UniqueNameKind = new UniqueNameKind("_$")
288288
val TailLabelName: UniqueNameKind = new UniqueNameKind("tailLabel")
289+
val TailLocalName: UniqueNameKind = new UniqueNameKind("$tailLocal")
290+
val TailTempName: UniqueNameKind = new UniqueNameKind("$tmp")
289291
val ExceptionBinderName: UniqueNameKind = new UniqueNameKind("ex")
290292
val SkolemName: UniqueNameKind = new UniqueNameKind("?")
291293
val LiftedTreeName: UniqueNameKind = new UniqueNameKind("liftedTree")

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 144 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@ package transform
44
import ast.Trees._
55
import ast.{TreeTypeMap, tpd}
66
import core._
7+
import Constants.Constant
78
import Contexts.Context
89
import Decorators._
910
import Symbols._
1011
import StdNames.nme
1112
import Types._
12-
import NameKinds.TailLabelName
13+
import NameKinds.{TailLabelName, TailLocalName, TailTempName}
1314
import MegaPhase.MiniPhase
1415
import reporting.diagnostic.messages.TailrecNotApplicable
1516

1617
/**
1718
* A Tail Rec Transformer
1819
* @author Erik Stenman, Iulian Dragos,
1920
* ported and heavily modified for dotty by Dmitry Petrashko
20-
* moved after erasure by Sébastien Doeraene
21+
* moved after erasure and adapted to emit `Labeled` blocks
22+
* by Sébastien Doeraene
2123
* @version 1.1
2224
*
2325
* What it does:
@@ -32,10 +34,29 @@ import reporting.diagnostic.messages.TailrecNotApplicable
3234
* contain such calls are not transformed).
3335
* </p>
3436
* <p>
35-
* Self-recursive calls in tail-position are replaced by jumps to a
36-
* label at the beginning of the method. As the JVM provides no way to
37-
* jump from a method to another one, non-recursive calls in
38-
* tail-position are not optimized.
37+
* When a method contains at least one tail-recursive call, its rhs
38+
* is wrapped in the following structure:
39+
* </p>
40+
* <pre>
41+
* var localForParam1: T1 = param1
42+
* ...
43+
* while (true) {
44+
* tailResult[ResultType]: {
45+
* return {
46+
* // original rhs
47+
* }
48+
* }
49+
* }
50+
* </pre>
51+
* <p>
52+
* Self-recursive calls in tail-position are then replaced by (a)
53+
* reassigning the local `var`s substituting formal parameters and
54+
* (b) a `return` from the `tailResult` labeled block, which has the
55+
* net effect of looping back to the beginning of the method.
56+
* </p>
57+
* <p>
58+
* As the JVM provides no way to jump from a method to another one,
59+
* non-recursive calls in tail-position are not optimized.
3960
* </p>
4061
* <p>
4162
* A method call is self-recursive if it calls the current method and
@@ -46,14 +67,8 @@ import reporting.diagnostic.messages.TailrecNotApplicable
4667
* </p>
4768
* <p>
4869
* This phase has been moved after erasure to allow the use of vars
49-
* for the parameters combined with a `WhileDo` (upcoming change).
50-
* This is also beneficial to support polymorphic tail-recursive
51-
* calls.
52-
* </p>
53-
* <p>
54-
* If a method contains self-recursive calls, a label is added to at
55-
* the beginning of its body and the calls are replaced by jumps to
56-
* that label.
70+
* for the parameters combined with a `WhileDo`. This is also
71+
* beneficial to support polymorphic tail-recursive calls.
5772
* </p>
5873
* <p>
5974
* In scalac, if the method had type parameters, the call must contain
@@ -72,25 +87,6 @@ class TailRec extends MiniPhase {
7287

7388
override def runsAfter: Set[String] = Set(Erasure.name) // tailrec assumes erased types
7489

75-
final val labelFlags: Flags.FlagSet = Flags.Synthetic | Flags.Label | Flags.Method
76-
77-
private def mkLabel(method: Symbol)(implicit ctx: Context): TermSymbol = {
78-
val name = TailLabelName.fresh()
79-
80-
if (method.owner.isClass) {
81-
val MethodTpe(paramNames, paramInfos, resultType) = method.info
82-
83-
val enclosingClass = method.enclosingClass.asClass
84-
val thisParamType =
85-
if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
86-
else enclosingClass.classInfo.selfType
87-
88-
ctx.newSymbol(method, name.toTermName, labelFlags,
89-
MethodType(nme.SELF :: paramNames, thisParamType :: paramInfos, resultType))
90-
}
91-
else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
92-
}
93-
9490
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = {
9591
val sym = tree.symbol
9692
tree match {
@@ -100,62 +96,60 @@ class TailRec extends MiniPhase {
10096
cpy.DefDef(dd)(rhs = {
10197
val defIsTopLevel = sym.owner.isClass
10298
val origMeth = sym
103-
val label = mkLabel(sym)
10499
val owner = ctx.owner.enclosingClass.asClass
105100

106-
var rewrote = false
107-
108101
// Note: this can be split in two separate transforms(in different groups),
109102
// than first one will collect info about which transformations and rewritings should be applied
110103
// and second one will actually apply,
111104
// now this speculatively transforms tree and throws away result in many cases
112-
val rhsSemiTransformed = {
113-
val transformer = new TailRecElimination(origMeth, owner, mandatory, label)
114-
val rhs = transformer.transform(dd.rhs)
115-
rewrote = transformer.rewrote
116-
rhs
117-
}
118-
119-
if (rewrote) {
120-
if (tree.symbol.owner.isClass) {
121-
val classSym = tree.symbol.owner.asClass
122-
123-
val labelDef = DefDef(label, vrefss => {
124-
assert(vrefss.size == 1, vrefss)
125-
val vrefs = vrefss.head
126-
val thisRef = vrefs.head
127-
val origMeth = tree.symbol
128-
val origVParams = vparams.map(_.symbol)
105+
val transformer = new TailRecElimination(origMeth, owner, vparams.map(_.symbol), mandatory)
106+
val rhsSemiTransformed = transformer.transform(dd.rhs)
107+
108+
if (transformer.rewrote) {
109+
val varForRewrittenThis = transformer.varForRewrittenThis
110+
val rewrittenParamSyms = transformer.rewrittenParamSyms
111+
val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms
112+
113+
val initialValDefs = {
114+
val initialParamValDefs = for ((param, local) <- rewrittenParamSyms.zip(varsForRewrittenParamSyms)) yield {
115+
ValDef(local.asTerm, ref(param))
116+
}
117+
varForRewrittenThis match {
118+
case Some(local) => ValDef(local.asTerm, This(tree.symbol.owner.asClass)) :: initialParamValDefs
119+
case none => initialParamValDefs
120+
}
121+
}
122+
123+
val rhsFullyTransformed = varForRewrittenThis match {
124+
case Some(localThisSym) =>
125+
val classSym = tree.symbol.owner.asClass
126+
val thisRef = localThisSym.termRef
129127
new TreeTypeMap(
130-
typeMap = identity(_)
131-
.substThisUnlessStatic(classSym, thisRef.tpe)
132-
.subst(origVParams, vrefs.tail.map(_.tpe)),
128+
typeMap = _.substThisUnlessStatic(classSym, thisRef)
129+
.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)),
133130
treeMap = {
134-
case tree: This if tree.symbol == classSym => thisRef
131+
case tree: This if tree.symbol == classSym => Ident(thisRef)
135132
case tree => tree
136-
},
137-
oldOwners = origMeth :: Nil,
138-
newOwners = label :: Nil
133+
}
139134
).transform(rhsSemiTransformed)
140-
})
141-
val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams.map(x => ref(x.symbol)))
142-
Block(List(labelDef), callIntoLabel)
143-
} else { // inner method. Tail recursion does not change `this`
144-
val labelDef = DefDef(label, vrefss => {
145-
assert(vrefss.size == 1, vrefss)
146-
val vrefs = vrefss.head
147-
val origMeth = tree.symbol
148-
val origVParams = vparams.map(_.symbol)
135+
136+
case none =>
149137
new TreeTypeMap(
150-
typeMap = identity(_)
151-
.subst(origVParams, vrefs.map(_.tpe)),
152-
oldOwners = origMeth :: Nil,
153-
newOwners = label :: Nil
138+
typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef))
154139
).transform(rhsSemiTransformed)
155-
})
156-
val callIntoLabel = ref(label).appliedToArgs(vparams.map(x => ref(x.symbol)))
157-
Block(List(labelDef), callIntoLabel)
158-
}} else {
140+
}
141+
142+
Block(
143+
initialValDefs :::
144+
WhileDo(Literal(Constant(true)), {
145+
Labeled(transformer.continueLabel.get.asTerm, {
146+
Return(rhsFullyTransformed, ref(origMeth))
147+
})
148+
}) ::
149+
Nil,
150+
Throw(Literal(Constant(null))) // unreachable code
151+
)
152+
} else {
159153
if (mandatory) ctx.error(
160154
"TailRec optimisation not applicable, method not tail recursive",
161155
// FIXME: want to report this error on `dd.namePos`, but
@@ -174,12 +168,51 @@ class TailRec extends MiniPhase {
174168

175169
}
176170

177-
class TailRecElimination(method: Symbol, enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
171+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, paramSyms: List[Symbol], isMandatory: Boolean) extends tpd.TreeMap {
178172

179173
import dotty.tools.dotc.ast.tpd._
180174

181175
var rewrote: Boolean = false
182176

177+
var continueLabel: Option[Symbol] = None
178+
var varForRewrittenThis: Option[Symbol] = None
179+
var rewrittenParamSyms: List[Symbol] = Nil
180+
var varsForRewrittenParamSyms: List[Symbol] = Nil
181+
182+
private def getContinueLabel()(implicit c: Context): Symbol = {
183+
continueLabel match {
184+
case Some(sym) => sym
185+
case none =>
186+
val sym = c.newSymbol(method, TailLabelName.fresh(), Flags.Label, defn.UnitType)
187+
continueLabel = Some(sym)
188+
sym
189+
}
190+
}
191+
192+
private def getVarForRewrittenThis()(implicit c: Context): Symbol = {
193+
varForRewrittenThis match {
194+
case Some(sym) => sym
195+
case none =>
196+
val tpe =
197+
if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
198+
else enclosingClass.asClass.classInfo.selfType
199+
val sym = c.newSymbol(method, nme.SELF, Flags.Synthetic | Flags.Mutable, tpe)
200+
varForRewrittenThis = Some(sym)
201+
sym
202+
}
203+
}
204+
205+
private def getVarForRewrittenParam(param: Symbol)(implicit c: Context): Symbol = {
206+
rewrittenParamSyms.indexOf(param) match {
207+
case -1 =>
208+
val sym = c.newSymbol(method, TailLocalName.fresh(param.name.toTermName), Flags.Synthetic | Flags.Mutable, param.info)
209+
rewrittenParamSyms ::= param
210+
varsForRewrittenParamSyms ::= sym
211+
sym
212+
case index => varsForRewrittenParamSyms(index)
213+
}
214+
}
215+
183216
/** Symbols of Labeled blocks that are in tail position. */
184217
private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]()
185218

@@ -233,15 +266,40 @@ class TailRec extends MiniPhase {
233266
if (ctx.tailPos) {
234267
c.debuglog("Rewriting tail recursive call: " + tree.pos)
235268
rewrote = true
236-
def receiver =
237-
if (prefix eq EmptyTree) This(enclosingClass.asClass)
238-
else noTailTransform(prefix)
239269

240-
val argumentsWithReceiver =
241-
if (this.method.owner.isClass) receiver :: arguments
242-
else arguments
243-
244-
tpd.cpy.Apply(tree)(ref(label), argumentsWithReceiver)
270+
val assignParamPairs = for {
271+
(param, arg) <- paramSyms.zip(arguments)
272+
if (arg match {
273+
case arg: Ident => arg.symbol != param
274+
case _ => true
275+
})
276+
} yield {
277+
(getVarForRewrittenParam(param), arg)
278+
}
279+
280+
val assignThisAndParamPairs = {
281+
if (prefix eq EmptyTree) assignParamPairs
282+
else {
283+
// TODO Opt: also avoid assigning `this` if the prefix is `this.`
284+
(getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
285+
}
286+
}
287+
288+
val assignments = assignThisAndParamPairs match {
289+
case (lhs, rhs) :: Nil =>
290+
Assign(ref(lhs), rhs) :: Nil
291+
case _ :: _ =>
292+
val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield {
293+
val temp = c.newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Flags.Synthetic, lhs.info)
294+
(ValDef(temp, rhs), Assign(ref(lhs), ref(temp)).withPos(tree.pos))
295+
}).unzip
296+
tempValDefs ::: assigns
297+
case nil =>
298+
Nil
299+
}
300+
301+
val tpt = TypeTree(method.info.resultType)
302+
Block(assignments, Typed(Return(Literal(Constant(())).withPos(tree.pos), ref(getContinueLabel())), tpt))
245303
}
246304
else fail("it is not in tail position")
247305
} else {

0 commit comments

Comments
 (0)