From 30b35f19828306d4504ad60d6194373a2f4f1395 Mon Sep 17 00:00:00 2001 From: fhackett Date: Tue, 11 Feb 2020 16:05:59 -0500 Subject: [PATCH] Fix #8290: Make Expr.betaReduce give up when it sees a non-function typed closure expression This commit addresses 2 issues with existing betaReduce behaviour: - when given a non-function typed closure the previous iteration could easily fail to resolve the correct apply method, or even successfully inline the wrong code (see added test cases) - if betaReduce did not successfully inline, it would return a transformed tree. This was fine until the above change made it possible to give up while inside a closureDef, which could insert a type ascription inside the closureDef's block, leading to betaReduce returning invalid trees (the closureDef block can only contain a DefDef and Closure, no type ascriptions). Fixing this issue would add meaningless complexity, so instead this commit changes betaReduce to cleanly give up by returning the function tree unchanged, only generating the code necessary to call it. Note: this change affects a few tests that were checking for betaReduce's slight changes to the function tree. Testing the correctness of this change is done by adding cases to existing tests for betaReduce's treatment of type ascriptions. --- .../ReflectionCompilerInterface.scala | 36 ++++++++------- .../beta-reduce-inline-result.check | 3 ++ .../beta-reduce-inline-result/Test_2.scala | 46 ++++++++++++++++++- tests/run-macros/quote-inline-function.check | 12 ++--- tests/run-staging/i3876-c.check | 4 +- 5 files changed, 73 insertions(+), 28 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala index 17d8d8fef925..6e822a0592b2 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala @@ -2050,17 +2050,17 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend }} val argVals = argVals0.reverse val argRefs = argRefs0.reverse - def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match { + val expectedSig = Signature.NotAMethod.prependTermParams(argRefs.map(_.tpe), false) + def rec(fn: Tree, topAscription: Option[TypeTree]): Option[Tree] = fn match { case Typed(expr, tpt) => - // we need to retain any type ascriptions we see and: - // a) if we succeed, ascribe the result type of the ascription to the inlined body - // b) if we fail, re-ascribe the same type to whatever it was we couldn't inline + // we need to retain any type ascriptions we see and if we succeed, + // ascribe the result type of the ascription to the inlined body // note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects rec(expr, topAscription.orElse(Some(tpt))) case Inlined(call, bindings, expansion) => // this case must go before closureDef to avoid dropping the inline node - cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription)) - case closureDef(ddef) => + rec(expansion, topAscription).map(cpy.Inlined(fn)(call, bindings, _)) + case cl @ closureDef(ddef) if defn.isFunctionType(cl.tpe) => val paramSyms = ddef.vparamss.head.map(param => param.symbol) val paramToVals = paramSyms.zip(argRefs).toMap val result = new TreeTypeMap( @@ -2070,24 +2070,26 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend ).transform(ddef.rhs) topAscription match { case Some(tpt) => - // we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys) - val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType] + // we checked that this is a plain Function closure, so there will be an apply method with a MethodType + // and the expected signature based on param types + val methodType = tpt.tpe.member(nme.apply).atSignature(expectedSig).info.asInstanceOf[MethodType] // result might contain paramrefs, so we substitute them with arg termrefs val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe)) - Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span) + Some(Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)) case None => - result + Some(result) } case tpd.Block(stats, expr) => - seq(stats, rec(expr, topAscription)).withSpan(fn.span) + rec(expr, topAscription).map(seq(stats, _).withSpan(fn.span)) case _ => - val maybeAscribed = topAscription match { - case Some(tpt) => Typed(fn, tpt).withSpan(fn.span) - case None => fn - } - maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span) + None + } + rec(fn, None) match { + case Some(result) => seq(argVals, result) + case None => + val expectedSig = Signature.NotAMethod.prependTermParams(args.map(_.tpe), false) + fn.selectWithSig(nme.apply, expectedSig).appliedToArgs(args).withSpan(fn.span) } - seq(argVals, rec(fn, None)) } ///////////// diff --git a/tests/run-macros/beta-reduce-inline-result.check b/tests/run-macros/beta-reduce-inline-result.check index 082514df02f7..3735f7520b82 100644 --- a/tests/run-macros/beta-reduce-inline-result.check +++ b/tests/run-macros/beta-reduce-inline-result.check @@ -3,3 +3,6 @@ run-time: 4 compile-time: 1 run-time: 1 run-time: 5 +run-time: 7 +run-time: -1 +run-time: 9 diff --git a/tests/run-macros/beta-reduce-inline-result/Test_2.scala b/tests/run-macros/beta-reduce-inline-result/Test_2.scala index 978b3e5d2f41..fed7f3dbb96b 100644 --- a/tests/run-macros/beta-reduce-inline-result/Test_2.scala +++ b/tests/run-macros/beta-reduce-inline-result/Test_2.scala @@ -14,6 +14,36 @@ object Test { inline def dummy4: Int => Int = ??? + object I extends (Int => Int) { + def apply(i: Int): i.type = i + } + + abstract class II extends (Int => Int) { + val apply = 123 + } + + inline def dummy5: II = + (i: Int) => i + 1 + + abstract class III extends (Int => Int) { + def impl(i: Int): Int + def apply(i: Int): Int = -1 + } + + inline def dummy6: III = + (i: Int) => i + 1 + + abstract class IV extends (Int => Int) { + def apply(s: String): String + } + + abstract class V extends IV { + def apply(s: String): String = "gotcha" + } + + inline def dummy7: IV = + { (i: Int) => i + 1 } : V + def main(argv : Array[String]) : Unit = { println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}") println(s"run-time: ${Macros.betaReduce(dummy1)(3)}") @@ -27,7 +57,21 @@ object Test { def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6) // make sure paramref types work when inlining is not possible - println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}") + println(s"run-time: ${Macros.betaReduce(I)(5)}") + + // -- cases below are non-function types, which are currently not inlined for simplicity but may be in the future + // (also, this tests that we return something valid when we see a closure that we can't inline) + + // A non-function type with an apply value that can be confused with the apply method. + println(s"run-time: ${Macros.betaReduce(dummy5)(6)}") + + // should print -1 (without inlining), because the apparent apply method actually + // has nothing to do with the function literal + println(s"run-time: ${Macros.betaReduce(dummy6)(7)}") + + // the literal does contain the implementation of the apply method, but there are two abstract apply methods + // in the outermost abstract type + println(s"run-time: ${Macros.betaReduce(dummy7)(8)}") } } diff --git a/tests/run-macros/quote-inline-function.check b/tests/run-macros/quote-inline-function.check index 958a2c455b9d..208b29543304 100644 --- a/tests/run-macros/quote-inline-function.check +++ b/tests/run-macros/quote-inline-function.check @@ -3,13 +3,11 @@ Normal function var i: scala.Int = 0 val j: scala.Int = 5 while (i.<(j)) { - val x$1: scala.Int = i - f.apply(x$1) + f.apply(i) i = i.+(1) } while ({ - val x$2: scala.Int = i - f.apply(x$2) + f.apply(i) i = i.+(1) i.<(j) }) () @@ -20,13 +18,11 @@ By name function var i: scala.Int = 0 val j: scala.Int = 5 while (i.<(j)) { - val x$3: scala.Int = i - f.apply(x$3) + f.apply(i) i = i.+(1) } while ({ - val x$4: scala.Int = i - f.apply(x$4) + f.apply(i) i = i.+(1) i.<(j) }) () diff --git a/tests/run-staging/i3876-c.check b/tests/run-staging/i3876-c.check index dca23bcfdf11..38c85ed40818 100644 --- a/tests/run-staging/i3876-c.check +++ b/tests/run-staging/i3876-c.check @@ -6,5 +6,5 @@ (f: scala.Function1[scala.Int, scala.Int] { def apply(x: scala.Int): scala.Int - }).apply(3) -} + }) +}.apply(3)