@@ -27,7 +27,11 @@ sealed abstract class GadtConstraint extends Showable {
27
27
/** Is `sym1` ordered to be less than `sym2`? */
28
28
def isLess (sym1 : Symbol , sym2 : Symbol )(implicit ctx : Context ): Boolean
29
29
30
- def addEmptyBounds (sym : Symbol )(implicit ctx : Context ): Unit
30
+ /** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */
31
+ def addToConstraint (syms : List [Symbol ])(implicit ctx : Context ): Boolean
32
+ def addToConstraint (sym : Symbol )(implicit ctx : Context ): Boolean = addToConstraint(sym :: Nil )
33
+
34
+ /** Further constrain a symbol already present in the constraint. */
31
35
def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean
32
36
33
37
/** Is the symbol registered in the constraint?
@@ -72,7 +76,54 @@ final class ProperGadtConstraint private(
72
76
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
73
77
}
74
78
75
- override def addEmptyBounds (sym : Symbol )(implicit ctx : Context ): Unit = tvar(sym)
79
+ override def addToConstraint (params : List [Symbol ])(implicit ctx : Context ): Boolean = {
80
+ import NameKinds .DepParamName
81
+
82
+ val poly1 = PolyType (params.map { sym => DepParamName .fresh(sym.name.toTypeName) })(
83
+ pt => params.map { param =>
84
+ // replace the symbols in bound type `tp` which are in dependent positions
85
+ // with their internal TypeParamRefs
86
+ def substDependentSyms (tp : Type , isUpper : Boolean )(implicit ctx : Context ): Type = {
87
+ def loop (tp : Type ) = substDependentSyms(tp, isUpper)
88
+ tp match {
89
+ case tp @ AndType (tp1, tp2) if ! isUpper =>
90
+ tp.derivedAndType(loop(tp1), loop(tp2))
91
+ case tp @ OrType (tp1, tp2) if isUpper =>
92
+ tp.derivedOrType(loop(tp1), loop(tp2))
93
+ case tp : NamedType =>
94
+ params.indexOf(tp.symbol) match {
95
+ case - 1 =>
96
+ mapping(tp.symbol) match {
97
+ case tv : TypeVar => tv.origin
98
+ case null => tp
99
+ }
100
+ case i => pt.paramRefs(i)
101
+ }
102
+ case tp => tp
103
+ }
104
+ }
105
+
106
+ val tb = param.info.bounds
107
+ tb.derivedTypeBounds(
108
+ lo = substDependentSyms(tb.lo, isUpper = false ),
109
+ hi = substDependentSyms(tb.hi, isUpper = true )
110
+ )
111
+ },
112
+ pt => defn.AnyType
113
+ )
114
+
115
+ val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) =>
116
+ val tv = new TypeVar (paramRef, creatorState = null )
117
+ mapping = mapping.updated(sym, tv)
118
+ reverseMapping = reverseMapping.updated(tv.origin, sym)
119
+ tv
120
+ }
121
+
122
+ // the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings
123
+ addToConstraint(poly1, tvars).reporting({ _ =>
124
+ i " added to constraint: $params%, % \n $debugBoundsDescription"
125
+ }, gadts)
126
+ }
76
127
77
128
override def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean = {
78
129
@ annotation.tailrec def stripInternalTypeVar (tp : Type ): Type = tp match {
@@ -82,16 +133,17 @@ final class ProperGadtConstraint private(
82
133
case _ => tp
83
134
}
84
135
85
- val symTvar : TypeVar = stripInternalTypeVar(tvar (sym)) match {
136
+ val symTvar : TypeVar = stripInternalTypeVar(tvarOrError (sym)) match {
86
137
case tv : TypeVar => tv
87
138
case inst =>
88
139
gadts.println(i " instantiated: $sym -> $inst" )
89
140
return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
90
141
}
91
142
92
143
val internalizedBound = bound match {
93
- case nt : NamedType if contains(nt.symbol) =>
94
- stripInternalTypeVar(tvar(nt.symbol))
144
+ case nt : NamedType =>
145
+ val ntTvar = mapping(nt.symbol)
146
+ if (ntTvar ne null ) stripInternalTypeVar(ntTvar) else bound
95
147
case _ => bound
96
148
}
97
149
(
@@ -119,20 +171,22 @@ final class ProperGadtConstraint private(
119
171
if (isUpper) addUpperBound(symTvar.origin, bound1)
120
172
else addLowerBound(symTvar.origin, bound1)
121
173
}
122
- ).reporting({ res =>
174
+ ).reporting({ res =>
123
175
val descr = if (isUpper) " upper" else " lower"
124
176
val op = if (isUpper) " <:" else " >:"
125
- i " adding $descr bound $sym $op $bound = $res\t ( $symTvar $op $internalizedBound ) "
177
+ i " adding $descr bound $sym $op $bound = $res"
126
178
}, gadts)
127
179
}
128
180
129
181
override def isLess (sym1 : Symbol , sym2 : Symbol )(implicit ctx : Context ): Boolean =
130
- constraint.isLess(tvar (sym1).origin, tvar (sym2).origin)
182
+ constraint.isLess(tvarOrError (sym1).origin, tvarOrError (sym2).origin)
131
183
132
184
override def fullBounds (sym : Symbol )(implicit ctx : Context ): TypeBounds =
133
185
mapping(sym) match {
134
186
case null => null
135
- case tv => fullBounds(tv.origin)
187
+ case tv =>
188
+ fullBounds(tv.origin)
189
+ .ensuring(containsNoInternalTypes(_))
136
190
}
137
191
138
192
override def bounds (sym : Symbol )(implicit ctx : Context ): TypeBounds = {
@@ -145,14 +199,16 @@ final class ProperGadtConstraint private(
145
199
TypeAlias (reverseMapping(tpr).typeRef)
146
200
case tb => tb
147
201
}
148
- retrieveBounds// .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
202
+ retrieveBounds
203
+ // .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
204
+ .ensuring(containsNoInternalTypes(_))
149
205
}
150
206
}
151
207
152
208
override def contains (sym : Symbol )(implicit ctx : Context ): Boolean = mapping(sym) ne null
153
209
154
210
override def approximation (sym : Symbol , fromBelow : Boolean )(implicit ctx : Context ): Type = {
155
- val res = approximation(tvar (sym).origin, fromBelow = fromBelow)
211
+ val res = approximation(tvarOrError (sym).origin, fromBelow = fromBelow)
156
212
gadts.println(i " approximating $sym ~> $res" )
157
213
res
158
214
}
@@ -207,36 +263,21 @@ final class ProperGadtConstraint private(
207
263
case null => param
208
264
}
209
265
210
- private [this ] def tvar (sym : Symbol )(implicit ctx : Context ): TypeVar = {
211
- mapping(sym) match {
212
- case tv : TypeVar =>
213
- tv
214
- case null =>
215
- val res = {
216
- import NameKinds .DepParamName
217
- // For symbols standing for HK types, we need to preserve the kind information
218
- // (see also usage of adaptHKvariances above)
219
- // Ideally we'd always preserve the bounds,
220
- // but first we need an equivalent of ConstraintHandling#addConstraint
221
- // TODO: implement the above
222
- val initialBounds = sym.info match {
223
- case tb @ TypeBounds (_, hi) if hi.isLambdaSub => tb
224
- case _ => TypeBounds .empty
225
- }
226
- // avoid registering the TypeVar with TyperState / TyperState#constraint
227
- // - we don't want TyperState instantiating these TypeVars
228
- // - we don't want TypeComparer constraining these TypeVars
229
- val poly = PolyType (DepParamName .fresh(sym.name.toTypeName) :: Nil )(
230
- pt => initialBounds :: Nil ,
231
- pt => defn.AnyType )
232
- new TypeVar (poly.paramRefs.head, creatorState = null )
233
- }
234
- gadts.println(i " GADTMap: created tvar $sym -> $res" )
235
- constraint = constraint.add(res.origin.binder, res :: Nil )
236
- mapping = mapping.updated(sym, res)
237
- reverseMapping = reverseMapping.updated(res.origin, sym)
238
- res
239
- }
266
+ private [this ] def tvarOrError (sym : Symbol )(implicit ctx : Context ): TypeVar =
267
+ mapping(sym).ensuring(_ ne null , i " not a constrainable symbol: $sym" )
268
+
269
+ private [this ] def containsNoInternalTypes (
270
+ tp : Type ,
271
+ acc : TypeAccumulator [Boolean ] = null
272
+ )(implicit ctx : Context ): Boolean = tp match {
273
+ case tpr : TypeParamRef => ! reverseMapping.contains(tpr)
274
+ case tv : TypeVar => ! reverseMapping.contains(tv.origin)
275
+ case tp =>
276
+ (if (acc ne null ) acc else new ContainsNoInternalTypesAccumulator ()).foldOver(true , tp)
277
+ }
278
+
279
+ private [this ] class ContainsNoInternalTypesAccumulator (implicit ctx : Context ) extends TypeAccumulator [Boolean ] {
280
+ override def apply (x : Boolean , tp : Type ): Boolean = x && containsNoInternalTypes(tp)
240
281
}
241
282
242
283
// ---- Debug ------------------------------------------------------------
@@ -266,7 +307,7 @@ final class ProperGadtConstraint private(
266
307
267
308
override def contains (sym : Symbol )(implicit ctx : Context ) = false
268
309
269
- override def addEmptyBounds ( sym : Symbol )(implicit ctx : Context ): Unit = unsupported(" EmptyGadtConstraint.addEmptyBounds " )
310
+ override def addToConstraint ( params : List [ Symbol ] )(implicit ctx : Context ): Boolean = unsupported(" EmptyGadtConstraint.addToConstraint " )
270
311
override def addBound (sym : Symbol , bound : Type , isUpper : Boolean )(implicit ctx : Context ): Boolean = unsupported(" EmptyGadtConstraint.addBound" )
271
312
272
313
override def approximation (sym : Symbol , fromBelow : Boolean )(implicit ctx : Context ): Type = unsupported(" EmptyGadtConstraint.approximation" )
0 commit comments