diff --git a/src/dotty/tools/dotc/transform/FullParameterization.scala b/src/dotty/tools/dotc/transform/FullParameterization.scala index e9057e885c47..be64df384063 100644 --- a/src/dotty/tools/dotc/transform/FullParameterization.scala +++ b/src/dotty/tools/dotc/transform/FullParameterization.scala @@ -12,6 +12,8 @@ import NameOps._ import ast._ import ast.Trees._ +import scala.reflect.internal.util.Collections + /** Provides methods to produce fully parameterized versions of instance methods, * where the `this` of the enclosing class is abstracted out in an extra leading * `$this` parameter and type parameters of the class become additional type @@ -86,9 +88,12 @@ trait FullParameterization { * } * * If a self type is present, $this has this self type as its type. + * * @param abstractOverClass if true, include the type parameters of the class in the method's list of type parameters. + * @param liftThisType if true, require created $this to be $this: (Foo[A] & Foo,this). + * This is needed if created member stays inside scope of Foo(as in tailrec) */ - def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true)(implicit ctx: Context): Type = { + def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Type = { val (mtparamCount, origResult) = info match { case info @ PolyType(mtnames) => (mtnames.length, info.resultType) case info: ExprType => (0, info.resultType) @@ -100,7 +105,8 @@ trait FullParameterization { /** The method result type */ def resultType(mapClassParams: Type => Type) = { val thisParamType = mapClassParams(clazz.classInfo.selfType) - MethodType(nme.SELF :: Nil, thisParamType :: Nil)(mt => + val firstArgType = if (liftThisType) thisParamType & clazz.thisType else thisParamType + MethodType(nme.SELF :: Nil, firstArgType :: Nil)(mt => mapClassParams(origResult).substThisUnlessStatic(clazz, MethodParam(mt, 0))) } @@ -217,12 +223,26 @@ trait FullParameterization { * - the `this` of the enclosing class, * - the value parameters of the original method `originalDef`. */ - def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree = - ref(derived.termRef) - .appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef)) - .appliedTo(This(originalDef.symbol.enclosingClass.asClass)) - .appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol))) - .withPos(originalDef.rhs.pos) + def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true, liftThisType: Boolean = false)(implicit ctx: Context): Tree = { + val fun = + ref(derived.termRef) + .appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef)) + .appliedTo(This(originalDef.symbol.enclosingClass.asClass)) + + (if (!liftThisType) + fun.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol))) + else { + // this type could have changed on forwarding. Need to insert a cast. + val args = Collections.map2(originalDef.vparamss, fun.tpe.paramTypess)((vparams, paramTypes) => + Collections.map2(vparams, paramTypes)((vparam, paramType) => { + assert(vparam.tpe <:< paramType.widen) // type should still conform to widened type + ref(vparam.symbol).ensureConforms(paramType) + }) + ) + fun.appliedToArgss(args) + + }).withPos(originalDef.rhs.pos) + } } object FullParameterization { diff --git a/src/dotty/tools/dotc/transform/TailRec.scala b/src/dotty/tools/dotc/transform/TailRec.scala index 58fe7a6c909c..23686b522be3 100644 --- a/src/dotty/tools/dotc/transform/TailRec.scala +++ b/src/dotty/tools/dotc/transform/TailRec.scala @@ -1,7 +1,7 @@ package dotty.tools.dotc.transform import dotty.tools.dotc.ast.Trees._ -import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.{TreeTypeMap, tpd} import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Decorators._ import dotty.tools.dotc.core.DenotTransformers.DenotTransformer @@ -10,13 +10,12 @@ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.core._ import dotty.tools.dotc.transform.TailRec._ -import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform} +import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo} /** * A Tail Rec Transformer - * * @author Erik Stenman, Iulian Dragos, - * ported to dotty by Dmitry Petrashko + * ported and heavily modified for dotty by Dmitry Petrashko * @version 1.1 * * What it does: @@ -77,7 +76,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = { val name = c.freshName(labelPrefix) - c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass)) + if (method.owner.isClass) + c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false)) + else c.newSymbol(method, name.toTermName, labelFlags, method.info) } override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { @@ -103,7 +104,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete // and second one will actually apply, // now this speculatively transforms tree and throws away result in many cases val rhsSemiTransformed = { - val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel) + val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel) val rhs = atGroupEnd(transformer.transform(dd.rhs)(_)) rewrote = transformer.rewrote rhs @@ -111,10 +112,25 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete if (rewrote) { val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed) - val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel) - val call = forwarder(label, dd, abstractOverClass = defIsTopLevel) - Block(List(res), call) - } else { + if (tree.symbol.owner.isClass) { + val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel) + val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true) + Block(List(labelDef), call) + } else { // inner method. Tail recursion does not change `this` + val labelDef = polyDefDef(label, trefs => vrefss => { + val origMeth = tree.symbol + val origTParams = tree.tparams.map(_.symbol) + val origVParams = tree.vparamss.flatten map (_.symbol) + new TreeTypeMap( + typeMap = identity(_) + .substDealias(origTParams, trefs) + .subst(origVParams, vrefss.flatten.map(_.tpe)), + oldOwners = origMeth :: Nil, + newOwners = label :: Nil + ).transform(rhsSemiTransformed) + }) + Block(List(labelDef), ref(label).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))) + }} else { if (mandatory) ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos) dd.rhs @@ -132,7 +148,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete } - class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap { + class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap { import dotty.tools.dotc.ast.tpd._ @@ -175,8 +191,9 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete case x => (x, x, accArgs, accT, x.symbol) } - val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) - val recv = noTailTransform(reciever) + val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) + val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe} + val recv = noTailTransform(prefix) val targs = typeArguments.map(noTailTransform) val argumentss = arguments.map(noTailTransforms) @@ -215,20 +232,21 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete targs ::: classTypeArgs.map(x => ref(x.typeSymbol)) } else targs - val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef), - List(receiver)) + val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef) + val thisPassed = if(this.method.owner.isClass) method appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head)) else method val res = - if (method.tpe.widen.isParameterless) method - else argumentss.foldLeft(method) { - (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. - } + if (thisPassed.tpe.widen.isParameterless) thisPassed + else argumentss.foldLeft(thisPassed) { + (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. + } res } if (isRecursiveCall) { if (ctx.tailPos) { - if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) + if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call") + else if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) else fail("it changes type of 'this' on a polymorphic recursive call") } diff --git a/tests/neg/tailcall/t6574.scala b/tests/neg/tailcall/t6574.scala index 7030b3b4ad05..d9ba2882ddab 100644 --- a/tests/neg/tailcall/t6574.scala +++ b/tests/neg/tailcall/t6574.scala @@ -4,7 +4,7 @@ class Bad[X, Y](val v: Int) extends AnyVal { println("tail") } - @annotation.tailrec final def differentTypeArgs : Unit = { - {(); new Bad[String, Unit](0)}.differentTypeArgs + @annotation.tailrec final def differentTypeArgs : Unit = { // error + {(); new Bad[String, Unit](0)}.differentTypeArgs // error } } diff --git a/tests/pos/tailcall/i1089.scala b/tests/pos/tailcall/i1089.scala new file mode 100644 index 000000000000..8eb69cb9bb75 --- /dev/null +++ b/tests/pos/tailcall/i1089.scala @@ -0,0 +1,26 @@ +package hello + +import scala.annotation.tailrec + +class Enclosing { + class SomeData(val x: Int) + + def localDef(): Unit = { + def foo(data: SomeData): Int = data.x + + @tailrec + def test(i: Int, data: SomeData): Unit = { + if (i != 0) { + println(foo(data)) + test(i - 1, data) + } + } + + test(3, new SomeData(42)) + } +} + +object world extends App { + println("hello dotty!") + new Enclosing().localDef() +}