Skip to content

Various fixes related to lambda adaptation #3203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,34 @@ class Definitions {
arity >= 0 && isFunctionClass(sym) && tp.isRef(FunctionType(arity, sym.name.isImplicitFunction).typeSymbol)
}

// Specialized type parameters defined for scala.Function{0,1,2}.
private lazy val Function1SpecializedParams: collection.Set[Type] =
Set(IntType, LongType, FloatType, DoubleType)
private lazy val Function2SpecializedParams: collection.Set[Type] =
Set(IntType, LongType, DoubleType)
private lazy val Function0SpecializedReturns: collection.Set[Type] =
ScalaNumericValueTypeList.toSet[Type] + UnitType + BooleanType
private lazy val Function1SpecializedReturns: collection.Set[Type] =
Set(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
private lazy val Function2SpecializedReturns: collection.Set[Type] =
Function1SpecializedReturns

def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(implicit ctx: Context) =
isFunctionClass(cls) && (paramTypes match {
case Nil =>
Function0SpecializedReturns.contains(retType)
case List(paramType0) =>
Function1SpecializedParams.contains(paramType0) &&
Function1SpecializedReturns.contains(retType)
case List(paramType0, paramType1) =>
Function2SpecializedParams.contains(paramType0) &&
Function2SpecializedParams.contains(paramType1) &&
Function2SpecializedReturns.contains(retType)
case _ =>
false
})


def functionArity(tp: Type)(implicit ctx: Context) = tp.dealias.argInfos.length - 1

def isImplicitFunctionType(tp: Type)(implicit ctx: Context) =
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/NameKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ object NameKinds {
val ModuleVarName = new SuffixNameKind(OBJECTVAR, "$module")
val ModuleClassName = new SuffixNameKind(OBJECTCLASS, "$", optInfoString = "ModuleClass")
val ImplMethName = new SuffixNameKind(IMPLMETH, "$")
val AdaptedClosureName = new SuffixNameKind(ADAPTEDCLOSURE, "$adapted") { override def definesNewName = true }

/** A name together with a signature. Used in Tasty trees. */
object SignedName extends NameKind(SIGNED) {
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Macro-format:
DIRECT Length underlying_NameRef
FIELD Length underlying_NameRef
EXTMETH Length underlying_NameRef
ADAPTEDCLOSURE Length underlying_NameRef
OBJECTVAR Length underlying_NameRef
OBJECTCLASS Length underlying_NameRef
SIGNED Length original_NameRef resultSig_NameRef paramSig_NameRef*
Expand Down Expand Up @@ -253,6 +254,7 @@ object TastyFormat {
final val DIRECT = 31
final val FIELD = 32
final val EXTMETH = 33
final val ADAPTEDCLOSURE = 34
final val OBJECTVAR = 39
final val OBJECTCLASS = 40

Expand Down Expand Up @@ -471,6 +473,7 @@ object TastyFormat {
case DIRECT => "DIRECT"
case FIELD => "FIELD"
case EXTMETH => "EXTMETH"
case ADAPTEDCLOSURE => "ADAPTEDCLOSURE"
case OBJECTVAR => "OBJECTVAR"
case OBJECTCLASS => "OBJECTCLASS"

Expand Down
106 changes: 69 additions & 37 deletions compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import core.Types._
import core.Names._
import core.StdNames._
import core.NameOps._
import core.NameKinds.AdaptedClosureName
import core.Decorators._
import core.Constants._
import core.Definitions._
Expand Down Expand Up @@ -565,54 +566,85 @@ object Erasure {
super.typedDefDef(ddef1, sym)
}

/** After erasure, we may have to replace the closure method by a bridge.
* LambdaMetaFactory handles this automatically for most types, but we have
* to deal with boxing and unboxing of value classes ourselves.
*/
override def typedClosure(tree: untpd.Closure, pt: Type)(implicit ctx: Context) = {
val xxl = defn.isXXLFunctionClass(tree.typeOpt.typeSymbol)
var implClosure @ Closure(_, meth, _) = super.typedClosure(tree, pt)
if (xxl) implClosure = cpy.Closure(implClosure)(tpt = TypeTree(defn.FunctionXXLType))
implClosure.tpe match {
case SAMType(sam) =>
val implType = meth.tpe.widen
val implType = meth.tpe.widen.asInstanceOf[MethodType]

val List(implParamTypes) = implType.paramInfoss
val implParamTypes = implType.paramInfos
val List(samParamTypes) = sam.info.paramInfoss
val implResultType = implType.resultType
val samResultType = sam.info.resultType

// Given a value class V with an underlying type U, the following code:
// val f: Function1[V, V] = x => ...
// results in the creation of a closure and a method:
// def $anonfun(v1: V): V = ...
// val f: Function1[V, V] = closure($anonfun)
// After [[Erasure]] this method will look like:
// def $anonfun(v1: ErasedValueType(V, U)): ErasedValueType(V, U) = ...
// And after [[ElimErasedValueType]] it will look like:
// def $anonfun(v1: U): U = ...
// This method does not implement the SAM of Function1[V, V] anymore and
// needs to be replaced by a bridge:
// def $anonfun$2(v1: V): V = new V($anonfun(v1.underlying))
// val f: Function1 = closure($anonfun$2)
// In general, a bridge is needed when the signature of the closure method after
// Erasure contains an ErasedValueType but the corresponding type in the functional
// interface is not an ErasedValueType.
val bridgeNeeded =
(implResultType :: implParamTypes, samResultType :: samParamTypes).zipped.exists(
(implType, samType) => implType.isErasedValueType && !samType.isErasedValueType
)

if (bridgeNeeded) {
val bridge = ctx.newSymbol(ctx.owner, nme.ANON_FUN, Flags.Synthetic | Flags.Method, sam.info)
val bridgeCtx = ctx.withOwner(bridge)
Closure(bridge, bridgeParamss => {
implicit val ctx = bridgeCtx

val List(bridgeParams) = bridgeParamss
val rhs = Apply(meth, (bridgeParams, implParamTypes).zipped.map(adapt(_, _)))
adapt(rhs, sam.info.resultType)
})
// The following code:
//
// val f: Function1[Int, Any] = x => ...
//
// results in the creation of a closure and a method in the typer:
//
// def $anonfun(x: Int): Any = ...
// val f: Function1[Int, Any] = closure($anonfun)
//
// Notice that `$anonfun` takes a primitive as argument, but the single abstract method
// of `Function1` after erasure is:
//
// def apply(x: Object): Object
//
// which takes a reference as argument. Hence, some form of adaptation is required.
//
// If we do nothing, the LambdaMetaFactory bootstrap method will
// automatically do the adaptation. Unfortunately, the result does not
// implement the expected Scala semantics: null should be "unboxed" to
// the default value of the value class, but LMF will throw a
// NullPointerException instead. LMF is also not capable of doing
// adaptation for derived value classes.
//
// Thus, we need to replace the closure method by a bridge method that
// forwards to the original closure method with appropriate
// boxing/unboxing. For our example above, this would be:
//
// def $anonfun1(x: Object): Object = $anonfun(BoxesRunTime.unboxToInt(x))
// val f: Function1 = closure($anonfun1)
//
// In general, a bridge is needed when, after Erasure:
// - one of the parameter type of the closure method is a non-reference type,
// and the corresponding type in the SAM is a reference type
// - or the result type of the closure method is an erased value type
// and the result type in the SAM isn't
// However, the following exception exists: If the SAM is replaced by
// JFunction*mc* in [[FunctionalInterfaces]], no bridge is needed: the
// SAM contains default methods to handle adaptation
//
// See test cases lambda-*.scala and t8017/ for concrete examples.

def isReferenceType(tp: Type) = !tp.isPrimitiveValueType && !tp.isErasedValueType

if (!defn.isSpecializableFunction(implClosure.tpe.widen.classSymbol.asClass, implParamTypes, implResultType)) {
val paramAdaptationNeeded =
(implParamTypes, samParamTypes).zipped.exists((implType, samType) =>
!isReferenceType(implType) && isReferenceType(samType))
val resultAdaptationNeeded =
implResultType.isErasedValueType && !samResultType.isErasedValueType

if (paramAdaptationNeeded || resultAdaptationNeeded) {
val bridgeType =
if (paramAdaptationNeeded) {
if (resultAdaptationNeeded) sam.info
else implType.derivedLambdaType(paramInfos = samParamTypes)
} else implType.derivedLambdaType(resType = samResultType)
val bridge = ctx.newSymbol(ctx.owner, AdaptedClosureName(meth.symbol.name.asTermName), Flags.Synthetic | Flags.Method, bridgeType)
val bridgeCtx = ctx.withOwner(bridge)
Closure(bridge, bridgeParamss => {
implicit val ctx = bridgeCtx

val List(bridgeParams) = bridgeParamss
val rhs = Apply(meth, (bridgeParams, implParamTypes).zipped.map(adapt(_, _)))
adapt(rhs, bridgeType.resultType)
}, targetType = implClosure.tpt.tpe)
} else implClosure
} else implClosure
case _ =>
implClosure
Expand Down
66 changes: 17 additions & 49 deletions compiler/src/dotty/tools/dotc/transform/FunctionalInterfaces.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,26 @@ class FunctionalInterfaces extends MiniPhaseTransform {

def phaseName: String = "functionalInterfaces"

private var allowedReturnTypes: Set[Symbol] = _ // moved here to make it explicit what specializations are generated
private var allowedArgumentTypes: Set[Symbol] = _
val maxArgsCount = 2

def shouldSpecialize(m: MethodType)(implicit ctx: Context) =
(m.paramInfos.size <= maxArgsCount) &&
m.paramInfos.forall(x => allowedArgumentTypes.contains(x.typeSymbol)) &&
allowedReturnTypes.contains(m.resultType.typeSymbol)

val functionName = "JFunction".toTermName
val functionPackage = "scala.compat.java8.".toTermName

override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = {
allowedReturnTypes = Set(defn.UnitClass,
defn.BooleanClass,
defn.IntClass,
defn.FloatClass,
defn.LongClass,
defn.DoubleClass,
/* only for Function0: */ defn.ByteClass,
defn.ShortClass,
defn.CharClass)

allowedArgumentTypes = Set(defn.IntClass,
defn.LongClass,
defn.DoubleClass,
/* only for Function1: */ defn.FloatClass)

this
}

override def transformClosure(tree: Closure)(implicit ctx: Context, info: TransformerInfo): Tree = {
tree.tpt match {
case EmptyTree =>
val m = tree.meth.tpe.widen.asInstanceOf[MethodType]

if (shouldSpecialize(m)) {
val functionSymbol = tree.tpe.widenDealias.classSymbol
val names = ctx.atPhase(ctx.erasurePhase) {
implicit ctx => functionSymbol.typeParams.map(_.name)
}
val interfaceName = (functionName ++ m.paramInfos.length.toString).specializedFor(m.paramInfos ::: m.resultType :: Nil, names, Nil, Nil)

// symbols loaded from classpath aren't defined in periods earlier than when they where loaded
val interface = ctx.withPhase(ctx.typerPhase).getClassIfDefined(functionPackage ++ interfaceName)
if (interface.exists) {
val tpt = tpd.TypeTree(interface.asType.appliedRef)
tpd.Closure(tree.env, tree.meth, tpt)
} else tree
} else tree
case _ =>
tree
}
val cls = tree.tpe.widen.classSymbol.asClass

val implType = tree.meth.tpe.widen
val List(implParamTypes) = implType.paramInfoss
val implResultType = implType.resultType

if (defn.isSpecializableFunction(cls, implParamTypes, implResultType)) {
val names = ctx.atPhase(ctx.erasurePhase) {
implicit ctx => cls.typeParams.map(_.name)
}
val interfaceName = (functionName ++ implParamTypes.length.toString).specializedFor(implParamTypes ::: implResultType :: Nil, names, Nil, Nil)

// symbols loaded from classpath aren't defined in periods earlier than when they where loaded
val interface = ctx.withPhase(ctx.typerPhase).requiredClass(functionPackage ++ interfaceName)
val tpt = tpd.TypeTree(interface.asType.appliedRef)
tpd.Closure(tree.env, tree.meth, tpt)
} else tree
}
}
16 changes: 8 additions & 8 deletions library/src/scala/compat/java8/JFunction0.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ public interface JFunction0<R> extends scala.Function0<R> {
apply();
}
default byte apply$mcB$sp() {
return (Byte) apply();
return scala.runtime.BoxesRunTime.unboxToByte(apply());
}
default short apply$mcS$sp() {
return (Short) apply();
return scala.runtime.BoxesRunTime.unboxToShort(apply());
}
default int apply$mcI$sp() {
return (Integer) apply();
return scala.runtime.BoxesRunTime.unboxToInt(apply());
}
default long apply$mcJ$sp() {
return (Long) apply();
return scala.runtime.BoxesRunTime.unboxToLong(apply());
}
default char apply$mcC$sp() {
return (Character) apply();
return scala.runtime.BoxesRunTime.unboxToChar(apply());
}
default float apply$mcF$sp() {
return (Float) apply();
return scala.runtime.BoxesRunTime.unboxToFloat(apply());
}
default double apply$mcD$sp() {
return (Double) apply();
return scala.runtime.BoxesRunTime.unboxToDouble(apply());
}
default boolean apply$mcZ$sp() {
return (Boolean) apply();
return scala.runtime.BoxesRunTime.unboxToBoolean(apply());
}
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcDD$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcDD$sp extends JFunction1 {
abstract double apply$mcDD$sp(double v1);

default Object apply(Object t) { return (Double) apply$mcDD$sp((Double) t); }
default Object apply(Object t) { return (Double) apply$mcDD$sp(scala.runtime.BoxesRunTime.unboxToDouble(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcDF$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcDF$sp extends JFunction1 {
abstract double apply$mcDF$sp(float v1);

default Object apply(Object t) { return (Double) apply$mcDF$sp((Float) t); }
default Object apply(Object t) { return (Double) apply$mcDF$sp(scala.runtime.BoxesRunTime.unboxToFloat(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcDI$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcDI$sp extends JFunction1 {
abstract double apply$mcDI$sp(int v1);

default Object apply(Object t) { return (Double) apply$mcDI$sp((Integer) t); }
default Object apply(Object t) { return (Double) apply$mcDI$sp(scala.runtime.BoxesRunTime.unboxToInt(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcDJ$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcDJ$sp extends JFunction1 {
abstract double apply$mcDJ$sp(long v1);

default Object apply(Object t) { return (Double) apply$mcDJ$sp((Long) t); }
default Object apply(Object t) { return (Double) apply$mcDJ$sp(scala.runtime.BoxesRunTime.unboxToLong(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcFD$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcFD$sp extends JFunction1 {
abstract float apply$mcFD$sp(double v1);

default Object apply(Object t) { return (Float) apply$mcFD$sp((Double) t); }
default Object apply(Object t) { return (Float) apply$mcFD$sp(scala.runtime.BoxesRunTime.unboxToDouble(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcFF$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcFF$sp extends JFunction1 {
abstract float apply$mcFF$sp(float v1);

default Object apply(Object t) { return (Float) apply$mcFF$sp((Float) t); }
default Object apply(Object t) { return (Float) apply$mcFF$sp(scala.runtime.BoxesRunTime.unboxToFloat(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcFI$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcFI$sp extends JFunction1 {
abstract float apply$mcFI$sp(int v1);

default Object apply(Object t) { return (Float) apply$mcFI$sp((Integer) t); }
default Object apply(Object t) { return (Float) apply$mcFI$sp(scala.runtime.BoxesRunTime.unboxToInt(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcFJ$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcFJ$sp extends JFunction1 {
abstract float apply$mcFJ$sp(long v1);

default Object apply(Object t) { return (Float) apply$mcFJ$sp((Long) t); }
default Object apply(Object t) { return (Float) apply$mcFJ$sp(scala.runtime.BoxesRunTime.unboxToLong(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcID$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcID$sp extends JFunction1 {
abstract int apply$mcID$sp(double v1);

default Object apply(Object t) { return (Integer) apply$mcID$sp((Double) t); }
default Object apply(Object t) { return (Integer) apply$mcID$sp(scala.runtime.BoxesRunTime.unboxToDouble(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcIF$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcIF$sp extends JFunction1 {
abstract int apply$mcIF$sp(float v1);

default Object apply(Object t) { return (Integer) apply$mcIF$sp((Float) t); }
default Object apply(Object t) { return (Integer) apply$mcIF$sp(scala.runtime.BoxesRunTime.unboxToFloat(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcII$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcII$sp extends JFunction1 {
abstract int apply$mcII$sp(int v1);

default Object apply(Object t) { return (Integer) apply$mcII$sp((Integer) t); }
default Object apply(Object t) { return (Integer) apply$mcII$sp(scala.runtime.BoxesRunTime.unboxToInt(t)); }
}
2 changes: 1 addition & 1 deletion library/src/scala/compat/java8/JFunction1$mcIJ$sp.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
public interface JFunction1$mcIJ$sp extends JFunction1 {
abstract int apply$mcIJ$sp(long v1);

default Object apply(Object t) { return (Integer) apply$mcIJ$sp((Long) t); }
default Object apply(Object t) { return (Integer) apply$mcIJ$sp(scala.runtime.BoxesRunTime.unboxToLong(t)); }
}
Loading