Skip to content

Commit 9522b0f

Browse files
authored
Merge pull request #11839 from dotty-staging/fix-SAM-lambda-syntax
Fix bytecode generation for Single Abstract Method lambdas
2 parents f5ced11 + 2ab6f5d commit 9522b0f

File tree

10 files changed

+231
-24
lines changed

10 files changed

+231
-24
lines changed

compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
14141414
def genLoadTry(tree: Try): BType
14151415

14161416
def genInvokeDynamicLambda(ctor: Symbol, lambdaTarget: Symbol, environmentSize: Int, functionalInterface: Symbol): BType = {
1417-
import java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE
1417+
import java.lang.invoke.LambdaMetafactory.{FLAG_BRIDGES, FLAG_SERIALIZABLE}
14181418

14191419
report.debuglog(s"Using invokedynamic rather than `new ${ctor.owner}`")
14201420
val generatedType = classBTypeFromSymbol(functionalInterface)
@@ -1445,9 +1445,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
14451445
val functionalInterfaceDesc: String = generatedType.descriptor
14461446
val desc = capturedParamsTypes.map(tpe => toTypeKind(tpe)).mkString(("("), "", ")") + functionalInterfaceDesc
14471447
// TODO specialization
1448-
val constrainedType = new MethodBType(lambdaParamTypes.map(p => toTypeKind(p)), toTypeKind(lambdaTarget.info.resultType)).toASMType
1448+
val instantiatedMethodType = new MethodBType(lambdaParamTypes.map(p => toTypeKind(p)), toTypeKind(lambdaTarget.info.resultType)).toASMType
14491449

1450-
val abstractMethod = atPhase(erasurePhase) {
1450+
val samMethod = atPhase(erasurePhase) {
14511451
val samMethods = toDenot(functionalInterface).info.possibleSamMethods.toList
14521452
samMethods match {
14531453
case x :: Nil => x.symbol
@@ -1457,21 +1457,40 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
14571457
}
14581458
}
14591459

1460-
val methodName = abstractMethod.javaSimpleName
1461-
val applyN = {
1462-
val mt = asmMethodType(abstractMethod)
1463-
mt.toASMType
1460+
val methodName = samMethod.javaSimpleName
1461+
val samMethodType = asmMethodType(samMethod).toASMType
1462+
// scala/bug#10334: make sure that a lambda object for `T => U` has a method `apply(T)U`, not only the `(Object)Object`
1463+
// version. Using the lambda a structural type `{def apply(t: T): U}` causes a reflective lookup for this method.
1464+
val needsGenericBridge = samMethodType != instantiatedMethodType
1465+
val bridgeMethods = atPhase(erasurePhase){
1466+
samMethod.allOverriddenSymbols.toList
14641467
}
1465-
val bsmArgs0 = Seq(applyN, targetHandle, constrainedType)
1466-
val bsmArgs =
1467-
if (isSerializable)
1468-
bsmArgs0 :+ Int.box(FLAG_SERIALIZABLE)
1468+
val overriddenMethodTypes = bridgeMethods.map(b => asmMethodType(b).toASMType)
1469+
1470+
// any methods which `samMethod` overrides need bridges made for them
1471+
// this is done automatically during erasure for classes we generate, but LMF needs to have them explicitly mentioned
1472+
// so we have to compute them at this relatively late point.
1473+
val bridgeTypes = (
1474+
if (needsGenericBridge)
1475+
instantiatedMethodType +: overriddenMethodTypes
14691476
else
1470-
bsmArgs0
1477+
overriddenMethodTypes
1478+
).distinct.filterNot(_ == samMethodType)
1479+
1480+
val needsBridges = bridgeTypes.nonEmpty
1481+
1482+
def flagIf(b: Boolean, flag: Int): Int = if (b) flag else 0
1483+
val flags = flagIf(isSerializable, FLAG_SERIALIZABLE) | flagIf(needsBridges, FLAG_BRIDGES)
1484+
1485+
val bsmArgs0 = Seq(samMethodType, targetHandle, instantiatedMethodType)
1486+
val bsmArgs1 = if (flags != 0) Seq(Int.box(flags)) else Seq.empty
1487+
val bsmArgs2 = if needsBridges then bridgeTypes.length +: bridgeTypes else Seq.empty
1488+
1489+
val bsmArgs = bsmArgs0 ++ bsmArgs1 ++ bsmArgs2
14711490

14721491
val metafactory =
1473-
if (isSerializable)
1474-
lambdaMetaFactoryAltMetafactoryHandle // altMetafactory needed to be able to pass the SERIALIZABLE flag
1492+
if (flags != 0)
1493+
lambdaMetaFactoryAltMetafactoryHandle // altMetafactory required to be able to pass the flags and additional arguments if needed
14751494
else
14761495
lambdaMetaFactoryMetafactoryHandle
14771496

compiler/src/dotty/tools/backend/jvm/GenBCode.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,13 @@ class GenBCodePipeline(val int: DottyBackendInterface, val primitives: DottyPrim
334334
val insn = iter.next()
335335
insn match {
336336
case indy: InvokeDynamicInsnNode
337-
// No need to check the exact bsmArgs because we only generate
338-
// altMetafactory indy calls for serializable lambdas.
339-
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
340-
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
341-
indyLambdaBodyMethods += implMethod
337+
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
338+
import java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE
339+
val metafactoryFlags = indy.bsmArgs(3).asInstanceOf[Integer].toInt
340+
val isSerializable = (metafactoryFlags & FLAG_SERIALIZABLE) != 0
341+
if isSerializable then
342+
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
343+
indyLambdaBodyMethods += implMethod
342344
case _ =>
343345
}
344346
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,10 @@ object Types {
928928
*/
929929
final def possibleSamMethods(using Context): Seq[SingleDenotation] = {
930930
record("possibleSamMethods")
931-
abstractTermMembers.toList.filterConserve(m =>
932-
!m.symbol.matchingMember(defn.ObjectType).exists && !m.symbol.isSuperAccessor)
931+
atPhaseNoLater(erasurePhase) {
932+
abstractTermMembers.toList.filterConserve(m =>
933+
!m.symbol.matchingMember(defn.ObjectType).exists && !m.symbol.isSuperAccessor)
934+
}.map(_.current)
933935
}
934936

935937
/** The set of abstract type members of this type. */

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,9 @@ object Erasure {
431431
val implParamTypes = implType.paramInfos
432432
val implResultType = implType.resultType
433433
val implReturnsUnit = implResultType.classSymbol eq defn.UnitClass
434-
// The SAM that this closure should implement
435-
val SAMType(sam) = lambdaType: @unchecked
434+
// The SAM that this closure should implement.
435+
// At this point it should be already guaranteed that there's only one method to implement
436+
val Seq(sam: MethodType) = lambdaType.possibleSamMethods.map(_.info)
436437
val samParamTypes = sam.paramInfos
437438
val samResultType = sam.resultType
438439

@@ -503,7 +504,7 @@ object Erasure {
503504
implType.derivedLambdaType(paramInfos = samParamTypes)
504505
else
505506
implType.derivedLambdaType(resType = samResultType)
506-
val bridge = newSymbol(ctx.owner, AdaptedClosureName(meth.symbol.name.asTermName), Flags.Synthetic | Flags.Method, bridgeType)
507+
val bridge = newSymbol(ctx.owner, AdaptedClosureName(meth.symbol.name.asTermName), Flags.Synthetic | Flags.Method | Flags.Bridge, bridgeType)
507508
Closure(bridge, bridgeParamss =>
508509
inContext(ctx.withOwner(bridge)) {
509510
val List(bridgeParams) = bridgeParamss

tests/run/i10068a.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
42
2+
Foo
3+
Foo

tests/run/i10068a.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
sealed trait Partial
2+
sealed trait Total extends Partial
3+
4+
case object Foo extends Total
5+
6+
trait P[A] {
7+
def bar(a: A): Partial
8+
}
9+
10+
trait T[A] extends P[A] {
11+
def bar(a: A): Total
12+
}
13+
14+
object T {
15+
def make[A](x: Total): T[A] =
16+
a => x
17+
}
18+
19+
object Test {
20+
def total[A](a: A)(ev: T[A]): Total = ev.bar(a)
21+
def partial[A](a: A)(ev: P[A]): Partial = ev.bar(a)
22+
23+
def go[A](a: A)(ev: T[A]): Unit = {
24+
println(a)
25+
println(total(a)(ev))
26+
println(partial(a)(ev))
27+
}
28+
29+
def main(args: Array[String]): Unit =
30+
go(42)(T.make(Foo))
31+
}

tests/run/i10068b.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
trait Foo[A] {
2+
def xxx(a1: A, a2: A): A
3+
def xxx(a: A): A = xxx(a, a)
4+
}
5+
6+
trait Bar[A] extends Foo[A] {
7+
def yyy(a1: A, a2: A) = xxx(a1, a2)
8+
}
9+
10+
trait Baz[A] extends Bar[A]
11+
12+
object Test:
13+
def main(args: Array[String]): Unit =
14+
val foo: Foo[String] = { (s1, s2) => s1 ++ s2 }
15+
val bar: Bar[String] = { (s1, s2) => s1 ++ s2 }
16+
val baz: Baz[String] = { (s1, s2) => s1 ++ s2 }
17+
18+
val s = "abc"
19+
val ss = "abcabc"
20+
assert(foo.xxx(s) == ss)
21+
assert(bar.yyy(s, s) == ss)
22+
assert(baz.xxx(s) == ss)
23+
assert(baz.yyy(s, s) == ss)

tests/run/i10068c.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Taken from: https://github.com/scala/scala/pull/6087
2+
3+
trait JsonValue
4+
class JsonObject extends JsonValue
5+
class JsonString extends JsonValue
6+
7+
trait JsonEncoder[A] {
8+
def encode(value: A): JsonValue
9+
}
10+
11+
trait JsonObjectEncoder[A] extends JsonEncoder[A] {
12+
def encode(value: A): JsonObject
13+
}
14+
15+
object JsonEncoderInstances {
16+
17+
val seWorks: JsonEncoder[String] =
18+
new JsonEncoder[String] {
19+
def encode(value: String) = new JsonString
20+
}
21+
22+
implicit val stringEncoder: JsonEncoder[String] =
23+
s => new JsonString
24+
//new JsonEncoder[String] {
25+
// def encode(value: String) = new JsonString
26+
//}
27+
28+
def leWorks[A](implicit encoder: JsonEncoder[A]): JsonObjectEncoder[List[A]] =
29+
new JsonObjectEncoder[List[A]] {
30+
def encode(value: List[A]) = new JsonObject
31+
}
32+
33+
implicit def listEncoder[A](implicit encoder: JsonEncoder[A]): JsonObjectEncoder[List[A]] =
34+
l => new JsonObject
35+
// new JsonObjectEncoder[List[A]] {
36+
// def encode(value: List[A]) = new JsonObject
37+
// }
38+
39+
}
40+
41+
object Test extends App {
42+
import JsonEncoderInstances._
43+
44+
implicitly[JsonEncoder[List[String]]].encode("" :: Nil)
45+
}

tests/run/i10068d.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Taken from: https://github.com/scala/scala/pull/6087
2+
3+
trait A
4+
trait B extends A
5+
trait C extends B
6+
object it extends C
7+
8+
/* try as many weird diamondy things as I can think of */
9+
trait SAM_A { def apply(): A }
10+
trait SAM_A1 extends SAM_A { def apply(): A }
11+
trait SAM_B extends SAM_A1 { def apply(): B }
12+
trait SAM_B1 extends SAM_A1 { def apply(): B }
13+
trait SAM_B2 extends SAM_B with SAM_B1
14+
trait SAM_C extends SAM_B2 { def apply(): C }
15+
16+
trait SAM_F extends (() => A) with SAM_C
17+
trait SAM_F1 extends (() => C) with SAM_F
18+
19+
20+
object Test extends App {
21+
22+
val s1: SAM_A = () => it
23+
val s2: SAM_A1 = () => it
24+
val s3: SAM_B = () => it
25+
val s4: SAM_B1 = () => it
26+
val s5: SAM_B2 = () => it
27+
val s6: SAM_C = () => it
28+
val s7: SAM_F = () => it
29+
val s8: SAM_F1 = () => it
30+
31+
(s1(): A)
32+
33+
(s2(): A)
34+
35+
(s3(): B)
36+
(s3(): A)
37+
38+
(s4(): B)
39+
(s4(): A)
40+
41+
(s5(): B)
42+
(s5(): A)
43+
44+
(s6(): C)
45+
(s6(): B)
46+
(s6(): A)
47+
48+
(s7(): C)
49+
(s7(): B)
50+
(s7(): A)
51+
52+
(s8(): C)
53+
(s8(): B)
54+
(s8(): A)
55+
56+
}

tests/run/i11676.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
sealed trait PartialOrdering
2+
sealed trait Ordering extends PartialOrdering
3+
4+
object Ordering {
5+
def fromCompare(n: Int): Ordering = new Ordering {}
6+
}
7+
8+
trait PartialOrd[-A] {
9+
def checkCompare(l: A, r: A): PartialOrdering
10+
}
11+
12+
trait Ord[-A] extends PartialOrd[A] {
13+
def checkCompare(l: A, r: A): Ordering
14+
}
15+
16+
object Ord {
17+
def fromScala[A](implicit ordering: scala.math.Ordering[A]): Ord[A] =
18+
(l: A, r: A) => Ordering.fromCompare(ordering.compare(l, r))
19+
}
20+
21+
object Test {
22+
def main(args: Array[String]): Unit =
23+
val intOrd = Ord.fromScala[Int]
24+
intOrd.checkCompare(1, 3)
25+
}

0 commit comments

Comments
 (0)