Skip to content

Commit aee7a4f

Browse files
authored
Merge pull request #3203 from dotty-staging/fix-lambda-nulls
Various fixes related to lambda adaptation
2 parents 71f9efb + fe3c357 commit aee7a4f

File tree

90 files changed

+388
-237
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+388
-237
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,34 @@ class Definitions {
911911
arity >= 0 && isFunctionClass(sym) && tp.isRef(FunctionType(arity, sym.name.isImplicitFunction).typeSymbol)
912912
}
913913

914+
// Specialized type parameters defined for scala.Function{0,1,2}.
915+
private lazy val Function1SpecializedParams: collection.Set[Type] =
916+
Set(IntType, LongType, FloatType, DoubleType)
917+
private lazy val Function2SpecializedParams: collection.Set[Type] =
918+
Set(IntType, LongType, DoubleType)
919+
private lazy val Function0SpecializedReturns: collection.Set[Type] =
920+
ScalaNumericValueTypeList.toSet[Type] + UnitType + BooleanType
921+
private lazy val Function1SpecializedReturns: collection.Set[Type] =
922+
Set(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
923+
private lazy val Function2SpecializedReturns: collection.Set[Type] =
924+
Function1SpecializedReturns
925+
926+
def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(implicit ctx: Context) =
927+
isFunctionClass(cls) && (paramTypes match {
928+
case Nil =>
929+
Function0SpecializedReturns.contains(retType)
930+
case List(paramType0) =>
931+
Function1SpecializedParams.contains(paramType0) &&
932+
Function1SpecializedReturns.contains(retType)
933+
case List(paramType0, paramType1) =>
934+
Function2SpecializedParams.contains(paramType0) &&
935+
Function2SpecializedParams.contains(paramType1) &&
936+
Function2SpecializedReturns.contains(retType)
937+
case _ =>
938+
false
939+
})
940+
941+
914942
def functionArity(tp: Type)(implicit ctx: Context) = tp.dealias.argInfos.length - 1
915943

916944
def isImplicitFunctionType(tp: Type)(implicit ctx: Context) =

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ object NameKinds {
361361
val ModuleVarName = new SuffixNameKind(OBJECTVAR, "$module")
362362
val ModuleClassName = new SuffixNameKind(OBJECTCLASS, "$", optInfoString = "ModuleClass")
363363
val ImplMethName = new SuffixNameKind(IMPLMETH, "$")
364+
val AdaptedClosureName = new SuffixNameKind(ADAPTEDCLOSURE, "$adapted") { override def definesNewName = true }
364365

365366
/** A name together with a signature. Used in Tasty trees. */
366367
object SignedName extends NameKind(SIGNED) {

compiler/src/dotty/tools/dotc/core/tasty/TastyFormat.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Macro-format:
4848
DIRECT Length underlying_NameRef
4949
FIELD Length underlying_NameRef
5050
EXTMETH Length underlying_NameRef
51+
ADAPTEDCLOSURE Length underlying_NameRef
5152
OBJECTVAR Length underlying_NameRef
5253
OBJECTCLASS Length underlying_NameRef
5354
SIGNED Length original_NameRef resultSig_NameRef paramSig_NameRef*
@@ -253,6 +254,7 @@ object TastyFormat {
253254
final val DIRECT = 31
254255
final val FIELD = 32
255256
final val EXTMETH = 33
257+
final val ADAPTEDCLOSURE = 34
256258
final val OBJECTVAR = 39
257259
final val OBJECTCLASS = 40
258260

@@ -471,6 +473,7 @@ object TastyFormat {
471473
case DIRECT => "DIRECT"
472474
case FIELD => "FIELD"
473475
case EXTMETH => "EXTMETH"
476+
case ADAPTEDCLOSURE => "ADAPTEDCLOSURE"
474477
case OBJECTVAR => "OBJECTVAR"
475478
case OBJECTCLASS => "OBJECTCLASS"
476479

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import core.Types._
1111
import core.Names._
1212
import core.StdNames._
1313
import core.NameOps._
14+
import core.NameKinds.AdaptedClosureName
1415
import core.Decorators._
1516
import core.Constants._
1617
import core.Definitions._
@@ -565,54 +566,85 @@ object Erasure {
565566
super.typedDefDef(ddef1, sym)
566567
}
567568

568-
/** After erasure, we may have to replace the closure method by a bridge.
569-
* LambdaMetaFactory handles this automatically for most types, but we have
570-
* to deal with boxing and unboxing of value classes ourselves.
571-
*/
572569
override def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context) = {
573570
val xxl = defn.isXXLFunctionClass(tree.typeOpt.typeSymbol)
574571
var implClosure @ Closure(_, meth, _) = super.typedClosure(tree, pt)
575572
if (xxl) implClosure = cpy.Closure(implClosure)(tpt = TypeTree(defn.FunctionXXLType))
576573
implClosure.tpe match {
577574
case SAMType(sam) =>
578-
val implType = meth.tpe.widen
575+
val implType = meth.tpe.widen.asInstanceOf[MethodType]
579576

580-
val List(implParamTypes) = implType.paramInfoss
577+
val implParamTypes = implType.paramInfos
581578
val List(samParamTypes) = sam.info.paramInfoss
582579
val implResultType = implType.resultType
583580
val samResultType = sam.info.resultType
584581

585-
// Given a value class V with an underlying type U, the following code:
586-
// val f: Function1[V, V] = x => ...
587-
// results in the creation of a closure and a method:
588-
// def $anonfun(v1: V): V = ...
589-
// val f: Function1[V, V] = closure($anonfun)
590-
// After [[Erasure]] this method will look like:
591-
// def $anonfun(v1: ErasedValueType(V, U)): ErasedValueType(V, U) = ...
592-
// And after [[ElimErasedValueType]] it will look like:
593-
// def $anonfun(v1: U): U = ...
594-
// This method does not implement the SAM of Function1[V, V] anymore and
595-
// needs to be replaced by a bridge:
596-
// def $anonfun$2(v1: V): V = new V($anonfun(v1.underlying))
597-
// val f: Function1 = closure($anonfun$2)
598-
// In general, a bridge is needed when the signature of the closure method after
599-
// Erasure contains an ErasedValueType but the corresponding type in the functional
600-
// interface is not an ErasedValueType.
601-
val bridgeNeeded =
602-
(implResultType :: implParamTypes, samResultType :: samParamTypes).zipped.exists(
603-
(implType, samType) => implType.isErasedValueType && !samType.isErasedValueType
604-
)
605-
606-
if (bridgeNeeded) {
607-
val bridge = ctx.newSymbol(ctx.owner, nme.ANON_FUN, Flags.Synthetic | Flags.Method, sam.info)
608-
val bridgeCtx = ctx.withOwner(bridge)
609-
Closure(bridge, bridgeParamss => {
610-
implicit val ctx = bridgeCtx
611-
612-
val List(bridgeParams) = bridgeParamss
613-
val rhs = Apply(meth, (bridgeParams, implParamTypes).zipped.map(adapt(_, _)))
614-
adapt(rhs, sam.info.resultType)
615-
})
582+
// The following code:
583+
//
584+
// val f: Function1[Int, Any] = x => ...
585+
//
586+
// results in the creation of a closure and a method in the typer:
587+
//
588+
// def $anonfun(x: Int): Any = ...
589+
// val f: Function1[Int, Any] = closure($anonfun)
590+
//
591+
// Notice that `$anonfun` takes a primitive as argument, but the single abstract method
592+
// of `Function1` after erasure is:
593+
//
594+
// def apply(x: Object): Object
595+
//
596+
// which takes a reference as argument. Hence, some form of adaptation is required.
597+
//
598+
// If we do nothing, the LambdaMetaFactory bootstrap method will
599+
// automatically do the adaptation. Unfortunately, the result does not
600+
// implement the expected Scala semantics: null should be "unboxed" to
601+
// the default value of the value class, but LMF will throw a
602+
// NullPointerException instead. LMF is also not capable of doing
603+
// adaptation for derived value classes.
604+
//
605+
// Thus, we need to replace the closure method by a bridge method that
606+
// forwards to the original closure method with appropriate
607+
// boxing/unboxing. For our example above, this would be:
608+
//
609+
// def $anonfun1(x: Object): Object = $anonfun(BoxesRunTime.unboxToInt(x))
610+
// val f: Function1 = closure($anonfun1)
611+
//
612+
// In general, a bridge is needed when, after Erasure:
613+
// - one of the parameter type of the closure method is a non-reference type,
614+
// and the corresponding type in the SAM is a reference type
615+
// - or the result type of the closure method is an erased value type
616+
// and the result type in the SAM isn't
617+
// However, the following exception exists: If the SAM is replaced by
618+
// JFunction*mc* in [[FunctionalInterfaces]], no bridge is needed: the
619+
// SAM contains default methods to handle adaptation
620+
//
621+
// See test cases lambda-*.scala and t8017/ for concrete examples.
622+
623+
def isReferenceType(tp: Type) = !tp.isPrimitiveValueType && !tp.isErasedValueType
624+
625+
if (!defn.isSpecializableFunction(implClosure.tpe.widen.classSymbol.asClass, implParamTypes, implResultType)) {
626+
val paramAdaptationNeeded =
627+
(implParamTypes, samParamTypes).zipped.exists((implType, samType) =>
628+
!isReferenceType(implType) && isReferenceType(samType))
629+
val resultAdaptationNeeded =
630+
implResultType.isErasedValueType && !samResultType.isErasedValueType
631+
632+
if (paramAdaptationNeeded || resultAdaptationNeeded) {
633+
val bridgeType =
634+
if (paramAdaptationNeeded) {
635+
if (resultAdaptationNeeded) sam.info
636+
else implType.derivedLambdaType(paramInfos = samParamTypes)
637+
} else implType.derivedLambdaType(resType = samResultType)
638+
val bridge = ctx.newSymbol(ctx.owner, AdaptedClosureName(meth.symbol.name.asTermName), Flags.Synthetic | Flags.Method, bridgeType)
639+
val bridgeCtx = ctx.withOwner(bridge)
640+
Closure(bridge, bridgeParamss => {
641+
implicit val ctx = bridgeCtx
642+
643+
val List(bridgeParams) = bridgeParamss
644+
val rhs = Apply(meth, (bridgeParams, implParamTypes).zipped.map(adapt(_, _)))
645+
adapt(rhs, bridgeType.resultType)
646+
}, targetType = implClosure.tpt.tpe)
647+
} else implClosure
616648
} else implClosure
617649
case _ =>
618650
implClosure

compiler/src/dotty/tools/dotc/transform/FunctionalInterfaces.scala

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,58 +26,26 @@ class FunctionalInterfaces extends MiniPhaseTransform {
2626

2727
def phaseName: String = "functionalInterfaces"
2828

29-
private var allowedReturnTypes: Set[Symbol] = _ // moved here to make it explicit what specializations are generated
30-
private var allowedArgumentTypes: Set[Symbol] = _
31-
val maxArgsCount = 2
32-
33-
def shouldSpecialize(m: MethodType)(implicit ctx: Context) =
34-
(m.paramInfos.size <= maxArgsCount) &&
35-
m.paramInfos.forall(x => allowedArgumentTypes.contains(x.typeSymbol)) &&
36-
allowedReturnTypes.contains(m.resultType.typeSymbol)
37-
3829
val functionName = "JFunction".toTermName
3930
val functionPackage = "scala.compat.java8.".toTermName
4031

41-
override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = {
42-
allowedReturnTypes = Set(defn.UnitClass,
43-
defn.BooleanClass,
44-
defn.IntClass,
45-
defn.FloatClass,
46-
defn.LongClass,
47-
defn.DoubleClass,
48-
/* only for Function0: */ defn.ByteClass,
49-
defn.ShortClass,
50-
defn.CharClass)
51-
52-
allowedArgumentTypes = Set(defn.IntClass,
53-
defn.LongClass,
54-
defn.DoubleClass,
55-
/* only for Function1: */ defn.FloatClass)
56-
57-
this
58-
}
59-
6032
override def transformClosure(tree: Closure)(implicit ctx: Context, info: TransformerInfo): Tree = {
61-
tree.tpt match {
62-
case EmptyTree =>
63-
val m = tree.meth.tpe.widen.asInstanceOf[MethodType]
64-
65-
if (shouldSpecialize(m)) {
66-
val functionSymbol = tree.tpe.widenDealias.classSymbol
67-
val names = ctx.atPhase(ctx.erasurePhase) {
68-
implicit ctx => functionSymbol.typeParams.map(_.name)
69-
}
70-
val interfaceName = (functionName ++ m.paramInfos.length.toString).specializedFor(m.paramInfos ::: m.resultType :: Nil, names, Nil, Nil)
71-
72-
// symbols loaded from classpath aren't defined in periods earlier than when they where loaded
73-
val interface = ctx.withPhase(ctx.typerPhase).getClassIfDefined(functionPackage ++ interfaceName)
74-
if (interface.exists) {
75-
val tpt = tpd.TypeTree(interface.asType.appliedRef)
76-
tpd.Closure(tree.env, tree.meth, tpt)
77-
} else tree
78-
} else tree
79-
case _ =>
80-
tree
81-
}
33+
val cls = tree.tpe.widen.classSymbol.asClass
34+
35+
val implType = tree.meth.tpe.widen
36+
val List(implParamTypes) = implType.paramInfoss
37+
val implResultType = implType.resultType
38+
39+
if (defn.isSpecializableFunction(cls, implParamTypes, implResultType)) {
40+
val names = ctx.atPhase(ctx.erasurePhase) {
41+
implicit ctx => cls.typeParams.map(_.name)
42+
}
43+
val interfaceName = (functionName ++ implParamTypes.length.toString).specializedFor(implParamTypes ::: implResultType :: Nil, names, Nil, Nil)
44+
45+
// symbols loaded from classpath aren't defined in periods earlier than when they where loaded
46+
val interface = ctx.withPhase(ctx.typerPhase).requiredClass(functionPackage ++ interfaceName)
47+
val tpt = tpd.TypeTree(interface.asType.appliedRef)
48+
tpd.Closure(tree.env, tree.meth, tpt)
49+
} else tree
8250
}
8351
}

library/src/scala/compat/java8/JFunction0.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,27 @@ public interface JFunction0<R> extends scala.Function0<R> {
1313
apply();
1414
}
1515
default byte apply$mcB$sp() {
16-
return (Byte) apply();
16+
return scala.runtime.BoxesRunTime.unboxToByte(apply());
1717
}
1818
default short apply$mcS$sp() {
19-
return (Short) apply();
19+
return scala.runtime.BoxesRunTime.unboxToShort(apply());
2020
}
2121
default int apply$mcI$sp() {
22-
return (Integer) apply();
22+
return scala.runtime.BoxesRunTime.unboxToInt(apply());
2323
}
2424
default long apply$mcJ$sp() {
25-
return (Long) apply();
25+
return scala.runtime.BoxesRunTime.unboxToLong(apply());
2626
}
2727
default char apply$mcC$sp() {
28-
return (Character) apply();
28+
return scala.runtime.BoxesRunTime.unboxToChar(apply());
2929
}
3030
default float apply$mcF$sp() {
31-
return (Float) apply();
31+
return scala.runtime.BoxesRunTime.unboxToFloat(apply());
3232
}
3333
default double apply$mcD$sp() {
34-
return (Double) apply();
34+
return scala.runtime.BoxesRunTime.unboxToDouble(apply());
3535
}
3636
default boolean apply$mcZ$sp() {
37-
return (Boolean) apply();
37+
return scala.runtime.BoxesRunTime.unboxToBoolean(apply());
3838
}
3939
}

library/src/scala/compat/java8/JFunction1$mcDD$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcDD$sp extends JFunction1 {
1010
abstract double apply$mcDD$sp(double v1);
1111

12-
default Object apply(Object t) { return (Double) apply$mcDD$sp((Double) t); }
12+
default Object apply(Object t) { return (Double) apply$mcDD$sp(scala.runtime.BoxesRunTime.unboxToDouble(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcDF$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcDF$sp extends JFunction1 {
1010
abstract double apply$mcDF$sp(float v1);
1111

12-
default Object apply(Object t) { return (Double) apply$mcDF$sp((Float) t); }
12+
default Object apply(Object t) { return (Double) apply$mcDF$sp(scala.runtime.BoxesRunTime.unboxToFloat(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcDI$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcDI$sp extends JFunction1 {
1010
abstract double apply$mcDI$sp(int v1);
1111

12-
default Object apply(Object t) { return (Double) apply$mcDI$sp((Integer) t); }
12+
default Object apply(Object t) { return (Double) apply$mcDI$sp(scala.runtime.BoxesRunTime.unboxToInt(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcDJ$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcDJ$sp extends JFunction1 {
1010
abstract double apply$mcDJ$sp(long v1);
1111

12-
default Object apply(Object t) { return (Double) apply$mcDJ$sp((Long) t); }
12+
default Object apply(Object t) { return (Double) apply$mcDJ$sp(scala.runtime.BoxesRunTime.unboxToLong(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcFD$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcFD$sp extends JFunction1 {
1010
abstract float apply$mcFD$sp(double v1);
1111

12-
default Object apply(Object t) { return (Float) apply$mcFD$sp((Double) t); }
12+
default Object apply(Object t) { return (Float) apply$mcFD$sp(scala.runtime.BoxesRunTime.unboxToDouble(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcFF$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcFF$sp extends JFunction1 {
1010
abstract float apply$mcFF$sp(float v1);
1111

12-
default Object apply(Object t) { return (Float) apply$mcFF$sp((Float) t); }
12+
default Object apply(Object t) { return (Float) apply$mcFF$sp(scala.runtime.BoxesRunTime.unboxToFloat(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcFI$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcFI$sp extends JFunction1 {
1010
abstract float apply$mcFI$sp(int v1);
1111

12-
default Object apply(Object t) { return (Float) apply$mcFI$sp((Integer) t); }
12+
default Object apply(Object t) { return (Float) apply$mcFI$sp(scala.runtime.BoxesRunTime.unboxToInt(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcFJ$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcFJ$sp extends JFunction1 {
1010
abstract float apply$mcFJ$sp(long v1);
1111

12-
default Object apply(Object t) { return (Float) apply$mcFJ$sp((Long) t); }
12+
default Object apply(Object t) { return (Float) apply$mcFJ$sp(scala.runtime.BoxesRunTime.unboxToLong(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcID$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcID$sp extends JFunction1 {
1010
abstract int apply$mcID$sp(double v1);
1111

12-
default Object apply(Object t) { return (Integer) apply$mcID$sp((Double) t); }
12+
default Object apply(Object t) { return (Integer) apply$mcID$sp(scala.runtime.BoxesRunTime.unboxToDouble(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcIF$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcIF$sp extends JFunction1 {
1010
abstract int apply$mcIF$sp(float v1);
1111

12-
default Object apply(Object t) { return (Integer) apply$mcIF$sp((Float) t); }
12+
default Object apply(Object t) { return (Integer) apply$mcIF$sp(scala.runtime.BoxesRunTime.unboxToFloat(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcII$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcII$sp extends JFunction1 {
1010
abstract int apply$mcII$sp(int v1);
1111

12-
default Object apply(Object t) { return (Integer) apply$mcII$sp((Integer) t); }
12+
default Object apply(Object t) { return (Integer) apply$mcII$sp(scala.runtime.BoxesRunTime.unboxToInt(t)); }
1313
}

library/src/scala/compat/java8/JFunction1$mcIJ$sp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
public interface JFunction1$mcIJ$sp extends JFunction1 {
1010
abstract int apply$mcIJ$sp(long v1);
1111

12-
default Object apply(Object t) { return (Integer) apply$mcIJ$sp((Long) t); }
12+
default Object apply(Object t) { return (Integer) apply$mcIJ$sp(scala.runtime.BoxesRunTime.unboxToLong(t)); }
1313
}

0 commit comments

Comments
 (0)