@@ -3,18 +3,26 @@ package dotty.tools.dotc.transform
3
3
import dotty .tools .dotc .ast .TreeTypeMap
4
4
import dotty .tools .dotc .ast .tpd ._
5
5
import dotty .tools .dotc .core .Contexts .Context
6
- import dotty .tools .dotc .core .Flags
6
+ import dotty .tools .dotc .core .{ Symbols , Flags }
7
7
import dotty .tools .dotc .core .Types ._
8
8
import dotty .tools .dotc .transform .TreeTransforms .{TransformerInfo , MiniPhaseTransform }
9
9
import dotty .tools .dotc .core .Decorators ._
10
- import scala .collection .mutable .{ ListBuffer , ArrayBuffer }
10
+ import scala .collection .mutable
11
11
12
12
class TypeSpecializer extends MiniPhaseTransform {
13
13
14
- override def phaseName = " Type Specializer "
14
+ override def phaseName = " specialize "
15
15
16
16
final val maxTparamsToSpecialize = 2
17
-
17
+
18
+ private val specializationRequests : mutable.HashMap [Symbol , List [List [Type ]]] = mutable.HashMap .empty
19
+
20
+ def registerSpecializationRequest (method : Symbol )(arguments : List [Type ])(implicit ctx : Context ) = {
21
+ assert(ctx.phaseId <= this .period.phaseId)
22
+ val prev = specializationRequests.getOrElse(method, List .empty)
23
+ specializationRequests.put(method, arguments :: prev)
24
+ }
25
+
18
26
private final def specialisedTypes (implicit ctx : Context ) =
19
27
Map (ctx.definitions.ByteType -> " $mcB$sp" ,
20
28
ctx.definitions.BooleanType -> " $mcZ$sp" ,
@@ -25,24 +33,30 @@ class TypeSpecializer extends MiniPhaseTransform {
25
33
ctx.definitions.DoubleType -> " $mcD$sp" ,
26
34
ctx.definitions.CharType -> " $mcC$sp" ,
27
35
ctx.definitions.UnitType -> " $mcV$sp" )
28
-
36
+
37
+ def shouldSpecializeForAll (sym : Symbols .Symbol )(implicit ctx : Context ): Boolean = {
38
+ // either -Yspecialize:all is given, or sym has @specialize annotation
39
+ sym.denot.hasAnnotation(ctx.definitions.specializedAnnot) || (ctx.settings.Yspecialize .value == " all" )
40
+ }
41
+
42
+ def shouldSpecializeForSome (sym : Symbol )(implicit ctx : Context ): List [List [Type ]] = {
43
+ specializationRequests.getOrElse(sym, Nil )
44
+ }
45
+
46
+
47
+
48
+
29
49
override def transformDefDef (tree : DefDef )(implicit ctx : Context , info : TransformerInfo ): Tree = {
30
50
31
- def rewireType (tpe : Type ) = tpe match {
32
- case tpe : TermRef => tpe.widen
33
- case _ => tpe
34
- }
35
-
36
51
tree.tpe.widen match {
37
52
38
53
case poly : PolyType if ! (tree.symbol.isPrimaryConstructor
39
- || (tree.symbol is Flags .Label )
40
- || (tree.tparams.length > maxTparamsToSpecialize)) => {
54
+ || (tree.symbol is Flags .Label )) => {
41
55
val origTParams = tree.tparams.map(_.symbol)
42
56
val origVParams = tree.vparamss.flatten.map(_.symbol)
43
57
println(s " specializing ${tree.symbol} for Tparams: ${origTParams.length}" )
44
58
45
- def specialize (instatiations : collection.mutable. ListBuffer [Type ], names : collection.mutable. ArrayBuffer [String ]): Tree = {
59
+ def specialize (instatiations : List [Type ], names : List [String ]): Tree = {
46
60
47
61
val newSym = ctx.newSymbol(tree.symbol.owner, (tree.name + names.mkString).toTermName, tree.symbol.flags | Flags .Synthetic , poly.instantiate(instatiations.toList))
48
62
polyDefDef(newSym, { tparams => vparams => {
@@ -58,29 +72,25 @@ class TypeSpecializer extends MiniPhaseTransform {
58
72
})
59
73
}
60
74
61
- def generateSpecializations (remainingTParams : List [TypeDef ])
62
- (instatiated : ArrayBuffer [ TypeDef ], instatiations : ListBuffer [Type ],
63
- names : ArrayBuffer [String ]): Iterable [Tree ] = {
75
+ def generateSpecializations (remainingTParams : List [TypeDef ], remainingBounds : List [ TypeBounds ] )
76
+ (instatiations : List [Type ],
77
+ names : List [String ]): Iterable [Tree ] = {
64
78
if (remainingTParams.nonEmpty) {
65
79
val typeToSpecialize = remainingTParams.head
66
- specialisedTypes.flatMap { tpnme =>
80
+ val bounds = remainingBounds.head
81
+ specialisedTypes.filter{ tpnme =>
82
+ bounds.contains(tpnme._1)
83
+ }.flatMap { tpnme =>
67
84
val tpe = tpnme._1
68
85
val nme = tpnme._2
69
- instatiated.+= (typeToSpecialize)
70
- instatiations.+= (tpe)
71
- names.+= (nme)
72
- val r = generateSpecializations(remainingTParams.tail)(instatiated, instatiations, names)
73
- instatiated.drop(1 )
74
- instatiations.drop(1 )
75
- names.drop(1 )
76
- r
86
+ generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instatiations, nme :: names)
77
87
}
78
88
} else
79
- List (specialize(instatiations, names))
89
+ List (specialize(instatiations.reverse , names.reverse ))
80
90
}
81
91
82
92
83
- Thicket (tree :: generateSpecializations(tree.tparams)( ArrayBuffer .empty, ListBuffer .empty, ArrayBuffer .empty).toList)
93
+ Thicket (tree :: generateSpecializations(tree.tparams, poly.paramBounds)( List .empty, List .empty).toList)
84
94
}
85
95
case _ => tree
86
96
}
0 commit comments