diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 149100a4fd6b..40d3a74cbe39 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -95,6 +95,7 @@ class Compiler { new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify. List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types + new PureStats, // Remove pure stats from blocks new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]` new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses diff --git a/compiler/src/dotty/tools/dotc/transform/PureStats.scala b/compiler/src/dotty/tools/dotc/transform/PureStats.scala new file mode 100644 index 000000000000..82b495170576 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/PureStats.scala @@ -0,0 +1,32 @@ +package dotty.tools.dotc +package transform + +import ast.{Trees, tpd} +import core._, core.Decorators._ +import MegaPhase._ +import Types._, Contexts._, Flags._, DenotTransformers._ +import Symbols._, StdNames._, Trees._ + +object PureStats { + val name: String = "pureStats" +} + +/** Remove pure statements in blocks */ +class PureStats extends MiniPhase { + + import tpd._ + + override def phaseName: String = PureStats.name + + override def runsAfter: Set[String] = Set(Erasure.name) + + override def transformBlock(tree: Block)(implicit ctx: Context): Tree = + val stats = tree.stats.mapConserve { + case Typed(Block(stats, expr), _) if isPureExpr(expr) => Thicket(stats) + case stat if !stat.symbol.isConstructor && isPureExpr(stat) => EmptyTree + case stat => stat + } + if stats eq tree.stats then tree + else cpy.Block(tree)(Trees.flatten(stats), tree.expr) + +} diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index d50a88056756..4b4fb328e75d 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -418,4 +418,70 @@ class InlineBytecodeTests extends DottyBytecodeTest { } } + + @Test def i6800a = { + val source = """class Foo: + | inline def inlined(f: => Unit): Unit = f + | def test: Unit = inlined { println("") } + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List(Invoke(INVOKESTATIC, "Foo", "f$1", "()V", false), Op(RETURN)) + assert(instructions == expected, + "`inlined` was not properly inlined in `test`\n" + diffInstructions(instructions, expected)) + + } + } + + @Test def i6800b = { + val source = """class Foo: + | inline def printIfZero(x: Int): Unit = inline x match + | case 0 => println("zero") + | case _ => () + | def test: Unit = printIfZero(0) + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List( + Field(GETSTATIC, "scala/Predef$", "MODULE$", "Lscala/Predef$;"), + Ldc(LDC, "zero"), + Invoke(INVOKEVIRTUAL, "scala/Predef$", "println", "(Ljava/lang/Object;)V", false), + Op(RETURN) + ) + assert(instructions == expected, + "`printIfZero` was not properly inlined in `test`\n" + diffInstructions(instructions, expected)) + } + } + + + @Test def i9246 = { + val source = """class Foo: + | inline def check(v:Double): Unit = if(v==0) throw new Exception() + | inline def divide(v: Double, d: Double): Double = { check(d); v / d } + | def test = divide(10,2) + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + val fun = getMethod(clsNode, "test") + val instructions = instructionsFromMethod(fun) + val expected = List(Ldc(LDC, 5.0), Op(DRETURN)) + assert(instructions == expected, + "`divide` was not properly inlined in `test`\n" + diffInstructions(instructions, expected)) + + } + } + }