Skip to content

Commit 354addd

Browse files
committed
Simplify TailRec
1 parent ca28446 commit 354addd

File tree

1 file changed

+57
-64
lines changed

1 file changed

+57
-64
lines changed

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 57 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import ast.{TreeTypeMap, tpd}
66
import core._
77
import Contexts.Context
88
import Decorators._
9-
import DenotTransformers.DenotTransformer
109
import Denotations.SingleDenotation
1110
import Symbols._
1211
import Types._
@@ -62,13 +61,11 @@ import TreeTransforms.{MiniPhaseTransform, TransformerInfo}
6261
* self recursive functions, that's why it's renamed to tailrec
6362
* </p>
6463
*/
65-
class TailRec extends MiniPhaseTransform with DenotTransformer with FullParameterization { thisTransform =>
64+
class TailRec extends MiniPhaseTransform with FullParameterization { thisTransform =>
6665
import TailRec._
6766

6867
import dotty.tools.dotc.ast.tpd._
6968

70-
override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref
71-
7269
override def phaseName: String = "tailrec"
7370

7471
override def treeTransformPhase(implicit ctx: Context, info: TransformerInfo) =
@@ -98,71 +95,69 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
9895
else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
9996
}
10097

98+
/** Note: This method should be run atGroupEnd */
10199
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
102100
val sym = tree.symbol
103101
tree match {
104102
case dd@DefDef(name, tparams, vparamss0, tpt, _)
105103
if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (dd.rhs eq EmptyTree) || (sym is Flags.Label)) =>
106104
val mandatory = sym.hasAnnotation(defn.TailrecAnnot)
107-
atGroupEnd { implicit ctx: Context =>
108-
109-
cpy.DefDef(dd)(rhs = {
110-
111-
val defIsTopLevel = sym.owner.isClass
112-
val origMeth = sym
113-
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
114-
val owner = ctx.owner.enclosingClass.asClass
115-
val thisTpe = owner.thisType.widen
116-
117-
var rewrote = false
118-
119-
// Note: this can be split in two separate transforms(in different groups),
120-
// than first one will collect info about which transformations and rewritings should be applied
121-
// and second one will actually apply,
122-
// now this speculatively transforms tree and throws away result in many cases
123-
val rhsSemiTransformed = {
124-
val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
125-
val rhs = atGroupEnd(implicit ctx => transformer.transform(dd.rhs))
126-
rewrote = transformer.rewrote
127-
rhs
128-
}
105+
cpy.DefDef(dd)(rhs = {
106+
107+
val defIsTopLevel = sym.owner.isClass
108+
val origMeth = sym
109+
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
110+
val owner = ctx.owner.enclosingClass.asClass
111+
val thisTpe = owner.thisType.widen
112+
113+
var rewrote = false
114+
115+
// Note: this can be split in two separate transforms(in different groups),
116+
// than first one will collect info about which transformations and rewritings should be applied
117+
// and second one will actually apply,
118+
// now this speculatively transforms tree and throws away result in many cases
119+
val rhsSemiTransformed = {
120+
val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
121+
val rhs = atGroupEnd(implicit ctx => transformer.transform(dd.rhs))
122+
rewrote = transformer.rewrote
123+
rhs
124+
}
129125

130-
if (rewrote) {
131-
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
132-
if (tree.symbol.owner.isClass) {
133-
val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
134-
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true)
135-
Block(List(labelDef), call)
136-
} else { // inner method. Tail recursion does not change `this`
137-
val labelDef = polyDefDef(label, trefs => vrefss => {
138-
val origMeth = tree.symbol
139-
val origTParams = tree.tparams.map(_.symbol)
140-
val origVParams = tree.vparamss.flatten map (_.symbol)
141-
new TreeTypeMap(
142-
typeMap = identity(_)
143-
.substDealias(origTParams, trefs)
144-
.subst(origVParams, vrefss.flatten.map(_.tpe)),
145-
oldOwners = origMeth :: Nil,
146-
newOwners = label :: Nil
147-
).transform(rhsSemiTransformed)
148-
})
149-
val callIntoLabel = (
150-
if (dd.tparams.isEmpty) ref(label)
151-
else ref(label).appliedToTypes(dd.tparams.map(_.tpe))
152-
).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))
153-
Block(List(labelDef), callIntoLabel)
154-
}} else {
155-
if (mandatory) ctx.error(
156-
"TailRec optimisation not applicable, method not tail recursive",
157-
// FIXME: want to report this error on `dd.namePos`, but
158-
// because of extension method getting a weird pos, it is
159-
// better to report on symbol so there's no overlap
160-
sym.pos
161-
)
162-
dd.rhs
163-
}
164-
})
165-
}
126+
if (rewrote) {
127+
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
128+
if (tree.symbol.owner.isClass) {
129+
val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
130+
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true)
131+
Block(List(labelDef), call)
132+
} else { // inner method. Tail recursion does not change `this`
133+
val labelDef = polyDefDef(label, trefs => vrefss => {
134+
val origMeth = tree.symbol
135+
val origTParams = tree.tparams.map(_.symbol)
136+
val origVParams = tree.vparamss.flatten map (_.symbol)
137+
new TreeTypeMap(
138+
typeMap = identity(_)
139+
.substDealias(origTParams, trefs)
140+
.subst(origVParams, vrefss.flatten.map(_.tpe)),
141+
oldOwners = origMeth :: Nil,
142+
newOwners = label :: Nil
143+
).transform(rhsSemiTransformed)
144+
})
145+
val callIntoLabel = (
146+
if (dd.tparams.isEmpty) ref(label)
147+
else ref(label).appliedToTypes(dd.tparams.map(_.tpe))
148+
).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))
149+
Block(List(labelDef), callIntoLabel)
150+
}} else {
151+
if (mandatory) ctx.error(
152+
"TailRec optimisation not applicable, method not tail recursive",
153+
// FIXME: want to report this error on `dd.namePos`, but
154+
// because of extension method getting a weird pos, it is
155+
// better to report on symbol so there's no overlap
156+
sym.pos
157+
)
158+
dd.rhs
159+
}
160+
})
166161
case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
167162
ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", sym.pos)
168163
d
@@ -242,8 +237,6 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
242237
continue
243238
}
244239

245-
246-
247240
if (isRecursiveCall) {
248241
if (ctx.tailPos) {
249242
val receiverIsSame = enclosingClass.appliedRef.widenDealias =:= recvWiden

0 commit comments

Comments
 (0)