diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala index 17d8d8fef925..6e822a0592b2 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala @@ -2050,17 +2050,17 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend }} val argVals = argVals0.reverse val argRefs = argRefs0.reverse - def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match { + val expectedSig = Signature.NotAMethod.prependTermParams(argRefs.map(_.tpe), false) + def rec(fn: Tree, topAscription: Option[TypeTree]): Option[Tree] = fn match { case Typed(expr, tpt) => - // we need to retain any type ascriptions we see and: - // a) if we succeed, ascribe the result type of the ascription to the inlined body - // b) if we fail, re-ascribe the same type to whatever it was we couldn't inline + // we need to retain any type ascriptions we see and if we succeed, + // ascribe the result type of the ascription to the inlined body // note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects rec(expr, topAscription.orElse(Some(tpt))) case Inlined(call, bindings, expansion) => // this case must go before closureDef to avoid dropping the inline node - cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription)) - case closureDef(ddef) => + rec(expansion, topAscription).map(cpy.Inlined(fn)(call, bindings, _)) + case cl @ closureDef(ddef) if defn.isFunctionType(cl.tpe) => val paramSyms = ddef.vparamss.head.map(param => param.symbol) val paramToVals = paramSyms.zip(argRefs).toMap val result = new TreeTypeMap( @@ -2070,24 +2070,26 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend ).transform(ddef.rhs) topAscription match { case Some(tpt) => - // we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys) - val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType] + // we checked that this is a plain Function closure, so there will be an apply method with a MethodType + // and the expected signature based on param types + val methodType = tpt.tpe.member(nme.apply).atSignature(expectedSig).info.asInstanceOf[MethodType] // result might contain paramrefs, so we substitute them with arg termrefs val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe)) - Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span) + Some(Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)) case None => - result + Some(result) } case tpd.Block(stats, expr) => - seq(stats, rec(expr, topAscription)).withSpan(fn.span) + rec(expr, topAscription).map(seq(stats, _).withSpan(fn.span)) case _ => - val maybeAscribed = topAscription match { - case Some(tpt) => Typed(fn, tpt).withSpan(fn.span) - case None => fn - } - maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span) + None + } + rec(fn, None) match { + case Some(result) => seq(argVals, result) + case None => + val expectedSig = Signature.NotAMethod.prependTermParams(args.map(_.tpe), false) + fn.selectWithSig(nme.apply, expectedSig).appliedToArgs(args).withSpan(fn.span) } - seq(argVals, rec(fn, None)) } ///////////// diff --git a/tests/run-macros/beta-reduce-inline-result.check b/tests/run-macros/beta-reduce-inline-result.check index 082514df02f7..3735f7520b82 100644 --- a/tests/run-macros/beta-reduce-inline-result.check +++ b/tests/run-macros/beta-reduce-inline-result.check @@ -3,3 +3,6 @@ run-time: 4 compile-time: 1 run-time: 1 run-time: 5 +run-time: 7 +run-time: -1 +run-time: 9 diff --git a/tests/run-macros/beta-reduce-inline-result/Test_2.scala b/tests/run-macros/beta-reduce-inline-result/Test_2.scala index 978b3e5d2f41..fed7f3dbb96b 100644 --- a/tests/run-macros/beta-reduce-inline-result/Test_2.scala +++ b/tests/run-macros/beta-reduce-inline-result/Test_2.scala @@ -14,6 +14,36 @@ object Test { inline def dummy4: Int => Int = ??? + object I extends (Int => Int) { + def apply(i: Int): i.type = i + } + + abstract class II extends (Int => Int) { + val apply = 123 + } + + inline def dummy5: II = + (i: Int) => i + 1 + + abstract class III extends (Int => Int) { + def impl(i: Int): Int + def apply(i: Int): Int = -1 + } + + inline def dummy6: III = + (i: Int) => i + 1 + + abstract class IV extends (Int => Int) { + def apply(s: String): String + } + + abstract class V extends IV { + def apply(s: String): String = "gotcha" + } + + inline def dummy7: IV = + { (i: Int) => i + 1 } : V + def main(argv : Array[String]) : Unit = { println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}") println(s"run-time: ${Macros.betaReduce(dummy1)(3)}") @@ -27,7 +57,21 @@ object Test { def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6) // make sure paramref types work when inlining is not possible - println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}") + println(s"run-time: ${Macros.betaReduce(I)(5)}") + + // -- cases below are non-function types, which are currently not inlined for simplicity but may be in the future + // (also, this tests that we return something valid when we see a closure that we can't inline) + + // A non-function type with an apply value that can be confused with the apply method. + println(s"run-time: ${Macros.betaReduce(dummy5)(6)}") + + // should print -1 (without inlining), because the apparent apply method actually + // has nothing to do with the function literal + println(s"run-time: ${Macros.betaReduce(dummy6)(7)}") + + // the literal does contain the implementation of the apply method, but there are two abstract apply methods + // in the outermost abstract type + println(s"run-time: ${Macros.betaReduce(dummy7)(8)}") } } diff --git a/tests/run-macros/quote-inline-function.check b/tests/run-macros/quote-inline-function.check index 958a2c455b9d..208b29543304 100644 --- a/tests/run-macros/quote-inline-function.check +++ b/tests/run-macros/quote-inline-function.check @@ -3,13 +3,11 @@ Normal function var i: scala.Int = 0 val j: scala.Int = 5 while (i.<(j)) { - val x$1: scala.Int = i - f.apply(x$1) + f.apply(i) i = i.+(1) } while ({ - val x$2: scala.Int = i - f.apply(x$2) + f.apply(i) i = i.+(1) i.<(j) }) () @@ -20,13 +18,11 @@ By name function var i: scala.Int = 0 val j: scala.Int = 5 while (i.<(j)) { - val x$3: scala.Int = i - f.apply(x$3) + f.apply(i) i = i.+(1) } while ({ - val x$4: scala.Int = i - f.apply(x$4) + f.apply(i) i = i.+(1) i.<(j) }) () diff --git a/tests/run-staging/i3876-c.check b/tests/run-staging/i3876-c.check index dca23bcfdf11..38c85ed40818 100644 --- a/tests/run-staging/i3876-c.check +++ b/tests/run-staging/i3876-c.check @@ -6,5 +6,5 @@ (f: scala.Function1[scala.Int, scala.Int] { def apply(x: scala.Int): scala.Int - }).apply(3) -} + }) +}.apply(3)