@@ -26,58 +26,26 @@ class FunctionalInterfaces extends MiniPhaseTransform {
26
26
27
27
def phaseName : String = " functionalInterfaces"
28
28
29
- private var allowedReturnTypes : Set [Symbol ] = _ // moved here to make it explicit what specializations are generated
30
- private var allowedArgumentTypes : Set [Symbol ] = _
31
- val maxArgsCount = 2
32
-
33
- def shouldSpecialize (m : MethodType )(implicit ctx : Context ) =
34
- (m.paramInfos.size <= maxArgsCount) &&
35
- m.paramInfos.forall(x => allowedArgumentTypes.contains(x.typeSymbol)) &&
36
- allowedReturnTypes.contains(m.resultType.typeSymbol)
37
-
38
29
val functionName = " JFunction" .toTermName
39
30
val functionPackage = " scala.compat.java8." .toTermName
40
31
41
- override def prepareForUnit (tree : tpd.Tree )(implicit ctx : Context ): TreeTransform = {
42
- allowedReturnTypes = Set (defn.UnitClass ,
43
- defn.BooleanClass ,
44
- defn.IntClass ,
45
- defn.FloatClass ,
46
- defn.LongClass ,
47
- defn.DoubleClass ,
48
- /* only for Function0: */ defn.ByteClass ,
49
- defn.ShortClass ,
50
- defn.CharClass )
51
-
52
- allowedArgumentTypes = Set (defn.IntClass ,
53
- defn.LongClass ,
54
- defn.DoubleClass ,
55
- /* only for Function1: */ defn.FloatClass )
56
-
57
- this
58
- }
59
-
60
32
override def transformClosure (tree : Closure )(implicit ctx : Context , info : TransformerInfo ): Tree = {
61
- tree.tpt match {
62
- case EmptyTree =>
63
- val m = tree.meth.tpe.widen.asInstanceOf [MethodType ]
64
-
65
- if (shouldSpecialize(m)) {
66
- val functionSymbol = tree.tpe.widenDealias.classSymbol
67
- val names = ctx.atPhase(ctx.erasurePhase) {
68
- implicit ctx => functionSymbol.typeParams.map(_.name)
69
- }
70
- val interfaceName = (functionName ++ m.paramInfos.length.toString).specializedFor(m.paramInfos ::: m.resultType :: Nil , names, Nil , Nil )
71
-
72
- // symbols loaded from classpath aren't defined in periods earlier than when they where loaded
73
- val interface = ctx.withPhase(ctx.typerPhase).getClassIfDefined(functionPackage ++ interfaceName)
74
- if (interface.exists) {
75
- val tpt = tpd.TypeTree (interface.asType.appliedRef)
76
- tpd.Closure (tree.env, tree.meth, tpt)
77
- } else tree
78
- } else tree
79
- case _ =>
80
- tree
81
- }
33
+ val cls = tree.tpe.widen.classSymbol.asClass
34
+
35
+ val implType = tree.meth.tpe.widen
36
+ val List (implParamTypes) = implType.paramInfoss
37
+ val implResultType = implType.resultType
38
+
39
+ if (defn.isSpecializableFunction(cls, implParamTypes, implResultType)) {
40
+ val names = ctx.atPhase(ctx.erasurePhase) {
41
+ implicit ctx => cls.typeParams.map(_.name)
42
+ }
43
+ val interfaceName = (functionName ++ implParamTypes.length.toString).specializedFor(implParamTypes ::: implResultType :: Nil , names, Nil , Nil )
44
+
45
+ // symbols loaded from classpath aren't defined in periods earlier than when they where loaded
46
+ val interface = ctx.withPhase(ctx.typerPhase).requiredClass(functionPackage ++ interfaceName)
47
+ val tpt = tpd.TypeTree (interface.asType.appliedRef)
48
+ tpd.Closure (tree.env, tree.meth, tpt)
49
+ } else tree
82
50
}
83
51
}
0 commit comments