Skip to content

Commit edeb7e4

Browse files
committed
Make trees after TailRec type correct
TailRec now relies on FullParameterization and uses two passes for transformation. First one decides weather the method will be transformed and if yes, starts rewriting calls in taill position. Second one lifts the method body to a fully parameterized one, correcting types.
1 parent 601b11e commit edeb7e4

File tree

1 file changed

+73
-90
lines changed

1 file changed

+73
-90
lines changed

src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 73 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,16 @@
11
package dotty.tools.dotc.transform
22

3-
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, TreeTransformer}
4-
import dotty.tools.dotc.ast.{Trees, tpd}
3+
import dotty.tools.dotc.ast.Trees._
4+
import dotty.tools.dotc.ast.tpd
55
import dotty.tools.dotc.core.Contexts.Context
6-
import scala.collection.mutable.ListBuffer
7-
import dotty.tools.dotc.core._
8-
import dotty.tools.dotc.core.Symbols.NoSymbol
9-
import scala.annotation.tailrec
10-
import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._
11-
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._
12-
import Decorators._
13-
import Symbols._
14-
import scala.Some
15-
import dotty.tools.dotc.transform.TreeTransforms.{NXTransformations, TransformerInfo, TreeTransform, TreeTransformer}
16-
import dotty.tools.dotc.core.Contexts.Context
17-
import scala.collection.mutable
18-
import dotty.tools.dotc.core.Names.Name
19-
import NameOps._
20-
import dotty.tools.dotc.CompilationUnit
21-
import dotty.tools.dotc.util.Positions.{Position, Coord}
22-
import dotty.tools.dotc.util.Positions.NoPosition
6+
import dotty.tools.dotc.core.Decorators._
237
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
248
import dotty.tools.dotc.core.Denotations.SingleDenotation
9+
import dotty.tools.dotc.core.Symbols._
10+
import dotty.tools.dotc.core.Types._
11+
import dotty.tools.dotc.core._
2512
import dotty.tools.dotc.transform.TailRec._
13+
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform}
2614

2715
/**
2816
* A Tail Rec Transformer
@@ -74,9 +62,9 @@ import dotty.tools.dotc.transform.TailRec._
7462
* self recursive functions, that's why it's renamed to tailrec
7563
* </p>
7664
*/
77-
class TailRec extends TreeTransform with DenotTransformer {
65+
class TailRec extends TreeTransform with DenotTransformer with FullParameterization {
7866

79-
import tpd._
67+
import dotty.tools.dotc.ast.tpd._
8068

8169
override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref
8270

@@ -85,54 +73,44 @@ class TailRec extends TreeTransform with DenotTransformer {
8573
final val labelPrefix = "tailLabel"
8674
final val labelFlags = Flags.Synthetic | Flags.Label
8775

88-
private def mkLabel(method: Symbol, tp: Type)(implicit c: Context): TermSymbol = {
76+
private def mkLabel(method: Symbol)(implicit c: Context): TermSymbol = {
8977
val name = c.freshName(labelPrefix)
90-
c.newSymbol(method, name.toTermName, labelFlags , tp)
78+
79+
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass))
9180
}
9281

9382
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
9483
tree match {
9584
case dd@DefDef(mods, name, tparams, vparamss0, tpt, rhs0)
96-
if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree)) =>
85+
if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree) || (dd.symbol is Flags.Label)) =>
9786
val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass)
9887
cpy.DefDef(tree, mods, name, tparams, vparamss0, tpt, rhs = {
99-
val owner = ctx.owner.enclosingClass.asClass
10088

101-
val thisTpe = owner.thisType
102-
103-
val newType: Type = dd.tpe.widen match {
104-
case t: PolyType => PolyType(t.paramNames)(x => t.paramBounds,
105-
x => MethodType(List(nme.THIS), List(thisTpe), t.resultType))
106-
case t => MethodType(List(nme.THIS), List(thisTpe), t)
107-
}
89+
val origMeth = tree.symbol
90+
val label = mkLabel(dd.symbol)
91+
val owner = ctx.owner.enclosingClass.asClass
92+
val thisTpe = owner.thisType.widen
10893

109-
val label = mkLabel(dd.symbol, newType)
11094
var rewrote = false
11195

11296
// Note: this can be split in two separate transforms(in different groups),
11397
// than first one will collect info about which transformations and rewritings should be applied
11498
// and second one will actually apply,
11599
// now this speculatively transforms tree and throws away result in many cases
116-
val res = tpd.DefDef(label, args => {
117-
val thiz = args.head.head
118-
val argMapping: Map[Symbol, Tree] = (vparamss0.flatten.map(_.symbol) zip args.tail.flatten).toMap
119-
val transformer = new TailRecElimination(dd.symbol, thiz, argMapping, owner, mandatory, label)
100+
val rhsSemiTransformed = {
101+
val transformer = new TailRecElimination(dd.symbol, owner, thisTpe, mandatory, label)
120102
val rhs = transformer.transform(rhs0)(ctx.withPhase(ctx.phase.next))
121103
rewrote = transformer.rewrote
122104
rhs
123-
})
105+
}
124106

125107
if (rewrote) {
126-
val call =
127-
if (tparams.isEmpty) Ident(label.termRef)
128-
else TypeApply(Ident(label.termRef), tparams)
129-
Block(
130-
List(res),
131-
vparamss0.foldLeft(Apply(call, List(This(owner))))
132-
{(call, args) => Apply(call, args.map(x => Ident(x.symbol.termRef)))}
133-
)
134-
}
135-
else {
108+
val dummyDefDef = cpy.DefDef(tree, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt,
109+
rhsSemiTransformed)
110+
val res = fullyParameterizedDef(label, dummyDefDef)
111+
val call = forwarder(label, dd)
112+
Block(List(res), call)
113+
} else {
136114
if (mandatory)
137115
ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos)
138116
rhs0
@@ -149,11 +127,9 @@ class TailRec extends TreeTransform with DenotTransformer {
149127

150128
}
151129

152-
class TailRecElimination(method: Symbol, thiz: Tree, argMapping: Map[Symbol, Tree],
153-
enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
154-
155-
import tpd._
130+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol) extends tpd.RetypingTreeMap {
156131

132+
import dotty.tools.dotc.ast.tpd._
157133

158134
var rewrote = false
159135

@@ -182,7 +158,6 @@ class TailRec extends TreeTransform with DenotTransformer {
182158
def noTailTransforms(trees: List[Tree])(implicit c: Context) =
183159
trees map (noTailTransform)
184160

185-
186161
override def transform(tree: Tree)(implicit c: Context): Tree = {
187162
/* A possibly polymorphic apply to be considered for tail call transformation. */
188163
def rewriteApply(tree: Tree, sym: Symbol): Tree = {
@@ -204,7 +179,7 @@ class TailRec extends TreeTransform with DenotTransformer {
204179

205180
val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen
206181
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen
207-
val receiverIsThis = recv.tpe.widen =:= thiz.tpe.widen
182+
val receiverIsThis = recv.tpe.widen =:= thisType
208183

209184
val isRecursiveCall = (method eq sym)
210185

@@ -226,17 +201,24 @@ class TailRec extends TreeTransform with DenotTransformer {
226201
def rewriteTailCall(recv: Tree): Tree = {
227202
c.debuglog("Rewriting tail recursive call: " + tree.pos)
228203
rewrote = true
229-
val method = if (targs.nonEmpty) TypeApply(Ident(label.termRef), targs) else Ident(label.termRef)
230-
val recv = noTailTransform(reciever)
231-
if (recv.tpe.widen.isParameterless) method
232-
else argumentss.foldLeft(Apply(method, List(recv))) {
233-
(method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet.
204+
val reciever = noTailTransform(recv)
205+
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
206+
val trz = classTypeArgs.map(x => ref(x.typeSymbol))
207+
val callTargs: List[tpd.Tree] = targs ::: trz
208+
val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef),
209+
List(reciever))
210+
211+
val res =
212+
if (method.tpe.widen.isParameterless) method
213+
else argumentss.foldLeft(method) {
214+
(met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
234215
}
216+
res
235217
}
236218

237219
if (isRecursiveCall) {
238220
if (ctx.tailPos) {
239-
if (recv eq EmptyTree) rewriteTailCall(thiz)
221+
if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
240222
else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
241223
else fail("it changes type of 'this' on a polymorphic recursive call")
242224
}
@@ -247,7 +229,7 @@ class TailRec extends TreeTransform with DenotTransformer {
247229
}
248230
}
249231

250-
def rewriteTry(tree: Try): Tree = {
232+
def rewriteTry(tree: Try): Try = {
251233
def transformHandlers(t: Tree): Tree = {
252234
t match {
253235
case Block(List((d: DefDef)), cl@Closure(Nil, _, EmptyTree)) =>
@@ -274,65 +256,61 @@ class TailRec extends TreeTransform with DenotTransformer {
274256
}
275257

276258
val res: Tree = tree match {
277-
case Block(stats, expr) =>
278-
tpd.cpy.Block(tree,
259+
260+
case tree@Block(stats, expr) =>
261+
val tree1 = tpd.cpy.Block(tree,
279262
noTailTransforms(stats),
280263
transform(expr)
281264
)
265+
propagateType(tree, tree1)
282266

283-
case t@CaseDef(pat, guard, body) =>
284-
cpy.CaseDef(t, pat, guard, transform(body))
267+
case tree@CaseDef(pat, guard, body) =>
268+
val tree1 = cpy.CaseDef(tree, pat, guard, transform(body))
269+
propagateType(tree, tree1)
285270

286-
case If(cond, thenp, elsep) =>
287-
tpd.cpy.If(tree,
288-
transform(cond),
271+
case tree@If(cond, thenp, elsep) =>
272+
val tree1 = tpd.cpy.If(tree,
273+
noTailTransform(cond),
289274
transform(thenp),
290275
transform(elsep)
291276
)
277+
propagateType(tree, tree1)
292278

293-
case Match(selector, cases) =>
294-
tpd.cpy.Match(tree,
279+
case tree@Match(selector, cases) =>
280+
val tree1 = tpd.cpy.Match(tree,
295281
noTailTransform(selector),
296282
transformSub(cases)
297283
)
284+
propagateType(tree, tree1)
298285

299-
case t: Try =>
300-
rewriteTry(t)
286+
case tree: Try =>
287+
val tree1 = rewriteTry(tree)
288+
propagateType(tree, tree1)
301289

302290
case Apply(fun, args) if fun.symbol == defn.Boolean_or || fun.symbol == defn.Boolean_and =>
303291
tpd.cpy.Apply(tree, fun, transform(args))
304292

305293
case Apply(fun, args) =>
306294
rewriteApply(tree, fun.symbol)
295+
307296
case Alternative(_) | Bind(_, _) =>
308297
assert(false, "We should've never gotten inside a pattern")
309298
tree
310-
case This(cls) if cls eq enclosingClass =>
311-
thiz
312-
case Select(qual, name) =>
299+
300+
case tree: Select =>
313301
val sym = tree.symbol
314302
if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
315-
else tpd.cpy.Select(tree, noTailTransform(qual), name)
303+
else propagateType(tree, tpd.cpy.Select(tree, noTailTransform(tree.qualifier), tree.name))
304+
316305
case ValDef(_, _, _, _) | EmptyTree | Super(_, _) | This(_) |
317306
Literal(_) | TypeTree(_) | DefDef(_, _, _, _, _, _) | TypeDef(_, _, _) =>
318307
tree
308+
319309
case Ident(qual) =>
320310
val sym = tree.symbol
321311
if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
322-
else argMapping.get(sym) match {
323-
case Some(rewrite) => rewrite
324-
case None => tree.tpe match {
325-
case TermRef(ThisType(`enclosingClass`), _) =>
326-
if (sym.flags is Flags.Local) {
327-
// trying to access private[this] member. toggle flag in order to access.
328-
val d = sym.denot
329-
val newDenot = d.copySymDenotation(initFlags = sym.flags &~ Flags.Local)
330-
newDenot.installAfter(TailRec.this)
331-
}
332-
thiz.select(sym)
333-
case _ => tree
334-
}
335-
}
312+
else tree
313+
336314
case _ =>
337315
super.transform(tree)
338316
}
@@ -341,6 +319,11 @@ class TailRec extends TreeTransform with DenotTransformer {
341319
}
342320
}
343321

322+
/** If references to original `target` from fully parameterized method `derived` should be
323+
* rewired to some fully parameterized method, that method symbol,
324+
* otherwise NoSymbol.
325+
*/
326+
override protected def rewiredTarget(target: Symbol, derived: Symbol)(implicit ctx: Context): Symbol = NoSymbol
344327
}
345328

346329
object TailRec {

0 commit comments

Comments
 (0)