From f50f72fe862a3bf956f2e00cf266b9cf078e45bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 13 Apr 2022 09:52:49 +0200 Subject: [PATCH 1/2] Add bytecode tests with the status quo of codegen control flow. --- .../backend/jvm/DottyBytecodeTests.scala | 377 ++++++++++++++++++ 1 file changed, 377 insertions(+) diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index a85c28a9f878..f1798b69e59a 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -1039,6 +1039,383 @@ class TestBCode extends DottyBytecodeTest { } } + @Test def patmatControlFlow(): Unit = { + val source = + s"""class Foo { + | def m1(xs: List[Int]): Int = xs match + | case x :: xr => x + | case Nil => 20 + | + | def m2(xs: List[Int]): Int = xs match + | case (1 | 2) :: xr => 10 + | case x :: xr => x + | case _ => 20 + |} + """.stripMargin + + checkBCode(source) { dir => + val fooClass = loadClassNode(dir.lookupName("Foo.class", directory = false).input) + + // --------------- + + val m1Meth = getMethod(fooClass, "m1") + + assertSameCode(m1Meth, List( + VarOp(ALOAD, 1), + VarOp(ASTORE, 2), + VarOp(ALOAD, 2), + TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), + Jump(IFEQ, Label(19)), + VarOp(ALOAD, 2), + TypeOp(CHECKCAST, "scala/collection/immutable/$colon$colon"), + VarOp(ASTORE, 3), + VarOp(ALOAD, 3), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "next$access$1", "()Lscala/collection/immutable/List;", false), + VarOp(ASTORE, 4), + VarOp(ALOAD, 3), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), + VarOp(ISTORE, 5), + VarOp(ALOAD, 4), + VarOp(ASTORE, 6), + VarOp(ILOAD, 5), + Jump(GOTO, Label(47)), + Label(19), + Field(GETSTATIC, "scala/package$", "MODULE$", "Lscala/package$;"), + Invoke(INVOKEVIRTUAL, "scala/package$", "Nil", "()Lscala/collection/immutable/Nil$;", false), + VarOp(ALOAD, 2), + VarOp(ASTORE, 7), + Op(DUP), + Jump(IFNONNULL, Label(31)), + Op(POP), + VarOp(ALOAD, 7), + Jump(IFNULL, Label(36)), + Jump(GOTO, Label(40)), + Label(31), + VarOp(ALOAD, 7), + Invoke(INVOKEVIRTUAL, "java/lang/Object", "equals", "(Ljava/lang/Object;)Z", false), + Jump(IFEQ, Label(40)), + Label(36), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(47)), + Label(40), + TypeOp(NEW, "scala/MatchError"), + Op(DUP), + VarOp(ALOAD, 2), + Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), + Op(ATHROW), + Label(47), + Op(IRETURN), + )) + + // --------------- + + val m2Meth = getMethod(fooClass, "m2") + + assertSameCode(m2Meth, List( + VarOp(ALOAD, 1), + VarOp(ASTORE, 2), + VarOp(ALOAD, 2), + TypeOp(INSTANCEOF, "scala/collection/immutable/$colon$colon"), + Jump(IFEQ, Label(42)), + VarOp(ALOAD, 2), + TypeOp(CHECKCAST, "scala/collection/immutable/$colon$colon"), + VarOp(ASTORE, 3), + VarOp(ALOAD, 3), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "head", "()Ljava/lang/Object;", false), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "unboxToInt", "(Ljava/lang/Object;)I", false), + VarOp(ISTORE, 4), + VarOp(ALOAD, 3), + Invoke(INVOKEVIRTUAL, "scala/collection/immutable/$colon$colon", "next$access$1", "()Lscala/collection/immutable/List;", false), + VarOp(ASTORE, 5), + Op(ICONST_1), + VarOp(ILOAD, 4), + Jump(IF_ICMPNE, Label(19)), + Jump(GOTO, Label(28)), + Label(19), + Op(ICONST_2), + VarOp(ILOAD, 4), + Jump(IF_ICMPNE, Label(25)), + Jump(GOTO, Label(28)), + Label(25), + Jump(GOTO, Label(34)), + Label(28), + VarOp(ALOAD, 5), + VarOp(ASTORE, 6), + IntOp(BIPUSH, 10), + Jump(GOTO, Label(46)), + Label(34), + VarOp(ILOAD, 4), + VarOp(ISTORE, 7), + VarOp(ALOAD, 5), + VarOp(ASTORE, 8), + VarOp(ILOAD, 7), + Jump(GOTO, Label(46)), + Label(42), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(46)), + Label(46), + Op(IRETURN), + )) + } + } + + @Test def switchControlFlow(): Unit = { + val source = + s"""import scala.annotation.switch + | + |class Foo { + | def m1(x: Int): Int = (x: @switch) match + | case 1 => 10 + | case 7 => 20 + | case 8 => 30 + | case 9 => 40 + | case _ => x + | + | def m2(x: Int): Int = (x: @switch) match + | case (1 | 2) => 10 + | case 7 => 20 + | case 8 => 30 + | case c if c > 100 => 20 + |} + """.stripMargin + + checkBCode(source) { dir => + val fooClass = loadClassNode(dir.lookupName("Foo.class", directory = false).input) + + // --------------- + + val m1Meth = getMethod(fooClass, "m1") + + assertSameCode(m1Meth, List( + VarOp(ILOAD, 1), + VarOp(ISTORE, 2), + VarOp(ILOAD, 2), + LookupSwitch(LOOKUPSWITCH, Label(40), List(1, 7, 8, 9), List(Label(4), Label(13), Label(22), Label(31))), + Label(4), + IntOp(BIPUSH, 10), + Jump(GOTO, Label(52)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(13), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(52)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(22), + IntOp(BIPUSH, 30), + Jump(GOTO, Label(52)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(31), + IntOp(BIPUSH, 40), + Jump(GOTO, Label(52)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(40), + VarOp(ILOAD, 1), + Jump(GOTO, Label(52)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Op(ATHROW), + Label(52), + Op(IRETURN), + )) + + // --------------- + + val m2Meth = getMethod(fooClass, "m2") + + assertSameCode(m2Meth, List( + VarOp(ILOAD, 1), + VarOp(ISTORE, 2), + VarOp(ILOAD, 2), + LookupSwitch(LOOKUPSWITCH, Label(31), List(1, 2, 7, 8), List(Label(4), Label(4), Label(13), Label(22))), + Label(4), + IntOp(BIPUSH, 10), + Jump(GOTO, Label(56)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(13), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(56)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(22), + IntOp(BIPUSH, 30), + Jump(GOTO, Label(56)), + Op(NOP), + Op(NOP), + Op(ATHROW), + Label(31), + VarOp(ILOAD, 2), + VarOp(ISTORE, 3), + VarOp(ILOAD, 3), + IntOp(BIPUSH, 100), + Jump(IF_ICMPLE, Label(40)), + IntOp(BIPUSH, 20), + Jump(GOTO, Label(56)), + Label(40), + TypeOp(NEW, "scala/MatchError"), + Op(DUP), + VarOp(ILOAD, 2), + Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false), + Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), + Op(ATHROW), + Op(NOP), + Op(NOP), + Op(ATHROW), + Op(ATHROW), + Label(56), + Op(IRETURN), + )) + } + } + + @Test def ifThenElseControlFlow(): Unit = { + /* This is a test case coming from the Scala.js linker, where in Scala 2 we + * had to introduce a "useless" `return` to make the bytecode size smaller, + * measurably increasing performance (!). + */ + + val source = + s"""import java.io.Writer + | + |final class SourceMapWriter(out: Writer) { + | private val Base64Map = + | "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + | "abcdefghijklmnopqrstuvwxyz" + + | "0123456789+/" + | + | private final val VLQBaseShift = 5 + | private final val VLQBase = 1 << VLQBaseShift + | private final val VLQBaseMask = VLQBase - 1 + | private final val VLQContinuationBit = VLQBase + | + | def entryPoint(value: Int): Unit = writeBase64VLQ(value) + | + | private def writeBase64VLQ(value0: Int): Unit = { + | val signExtended = value0 >> 31 + | val value = (((value0 ^ signExtended) - signExtended) << 1) | (signExtended & 1) + | if (value < 26) { + | out.write('A' + value) // was `return out...` + | } else { + | def writeBase64VLQSlowPath(value0: Int): Unit = { + | var value = value0 + | while ({ + | // do { + | var digit = value & VLQBaseMask + | value = value >>> VLQBaseShift + | if (value != 0) + | digit |= VLQContinuationBit + | out.write(Base64Map.charAt(digit)) + | // } while ( + | value != 0 + | // ) + | }) () + | } + | writeBase64VLQSlowPath(value) + | } + | } + |} + """.stripMargin + + checkBCode(source) { dir => + val sourceMapWriterClass = loadClassNode(dir.lookupName("SourceMapWriter.class", directory = false).input) + + // --------------- + + val writeBase64VLQMeth = getMethod(sourceMapWriterClass, "writeBase64VLQ") + + assertSameCode(writeBase64VLQMeth, List( + VarOp(ILOAD, 1), + IntOp(BIPUSH, 31), + Op(ISHR), + VarOp(ISTORE, 2), + VarOp(ILOAD, 1), + VarOp(ILOAD, 2), + Op(IXOR), + VarOp(ILOAD, 2), + Op(ISUB), + Op(ICONST_1), + Op(ISHL), + VarOp(ILOAD, 2), + Op(ICONST_1), + Op(IAND), + Op(IOR), + VarOp(ISTORE, 3), + VarOp(ILOAD, 3), + IntOp(BIPUSH, 26), + Jump(IF_ICMPGE, Label(26)), + VarOp(ALOAD, 0), + Field(GETFIELD, "SourceMapWriter", "out", "Ljava/io/Writer;"), + IntOp(BIPUSH, 65), + VarOp(ILOAD, 3), + Op(IADD), + Invoke(INVOKEVIRTUAL, "java/io/Writer", "write", "(I)V", false), + Jump(GOTO, Label(31)), + Label(26), + VarOp(ALOAD, 0), + VarOp(ILOAD, 3), + Invoke(INVOKESPECIAL, "SourceMapWriter", "writeBase64VLQSlowPath$1", "(I)V", false), + Label(31), + Op(RETURN), + )) + + // --------------- + + val writeBase64VLQSlowPathMeth = getMethod(sourceMapWriterClass, "writeBase64VLQSlowPath$1") + + assertSameCode(writeBase64VLQSlowPathMeth, List( + VarOp(ILOAD, 1), + VarOp(ISTORE, 2), + Label(2), + VarOp(ILOAD, 2), + IntOp(BIPUSH, 31), + Op(IAND), + VarOp(ISTORE, 3), + VarOp(ILOAD, 2), + Op(ICONST_5), + Op(IUSHR), + VarOp(ISTORE, 2), + VarOp(ILOAD, 2), + Op(ICONST_0), + Jump(IF_ICMPEQ, Label(19)), + VarOp(ILOAD, 3), + IntOp(BIPUSH, 32), + Op(IOR), + VarOp(ISTORE, 3), + Label(19), + VarOp(ALOAD, 0), + Field(GETFIELD, "SourceMapWriter", "out", "Ljava/io/Writer;"), + Field(GETSTATIC, "scala/Char$", "MODULE$", "Lscala/Char$;"), + VarOp(ALOAD, 0), + Field(GETFIELD, "SourceMapWriter", "Base64Map", "Ljava/lang/String;"), + VarOp(ILOAD, 3), + Invoke(INVOKEVIRTUAL, "java/lang/String", "charAt", "(I)C", false), + Invoke(INVOKEVIRTUAL, "scala/Char$", "char2int", "(C)I", false), + Invoke(INVOKEVIRTUAL, "java/io/Writer", "write", "(I)V", false), + VarOp(ILOAD, 2), + Op(ICONST_0), + Jump(IF_ICMPEQ, Label(35)), + Op(ICONST_1), + Jump(GOTO, Label(38)), + Label(35), + Op(ICONST_0), + Label(38), + Jump(IFNE, Label(2)), + Op(RETURN), + )) + } + } + @Test def getClazz: Unit = { val source = """ From 4a2889f93a46372c3551beef8b17d4ba8f289ddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 13 Apr 2022 11:26:33 +0200 Subject: [PATCH 2/2] Use explicit destinations in codegen to avoid uselessly jumping around. Previously, the codegen's main method `genLoad` always generated code that loaded the value on the stack before continuing. There were a number of situations where `genLoad` would be directly followed by unconditional jumps to instructions performing more jumps, returns and throws. This generated more spurious jumps than necessary, along with artifact dead code. We solve these limitations by introducing `LoadDestination`s that specify the destination of a loaded value: * FallThrough: as previously, load the value on the stack and continue. * Jump(label): load the value on the stack and jump to the given label. * Return: return the value from the enclosing method. * Throw: throw the value. We generalize `genLoad` as `genLoadTo`, taking a specific destination for the loaded value. `genLoadTo` can "push down" its destination into all control flow structures (except `Try`s, because of their cleanups). With that, when we get to the end of what amounts to "basic blocks", we know exactly the ultimate destination of the loaded value. We can therefore directly jump, return or throw to the final destination. This produces less bytecode, notably because fewer labels are necessary. For example, the method: def abs(x: Int): Int = if x < 0 then -x else x previously generated bytecode like ILOAD 1 ICONST_0 IF_ICMPGE Label(1) ILOAD 1 INEG GOTO Label(2) Label(1): ILOAD 1 Label(2): IRETURN Now, instead of jumping to Label(2), we directly perform an IRETURN: ILOAD 1 ICONST_0 IF_ICMPGE Label(1) ILOAD 1 INEG IRETURN Label(1): ILOAD 1 IRETURN While the changes are not very impressive on that simple example, they become more important in more complex cases, notably with pattern matching. Examples can be found in the changed bytecode tests. An added benefit is that `genLoadTo` knows when loading a value results in an unconditional control flow change (jump, return or throw). It can then avoid inserting any useless adaptation. This removes all the dead bytecode that the codegen used to generate as artifacts of its own compilation scheme. (It will still generate dead bytecode if the original source code/inlined code contains dead code.) --- .../tools/backend/jvm/BCodeBodyBuilder.scala | 220 ++++++++++-------- .../tools/backend/jvm/BCodeSkelBuilder.scala | 64 ++--- .../tools/backend/jvm/BCodeSyncAndTry.scala | 2 +- .../backend/jvm/DottyBytecodeTests.scala | 107 +++------ 4 files changed, 192 insertions(+), 201 deletions(-) diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala index a7de18f7b4b2..6b9cccec7967 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala @@ -113,16 +113,6 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } } - def genThrow(expr: Tree): Unit = { - val thrownKind = tpeTK(expr) - // `throw null` is valid although scala.Null (as defined in src/libray-aux) isn't a subtype of Throwable. - // Similarly for scala.Nothing (again, as defined in src/libray-aux). - assert(thrownKind.isNullType || thrownKind.isNothingType || thrownKind.asClassBType.isSubtypeOf(ThrowableReference)) - genLoad(expr, thrownKind) - lineNumber(expr) - emit(asm.Opcodes.ATHROW) // ICode enters here into enterIgnoreMode, we'll rely instead on DCE at ClassNode level. - } - /* Generate code for primitive arithmetic operations. */ def genArithmeticOp(tree: Tree, code: Int): BType = tree match{ case Apply(fun @ DesugaredSelect(larg, _), args) => @@ -211,7 +201,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { generatedType } - def genLoadIf(tree: If, expectedType: BType): BType = tree match{ + def genLoadIfTo(tree: If, expectedType: BType, dest: LoadDestination): BType = tree match{ case If(condp, thenp, elsep) => val success = new asm.Label @@ -221,25 +211,37 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { case Literal(value) if value.tag == UnitTag => false case _ => true }) - val postIf = if (hasElse) new asm.Label else failure genCond(condp, success, failure, targetIfNoJump = success) markProgramPoint(success) - val thenKind = tpeTK(thenp) - val elseKind = if (!hasElse) UNIT else tpeTK(elsep) - def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT) && expectedType == UNIT - val resKind = if (hasUnitBranch) UNIT else tpeTK(tree) - - genLoad(thenp, resKind) - if (hasElse) { bc goTo postIf } - markProgramPoint(failure) - if (hasElse) { - genLoad(elsep, resKind) - markProgramPoint(postIf) - } - - resKind + if dest == LoadDestination.FallThrough then + if hasElse then + val thenKind = tpeTK(thenp) + val elseKind = tpeTK(elsep) + def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT) && expectedType == UNIT + val resKind = if (hasUnitBranch) UNIT else tpeTK(tree) + + val postIf = new asm.Label + genLoadTo(thenp, resKind, LoadDestination.Jump(postIf)) + markProgramPoint(failure) + genLoadTo(elsep, resKind, LoadDestination.FallThrough) + markProgramPoint(postIf) + resKind + else + genLoad(thenp, UNIT) + markProgramPoint(failure) + UNIT + end if + else + genLoadTo(thenp, expectedType, dest) + markProgramPoint(failure) + if hasElse then + genLoadTo(elsep, expectedType, dest) + else + genAdaptAndSendToDest(UNIT, expectedType, dest) + expectedType + end if } def genPrimitiveOp(tree: Apply, expectedType: BType): BType = (tree: @unchecked) match { @@ -285,8 +287,13 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } /* Generate code for trees that produce values on the stack */ - def genLoad(tree: Tree, expectedType: BType): Unit = { + def genLoad(tree: Tree, expectedType: BType): Unit = + genLoadTo(tree, expectedType, LoadDestination.FallThrough) + + /* Generate code for trees that produce values, sent to a given `LoadDestination`. */ + def genLoadTo(tree: Tree, expectedType: BType, dest: LoadDestination): Unit = var generatedType = expectedType + var generatedDest = LoadDestination.FallThrough lineNumber(tree) @@ -307,24 +314,29 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { generatedType = UNIT case t @ If(_, _, _) => - generatedType = genLoadIf(t, expectedType) + generatedType = genLoadIfTo(t, expectedType, dest) + generatedDest = dest case t @ Labeled(_, _) => - generatedType = genLabeled(t) + generatedType = genLabeledTo(t, expectedType, dest) + generatedDest = dest case r: Return => genReturn(r) - generatedType = expectedType + generatedDest = LoadDestination.Return case t @ WhileDo(_, _) => - generatedType = genWhileDo(t, expectedType) + generatedDest = genWhileDo(t) + generatedType = UNIT case t @ Try(_, _, _) => generatedType = genLoadTry(t) case t: Apply if t.fun.symbol eq defn.throwMethod => - genThrow(t.args.head) - generatedType = expectedType + val thrownExpr = t.args.head + val thrownKind = tpeTK(thrownExpr) + genLoadTo(thrownExpr, thrownKind, LoadDestination.Throw) + generatedDest = LoadDestination.Throw case New(tpt) => abort(s"Unexpected New(${tpt.tpe.showSummary()}/$tpt) reached GenBCode.\n" + @@ -425,12 +437,18 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { case blck @ Block(stats, expr) => if(stats.isEmpty) - genLoad(expr, expectedType) - else genBlock(blck, expectedType) + genLoadTo(expr, expectedType, dest) + else + genBlockTo(blck, expectedType, dest) + generatedDest = dest - case Typed(Super(_, _), _) => genLoad(tpd.This(claszSymbol.asClass), expectedType) + case Typed(Super(_, _), _) => + genLoadTo(tpd.This(claszSymbol.asClass), expectedType, dest) + generatedDest = dest - case Typed(expr, _) => genLoad(expr, expectedType) + case Typed(expr, _) => + genLoadTo(expr, expectedType, dest) + generatedDest = dest case Assign(_, _) => generatedType = UNIT @@ -440,7 +458,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { generatedType = genArrayValue(av) case mtch @ Match(_, _) => - generatedType = genMatch(mtch) + generatedType = genMatchTo(mtch, expectedType, dest) + generatedDest = dest case tpd.EmptyTree => if (expectedType != UNIT) { emitZeroOf(expectedType) } @@ -451,12 +470,29 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { case _ => abort(s"Unexpected tree in genLoad: $tree/${tree.getClass} at: ${tree.span}") } - // emit conversion - if (generatedType != expectedType) { + // emit conversion and send to the right destination + if generatedDest == LoadDestination.FallThrough then + genAdaptAndSendToDest(generatedType, expectedType, dest) + end genLoadTo + + def genAdaptAndSendToDest(generatedType: BType, expectedType: BType, dest: LoadDestination): Unit = + if generatedType != expectedType then adapt(generatedType, expectedType) - } - } // end of GenBCode.genLoad() + dest match + case LoadDestination.FallThrough => + () + case LoadDestination.Jump(label) => + bc goTo label + case LoadDestination.Return => + bc emitRETURN returnType + case LoadDestination.Throw => + val thrownType = expectedType + // `throw null` is valid although scala.Null (as defined in src/libray-aux) isn't a subtype of Throwable. + // Similarly for scala.Nothing (again, as defined in src/libray-aux). + assert(thrownType.isNullType || thrownType.isNothingType || thrownType.asClassBType.isSubtypeOf(ThrowableReference)) + emit(asm.Opcodes.ATHROW) + end genAdaptAndSendToDest // ---------------- field load and store ---------------- @@ -533,13 +569,23 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } } - private def genLabeled(tree: Labeled): BType = tree match { + private def genLabeledTo(tree: Labeled, expectedType: BType, dest: LoadDestination): BType = tree match { case Labeled(bind, expr) => - val resKind = tpeTK(tree) - genLoad(expr, resKind) - markProgramPoint(programPoint(bind.symbol)) - resKind + val labelSym = bind.symbol + + if dest == LoadDestination.FallThrough then + val resKind = tpeTK(tree) + val jumpTarget = new asm.Label + registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget)) + genLoad(expr, resKind) + markProgramPoint(jumpTarget) + resKind + else + registerJumpDest(labelSym, expectedType, dest) + genLoadTo(expr, expectedType, dest) + expectedType + end if } private def genReturn(r: Return): Unit = { @@ -548,17 +594,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { if (NoSymbol == fromSym) { // return from enclosing method - val returnedKind = tpeTK(expr) - genLoad(expr, returnedKind) - adapt(returnedKind, returnType) - val saveReturnValue = (returnType != UNIT) - lineNumber(r) - cleanups match { case Nil => // not an assertion: !shouldEmitCleanup (at least not yet, pendingCleanups() may still have to run, and reset `shouldEmitCleanup`. - bc emitRETURN returnType + genLoadTo(expr, returnType, LoadDestination.Return) case nextCleanup :: rest => + genLoad(expr, returnType) + lineNumber(r) + val saveReturnValue = (returnType != UNIT) if (saveReturnValue) { // regarding return value, the protocol is: in place of a `return-stmt`, a sequence of `adapt, store, jump` are inserted. if (earlyReturnVar == null) { @@ -578,54 +621,39 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { * that cross cleanup boundaries. However, in theory such crossings are valid, so we should take care * of them. */ - val resultKind = toTypeKind(fromSym.info) - genLoad(expr, resultKind) - lineNumber(r) - bc goTo programPoint(fromSym) + val (exprExpectedType, exprDest) = findJumpDest(fromSym) + genLoadTo(expr, exprExpectedType, exprDest) } } // end of genReturn() - def genWhileDo(tree: WhileDo, expectedType: BType): BType = tree match{ + def genWhileDo(tree: WhileDo): LoadDestination = tree match{ case WhileDo(cond, body) => val isInfinite = cond == tpd.EmptyTree + val loop = new asm.Label + markProgramPoint(loop) + if isInfinite then - body match - case Labeled(bind, expr) if tpeTK(body) == UNIT => - // this is the shape of tailrec methods - val loop = programPoint(bind.symbol) - markProgramPoint(loop) - genLoad(expr, UNIT) - bc goTo loop - case _ => - val loop = new asm.Label - markProgramPoint(loop) - genLoad(body, UNIT) - bc goTo loop - end match - expectedType + val dest = LoadDestination.Jump(loop) + genLoadTo(body, UNIT, dest) + dest else body match case Literal(value) if value.tag == UnitTag => // this is the shape of do..while loops - val loop = new asm.Label - markProgramPoint(loop) val exitLoop = new asm.Label genCond(cond, loop, exitLoop, targetIfNoJump = exitLoop) markProgramPoint(exitLoop) case _ => - val loop = new asm.Label val success = new asm.Label val failure = new asm.Label - markProgramPoint(loop) genCond(cond, success, failure, targetIfNoJump = success) markProgramPoint(success) - genLoad(body, UNIT) - bc goTo loop + genLoadTo(body, UNIT, LoadDestination.Jump(loop)) markProgramPoint(failure) end match - UNIT + LoadDestination.FallThrough } def genTypeApply(t: TypeApply): BType = (t: @unchecked) match { @@ -848,11 +876,16 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { * Int/String values to use as keys, and a code block. The exception is the "default" case * clause which doesn't list any key (there is exactly one of these per match). */ - private def genMatch(tree: Match): BType = tree match { + private def genMatchTo(tree: Match, expectedType: BType, dest: LoadDestination): BType = tree match { case Match(selector, cases) => lineNumber(tree) - val generatedType = tpeTK(tree) - val postMatch = new asm.Label + + val (generatedType, postMatch, postMatchDest) = + if dest == LoadDestination.FallThrough then + val postMatch = new asm.Label + (tpeTK(tree), postMatch, LoadDestination.Jump(postMatch)) + else + (expectedType, null, dest) // Only two possible selector types exist in `Match` trees at this point: Int and String if (tpeTK(selector) == INT) { @@ -902,8 +935,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { for (sb <- switchBlocks.reverse) { val (caseLabel, caseBody) = sb markProgramPoint(caseLabel) - genLoad(caseBody, generatedType) - bc goTo postMatch + genLoadTo(caseBody, generatedType, postMatchDest) } } else { @@ -968,13 +1000,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } // Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it - genLoadIf( + genLoadIfTo( If( tree.selector.select(defn.Any_==).appliedTo(nullLiteral), Literal(Constant(0)), tree.selector.select(defn.Any_hashCode).appliedToNone ), - INT + INT, + LoadDestination.FallThrough ) bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY) @@ -993,8 +1026,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val thisCaseMatches = new asm.Label genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches) markProgramPoint(thisCaseMatches) - genLoad(caseBody, generatedType) - bc goTo postMatch + genLoadTo(caseBody, generatedType, postMatchDest) } markProgramPoint(keepGoing) } @@ -1004,22 +1036,22 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { // emit blocks for common patterns for ((caseLabel, caseBody) <- indirectBlocks.reverse) { markProgramPoint(caseLabel) - genLoad(caseBody, generatedType) - bc goTo postMatch + genLoadTo(caseBody, generatedType, postMatchDest) } } - markProgramPoint(postMatch) + if postMatch != null then + markProgramPoint(postMatch) generatedType } - def genBlock(tree: Block, expectedType: BType) = tree match { + def genBlockTo(tree: Block, expectedType: BType, dest: LoadDestination): Unit = tree match { case Block(stats, expr) => val savedScope = varsInScope varsInScope = Nil stats foreach genStat - genLoad(expr, expectedType) + genLoadTo(expr, expectedType, dest) val end = currProgramPoint() if (emitVars) { // add entries to LocalVariableTable JVM attribute diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala index 824a93ed506f..829b156be428 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala @@ -40,6 +40,18 @@ trait BCodeSkelBuilder extends BCodeHelpers { lazy val NativeAttr: Symbol = requiredClass[scala.native] + /** The destination of a value generated by `genLoadTo`. */ + enum LoadDestination: + /** The value is put on the stack, and control flows through to the next opcode. */ + case FallThrough + /** The value is put on the stack, and control flow is transferred to the given `label`. */ + case Jump(label: asm.Label) + /** The value is RETURN'ed from the enclosing method. */ + case Return + /** The value is ATHROW'n. */ + case Throw + end LoadDestination + /* * There's a dedicated PlainClassBuilder for each CompilationUnit, * which simplifies the initialization of per-class data structures in `genPlainClass()` which in turn delegates to `initJClass()` @@ -379,21 +391,21 @@ trait BCodeSkelBuilder extends BCodeHelpers { /* ---------------- Part 1 of program points, ie Labels in the ASM world ---------------- */ /* - * A jump is represented as an Apply node whose symbol denotes a LabelDef, the target of the jump. - * The `jumpDest` map is used to: - * (a) find the asm.Label for the target, given an Apply node's symbol; - * (b) anchor an asm.Label in the instruction stream, given a LabelDef node. - * In other words, (a) is necessary when visiting a jump-source, and (b) when visiting a jump-target. - * A related map is `labelDef`: it has the same keys as `jumpDest` but its values are LabelDef nodes not asm.Labels. - * + * A jump is represented as a Return node whose `from` symbol denotes a Labeled's Bind node, the target of the jump. + * The `jumpDest` map is used to find the `LoadDestination` at the end of the `Labeled` block, as well as the + * corresponding expected type. The `LoadDestination` can never be `FallThrough` here. */ - var jumpDest: immutable.Map[ /* Labeled or LabelDef */ Symbol, asm.Label ] = null - def programPoint(labelSym: Symbol): asm.Label = { + var jumpDest: immutable.Map[ /* Labeled */ Symbol, (BType, LoadDestination) ] = null + def registerJumpDest(labelSym: Symbol, expectedType: BType, dest: LoadDestination): Unit = { + assert(labelSym.is(Label), s"trying to register a jump-dest for a non-label symbol, at: ${labelSym.span}") + assert(dest != LoadDestination.FallThrough, s"trying to register a FallThrough dest for label, at: ${labelSym.span}") + assert(!jumpDest.contains(labelSym), s"trying to register a second jump-dest for label, at: ${labelSym.span}") + jumpDest += (labelSym -> (expectedType, dest)) + } + def findJumpDest(labelSym: Symbol): (BType, LoadDestination) = { assert(labelSym.is(Label), s"trying to map a non-label symbol to an asm.Label, at: ${labelSym.span}") jumpDest.getOrElse(labelSym, { - val pp = new asm.Label - jumpDest += (labelSym -> pp) - pp + abort(s"unknown label symbol, for label at: ${labelSym.span}") }) } @@ -566,7 +578,7 @@ trait BCodeSkelBuilder extends BCodeHelpers { def resetMethodBookkeeping(dd: DefDef) = { val rhs = dd.rhs locals.reset(isStaticMethod = methSymbol.isStaticMember) - jumpDest = immutable.Map.empty[ /* LabelDef */ Symbol, asm.Label ] + jumpDest = immutable.Map.empty // check previous invocation of genDefDef exited as many varsInScope as it entered. assert(varsInScope == null, "Unbalanced entering/exiting of GenBCode's genBlock().") @@ -799,20 +811,16 @@ trait BCodeSkelBuilder extends BCodeHelpers { def emitNormalMethodBody(): Unit = { val veryFirstProgramPoint = currProgramPoint() - genLoad(trimmedRhs, returnType) - - trimmedRhs match { - case (_: Return) | Block(_, (_: Return)) => () - case (_: Apply) | Block(_, (_: Apply)) if trimmedRhs.symbol eq defn.throwMethod => () - case tpd.EmptyTree => - report.error("Concrete method has no definition: " + dd + ( - if (ctx.settings.Ydebug.value) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")" - else ""), - ctx.source.atSpan(NoSpan) - ) - case _ => - bc emitRETURN returnType - } + + if trimmedRhs == tpd.EmptyTree then + report.error("Concrete method has no definition: " + dd + ( + if (ctx.settings.Ydebug.value) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")" + else ""), + ctx.source.atSpan(NoSpan) + ) + else + genLoadTo(trimmedRhs, returnType, LoadDestination.Return) + if (emitVars) { // add entries to LocalVariableTable JVM attribute val onePastLastProgramPoint = currProgramPoint() @@ -905,7 +913,7 @@ trait BCodeSkelBuilder extends BCodeHelpers { } } - def genLoad(tree: Tree, expectedType: BType): Unit + def genLoadTo(tree: Tree, expectedType: BType, dest: LoadDestination): Unit } // end of class PlainSkelBuilder diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeSyncAndTry.scala b/compiler/src/dotty/tools/backend/jvm/BCodeSyncAndTry.scala index 3afa13d18c98..5bcc4e0efead 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeSyncAndTry.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeSyncAndTry.scala @@ -397,7 +397,7 @@ trait BCodeSyncAndTry extends BCodeBodyBuilder { /* `tmp` (if non-null) is the symbol of the local-var used to preserve the result of the try-body, see `guardResult` */ def emitFinalizer(finalizer: Tree, tmp: Symbol, isDuplicate: Boolean): Unit = { - var saved: immutable.Map[ /* LabelDef */ Symbol, asm.Label ] = null + var saved: immutable.Map[ /* Labeled */ Symbol, (BType, LoadDestination) ] = null if (isDuplicate) { saved = jumpDest } diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index f1798b69e59a..4b8e59852942 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -597,7 +597,7 @@ class TestBCode extends DottyBytecodeTest { val clsIn = dir.lookupName("Test.class", directory = false).input val clsNode = loadClassNode(clsIn) val method = getMethod(clsNode, "test") - assertEquals(93, instructionsFromMethod(method).size) + assertEquals(88, instructionsFromMethod(method).size) } } @@ -938,7 +938,7 @@ class TestBCode extends DottyBytecodeTest { Label(0), Ldc(LDC, ""), VarOp(ASTORE, 1), Label(5), VarOp(ALOAD, 1), Jump(IFNULL, Label(19)), Label(10), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "foo", "()V", false), Label(14), Op(ACONST_NULL), VarOp(ASTORE, 1), Jump(GOTO, Label(5)), - Label(19), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "bar", "()V", false), Label(24), Op(RETURN), Label(26))) + Label(19), VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "bar", "()V", false), Op(RETURN), Label(25))) val labels = instructions collect { case l: Label => l } val x = convertMethod(t).localVars.find(_.name == "x").get assertEquals(x.start, labels(1)) @@ -976,7 +976,7 @@ class TestBCode extends DottyBytecodeTest { Op(ICONST_0), Jump(IF_ICMPNE, Label(7)), VarOp(ILOAD, 2), - Jump(GOTO, Label(22)), + Op(IRETURN), Label(7), VarOp(ILOAD, 1), Op(ICONST_1), @@ -991,12 +991,6 @@ class TestBCode extends DottyBytecodeTest { VarOp(ILOAD, 4), VarOp(ISTORE, 2), Jump(GOTO, Label(0)), - Label(22), - Op(IRETURN), - Op(NOP), - Op(NOP), - Op(NOP), - Op(ATHROW), )) // The mutable local vars for this and acc reuse the slots of `this` and of the param acc @@ -1015,7 +1009,7 @@ class TestBCode extends DottyBytecodeTest { VarOp(ALOAD, 0), Field(GETFIELD, "IntList", "head", "I"), Op(IADD), - Jump(GOTO, Label(26)), + Op(IRETURN), Label(12), VarOp(ALOAD, 2), VarOp(ASTORE, 3), @@ -1029,12 +1023,6 @@ class TestBCode extends DottyBytecodeTest { VarOp(ILOAD, 4), VarOp(ISTORE, 1), Jump(GOTO, Label(0)), - Label(26), - Op(IRETURN), - Op(NOP), - Op(NOP), - Op(ATHROW), - Op(ATHROW), )) } } @@ -1079,7 +1067,7 @@ class TestBCode extends DottyBytecodeTest { VarOp(ALOAD, 4), VarOp(ASTORE, 6), VarOp(ILOAD, 5), - Jump(GOTO, Label(47)), + Op(IRETURN), Label(19), Field(GETSTATIC, "scala/package$", "MODULE$", "Lscala/package$;"), Invoke(INVOKEVIRTUAL, "scala/package$", "Nil", "()Lscala/collection/immutable/Nil$;", false), @@ -1097,15 +1085,13 @@ class TestBCode extends DottyBytecodeTest { Jump(IFEQ, Label(40)), Label(36), IntOp(BIPUSH, 20), - Jump(GOTO, Label(47)), + Op(IRETURN), Label(40), TypeOp(NEW, "scala/MatchError"), Op(DUP), VarOp(ALOAD, 2), Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), Op(ATHROW), - Label(47), - Op(IRETURN), )) // --------------- @@ -1143,18 +1129,16 @@ class TestBCode extends DottyBytecodeTest { VarOp(ALOAD, 5), VarOp(ASTORE, 6), IntOp(BIPUSH, 10), - Jump(GOTO, Label(46)), + Op(IRETURN), Label(34), VarOp(ILOAD, 4), VarOp(ISTORE, 7), VarOp(ALOAD, 5), VarOp(ASTORE, 8), VarOp(ILOAD, 7), - Jump(GOTO, Label(46)), + Op(IRETURN), Label(42), IntOp(BIPUSH, 20), - Jump(GOTO, Label(46)), - Label(46), Op(IRETURN), )) } @@ -1191,39 +1175,21 @@ class TestBCode extends DottyBytecodeTest { VarOp(ILOAD, 1), VarOp(ISTORE, 2), VarOp(ILOAD, 2), - LookupSwitch(LOOKUPSWITCH, Label(40), List(1, 7, 8, 9), List(Label(4), Label(13), Label(22), Label(31))), + LookupSwitch(LOOKUPSWITCH, Label(20), List(1, 7, 8, 9), List(Label(4), Label(8), Label(12), Label(16))), Label(4), IntOp(BIPUSH, 10), - Jump(GOTO, Label(52)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(13), + Op(IRETURN), + Label(8), IntOp(BIPUSH, 20), - Jump(GOTO, Label(52)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(22), + Op(IRETURN), + Label(12), IntOp(BIPUSH, 30), - Jump(GOTO, Label(52)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(31), + Op(IRETURN), + Label(16), IntOp(BIPUSH, 40), - Jump(GOTO, Label(52)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(40), + Op(IRETURN), + Label(20), VarOp(ILOAD, 1), - Jump(GOTO, Label(52)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Op(ATHROW), - Label(52), Op(IRETURN), )) @@ -1235,46 +1201,31 @@ class TestBCode extends DottyBytecodeTest { VarOp(ILOAD, 1), VarOp(ISTORE, 2), VarOp(ILOAD, 2), - LookupSwitch(LOOKUPSWITCH, Label(31), List(1, 2, 7, 8), List(Label(4), Label(4), Label(13), Label(22))), + LookupSwitch(LOOKUPSWITCH, Label(16), List(1, 2, 7, 8), List(Label(4), Label(4), Label(8), Label(12))), Label(4), IntOp(BIPUSH, 10), - Jump(GOTO, Label(56)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(13), + Op(IRETURN), + Label(8), IntOp(BIPUSH, 20), - Jump(GOTO, Label(56)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(22), + Op(IRETURN), + Label(12), IntOp(BIPUSH, 30), - Jump(GOTO, Label(56)), - Op(NOP), - Op(NOP), - Op(ATHROW), - Label(31), + Op(IRETURN), + Label(16), VarOp(ILOAD, 2), VarOp(ISTORE, 3), VarOp(ILOAD, 3), IntOp(BIPUSH, 100), - Jump(IF_ICMPLE, Label(40)), + Jump(IF_ICMPLE, Label(25)), IntOp(BIPUSH, 20), - Jump(GOTO, Label(56)), - Label(40), + Op(IRETURN), + Label(25), TypeOp(NEW, "scala/MatchError"), Op(DUP), VarOp(ILOAD, 2), Invoke(INVOKESTATIC, "scala/runtime/BoxesRunTime", "boxToInteger", "(I)Ljava/lang/Integer;", false), Invoke(INVOKESPECIAL, "scala/MatchError", "", "(Ljava/lang/Object;)V", false), Op(ATHROW), - Op(NOP), - Op(NOP), - Op(ATHROW), - Op(ATHROW), - Label(56), - Op(IRETURN), )) } } @@ -1283,6 +1234,7 @@ class TestBCode extends DottyBytecodeTest { /* This is a test case coming from the Scala.js linker, where in Scala 2 we * had to introduce a "useless" `return` to make the bytecode size smaller, * measurably increasing performance (!). + * In dotc, with or without the explicit `return`, the generated code is the same. */ val source = @@ -1360,12 +1312,11 @@ class TestBCode extends DottyBytecodeTest { VarOp(ILOAD, 3), Op(IADD), Invoke(INVOKEVIRTUAL, "java/io/Writer", "write", "(I)V", false), - Jump(GOTO, Label(31)), + Op(RETURN), Label(26), VarOp(ALOAD, 0), VarOp(ILOAD, 3), Invoke(INVOKESPECIAL, "SourceMapWriter", "writeBase64VLQSlowPath$1", "(I)V", false), - Label(31), Op(RETURN), ))