Skip to content

Commit 7685e83

Browse files
committed
Refactor FunctionalInterfaces
This commit does not change the end result of FunctionalInterfaces but makes the code easier to read, and add a `Definitions#isSpecializableFunction` method used in the next commit.
1 parent f795f09 commit 7685e83

File tree

2 files changed

+45
-49
lines changed

2 files changed

+45
-49
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,34 @@ class Definitions {
911911
arity >= 0 && isFunctionClass(sym) && tp.isRef(FunctionType(arity, sym.name.isImplicitFunction).typeSymbol)
912912
}
913913

914+
// Specialized type parameters defined for scala.Function{0,1,2}.
915+
private lazy val Function1SpecializedParams: collection.Set[Type] =
916+
Set(IntType, LongType, FloatType, DoubleType)
917+
private lazy val Function2SpecializedParams: collection.Set[Type] =
918+
Set(IntType, LongType, DoubleType)
919+
private lazy val Function0SpecializedReturns: collection.Set[Type] =
920+
ScalaNumericValueTypeList.toSet[Type] + UnitType + BooleanType
921+
private lazy val Function1SpecializedReturns: collection.Set[Type] =
922+
Set(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
923+
private lazy val Function2SpecializedReturns: collection.Set[Type] =
924+
Function1SpecializedReturns
925+
926+
def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(implicit ctx: Context) =
927+
isFunctionClass(cls) && (paramTypes match {
928+
case Nil =>
929+
Function0SpecializedReturns.contains(retType)
930+
case List(paramType0) =>
931+
Function1SpecializedParams.contains(paramType0) &&
932+
Function1SpecializedReturns.contains(retType)
933+
case List(paramType0, paramType1) =>
934+
Function2SpecializedParams.contains(paramType0) &&
935+
Function2SpecializedParams.contains(paramType1) &&
936+
Function2SpecializedReturns.contains(retType)
937+
case _ =>
938+
false
939+
})
940+
941+
914942
def functionArity(tp: Type)(implicit ctx: Context) = tp.dealias.argInfos.length - 1
915943

916944
def isImplicitFunctionType(tp: Type)(implicit ctx: Context) =

compiler/src/dotty/tools/dotc/transform/FunctionalInterfaces.scala

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,58 +26,26 @@ class FunctionalInterfaces extends MiniPhaseTransform {
2626

2727
def phaseName: String = "functionalInterfaces"
2828

29-
private var allowedReturnTypes: Set[Symbol] = _ // moved here to make it explicit what specializations are generated
30-
private var allowedArgumentTypes: Set[Symbol] = _
31-
val maxArgsCount = 2
32-
33-
def shouldSpecialize(m: MethodType)(implicit ctx: Context) =
34-
(m.paramInfos.size <= maxArgsCount) &&
35-
m.paramInfos.forall(x => allowedArgumentTypes.contains(x.typeSymbol)) &&
36-
allowedReturnTypes.contains(m.resultType.typeSymbol)
37-
3829
val functionName = "JFunction".toTermName
3930
val functionPackage = "scala.compat.java8.".toTermName
4031

41-
override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context): TreeTransform = {
42-
allowedReturnTypes = Set(defn.UnitClass,
43-
defn.BooleanClass,
44-
defn.IntClass,
45-
defn.FloatClass,
46-
defn.LongClass,
47-
defn.DoubleClass,
48-
/* only for Function0: */ defn.ByteClass,
49-
defn.ShortClass,
50-
defn.CharClass)
51-
52-
allowedArgumentTypes = Set(defn.IntClass,
53-
defn.LongClass,
54-
defn.DoubleClass,
55-
/* only for Function1: */ defn.FloatClass)
56-
57-
this
58-
}
59-
6032
override def transformClosure(tree: Closure)(implicit ctx: Context, info: TransformerInfo): Tree = {
61-
tree.tpt match {
62-
case EmptyTree =>
63-
val m = tree.meth.tpe.widen.asInstanceOf[MethodType]
64-
65-
if (shouldSpecialize(m)) {
66-
val functionSymbol = tree.tpe.widenDealias.classSymbol
67-
val names = ctx.atPhase(ctx.erasurePhase) {
68-
implicit ctx => functionSymbol.typeParams.map(_.name)
69-
}
70-
val interfaceName = (functionName ++ m.paramInfos.length.toString).specializedFor(m.paramInfos ::: m.resultType :: Nil, names, Nil, Nil)
71-
72-
// symbols loaded from classpath aren't defined in periods earlier than when they where loaded
73-
val interface = ctx.withPhase(ctx.typerPhase).getClassIfDefined(functionPackage ++ interfaceName)
74-
if (interface.exists) {
75-
val tpt = tpd.TypeTree(interface.asType.appliedRef)
76-
tpd.Closure(tree.env, tree.meth, tpt)
77-
} else tree
78-
} else tree
79-
case _ =>
80-
tree
81-
}
33+
val cls = tree.tpe.widen.classSymbol.asClass
34+
35+
val implType = tree.meth.tpe.widen
36+
val List(implParamTypes) = implType.paramInfoss
37+
val implResultType = implType.resultType
38+
39+
if (defn.isSpecializableFunction(cls, implParamTypes, implResultType)) {
40+
val names = ctx.atPhase(ctx.erasurePhase) {
41+
implicit ctx => cls.typeParams.map(_.name)
42+
}
43+
val interfaceName = (functionName ++ implParamTypes.length.toString).specializedFor(implParamTypes ::: implResultType :: Nil, names, Nil, Nil)
44+
45+
// symbols loaded from classpath aren't defined in periods earlier than when they where loaded
46+
val interface = ctx.withPhase(ctx.typerPhase).requiredClass(functionPackage ++ interfaceName)
47+
val tpt = tpd.TypeTree(interface.asType.appliedRef)
48+
tpd.Closure(tree.env, tree.meth, tpt)
49+
} else tree
8250
}
8351
}

0 commit comments

Comments
 (0)