diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 765bc607a4a1..d5027192b269 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -1383,7 +1383,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { val expanded = expandMacro(res.args.head, tree.srcPos) typedExpr(expanded) // Inline calls and constant fold code generated by the macro case res => - inlineIfNeeded(res) + specializeEq(inlineIfNeeded(res)) } if res.symbol == defn.QuotedRuntime_exprQuote then ctx.compilationUnit.needsQuotePickling = true @@ -1465,6 +1465,21 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { case tree => tree } + def specializeEq(tree: Tree): Tree = + tree match + case Apply(sel @ Select(arg1, opName), arg2 :: Nil) + if sel.symbol == defn.Any_== || sel.symbol == defn.Any_!= => + defn.ScalaValueClasses().find { cls => + arg1.tpe.derivesFrom(cls) && arg2.tpe.derivesFrom(cls) + } match { + case Some(cls) => + val newOp = cls.requiredMethod(opName, List(cls.typeRef)) + arg1.select(newOp).withSpan(sel.span).appliedTo(arg2).withSpan(tree.span) + case None => tree + } + case _ => + tree + /** Drop any side-effect-free bindings that are unused in expansion or other reachable bindings. * Inline def bindings that are used only once. */ diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index 829bc2607feb..3e9f191ab279 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -615,4 +615,108 @@ class InlineBytecodeTests extends DottyBytecodeTest { } } + + @Test def any_eq_specialization = { + val source = """class Test: + | inline def eql(x: Any, y: Any) = x == y + | + | def testAny(x: Any, y: Any) = eql(x, y) + | def testAnyExpected(x: Any, y: Any) = x == y + | + | def testBoolean(x: Boolean, y: Boolean) = eql(x, y) + | def testBooleanExpected(x: Boolean, y: Boolean) = x == y + | + | def testByte(x: Byte, y: Byte) = eql(x, y) + | def testByteExpected(x: Byte, y: Byte) = x == y + | + | def testShort(x: Short, y: Short) = eql(x, y) + | def testShortExpected(x: Short, y: Short) = x == y + | + | def testInt(x: Int, y: Int) = eql(x, y) + | def testIntExpected(x: Int, y: Int) = x == y + | + | def testLong(x: Long, y: Long) = eql(x, y) + | def testLongExpected(x: Long, y: Long) = x == y + | + | def testFloat(x: Float, y: Float) = eql(x, y) + | def testFloatExpected(x: Float, y: Float) = x == y + | + | def testDouble(x: Double, y: Double) = eql(x, y) + | def testDoubleExpected(x: Double, y: Double) = x == y + | + | def testChar(x: Char, y: Char) = eql(x, y) + | def testCharExpected(x: Char, y: Char) = x == y + | + | def testUnit(x: Unit, y: Unit) = eql(x, y) + | def testUnitExpected(x: Unit, y: Unit) = x == y + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + for cls <- List("Boolean", "Byte", "Short", "Int", "Long", "Float", "Double", "Char", "Unit") do + val meth1 = getMethod(clsNode, s"test$cls") + val meth2 = getMethod(clsNode, s"test${cls}Expected") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + s"`==` was not properly specialized when inlined in `test$cls`\n" + + diffInstructions(instructions1, instructions2)) + } + } + + @Test def any_neq_specialization = { + val source = """class Test: + | inline def neql(x: Any, y: Any) = x != y + | + | def testAny(x: Any, y: Any) = neql(x, y) + | def testAnyExpected(x: Any, y: Any) = x != y + | + | def testBoolean(x: Boolean, y: Boolean) = neql(x, y) + | def testBooleanExpected(x: Boolean, y: Boolean) = x != y + | + | def testByte(x: Byte, y: Byte) = neql(x, y) + | def testByteExpected(x: Byte, y: Byte) = x != y + | + | def testShort(x: Short, y: Short) = neql(x, y) + | def testShortExpected(x: Short, y: Short) = x != y + | + | def testInt(x: Int, y: Int) = neql(x, y) + | def testIntExpected(x: Int, y: Int) = x != y + | + | def testLong(x: Long, y: Long) = neql(x, y) + | def testLongExpected(x: Long, y: Long) = x != y + | + | def testFloat(x: Float, y: Float) = neql(x, y) + | def testFloatExpected(x: Float, y: Float) = x != y + | + | def testDouble(x: Double, y: Double) = neql(x, y) + | def testDoubleExpected(x: Double, y: Double) = x != y + | + | def testChar(x: Char, y: Char) = neql(x, y) + | def testCharExpected(x: Char, y: Char) = x != y + | + | def testUnit(x: Unit, y: Unit) = neql(x, y) + | def testUnitExpected(x: Unit, y: Unit) = x != y + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + + for cls <- List("Boolean", "Byte", "Short", "Int", "Long", "Float", "Double", "Char", "Unit") do + val meth1 = getMethod(clsNode, s"test$cls") + val meth2 = getMethod(clsNode, s"test${cls}Expected") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + s"`!=` was not properly specialized when inlined in `test$cls`\n" + + diffInstructions(instructions1, instructions2)) + } + } }