Skip to content

Commit 177bd36

Browse files
committed
Allow adding multiple symbols to GadtConstraint simultaneously
The added symbols can have inter-dependencies in their bounds.
1 parent 1171db4 commit 177bd36

File tree

7 files changed

+117
-93
lines changed

7 files changed

+117
-93
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ sealed abstract class GadtConstraint extends Showable {
2727
/** Is `sym1` ordered to be less than `sym2`? */
2828
def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean
2929

30-
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
30+
/** Add symbols to constraint, preserving the underlying bounds and handling inter-dependencies. */
31+
def addToConstraint(syms: List[Symbol])(implicit ctx: Context): Boolean
32+
def addToConstraint(sym: Symbol)(implicit ctx: Context): Boolean = addToConstraint(sym :: Nil)
33+
34+
/** Further constrain a symbol already present in the constraint. */
3135
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
3236

3337
/** Is the symbol registered in the constraint?
@@ -72,7 +76,54 @@ final class ProperGadtConstraint private(
7276
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
7377
}
7478

75-
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)
79+
override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = {
80+
import NameKinds.DepParamName
81+
82+
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
83+
pt => params.map { param =>
84+
// replace the symbols in bound type `tp` which are in dependent positions
85+
// with their internal TypeParamRefs
86+
def substDependentSyms(tp: Type, isUpper: Boolean)(implicit ctx: Context): Type = {
87+
def loop(tp: Type) = substDependentSyms(tp, isUpper)
88+
tp match {
89+
case tp @ AndType(tp1, tp2) if !isUpper =>
90+
tp.derivedAndType(loop(tp1), loop(tp2))
91+
case tp @ OrType(tp1, tp2) if isUpper =>
92+
tp.derivedOrType(loop(tp1), loop(tp2))
93+
case tp: NamedType =>
94+
params.indexOf(tp.symbol) match {
95+
case -1 =>
96+
mapping(tp.symbol) match {
97+
case tv: TypeVar => tv.origin
98+
case null => tp
99+
}
100+
case i => pt.paramRefs(i)
101+
}
102+
case tp => tp
103+
}
104+
}
105+
106+
val tb = param.info.bounds
107+
tb.derivedTypeBounds(
108+
lo = substDependentSyms(tb.lo, isUpper = false),
109+
hi = substDependentSyms(tb.hi, isUpper = true)
110+
)
111+
},
112+
pt => defn.AnyType
113+
)
114+
115+
val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) =>
116+
val tv = new TypeVar(paramRef, creatorState = null)
117+
mapping = mapping.updated(sym, tv)
118+
reverseMapping = reverseMapping.updated(tv.origin, sym)
119+
tv
120+
}
121+
122+
// the replaced symbols will be stripped off the bounds by `addToConstraint` and used as orderings
123+
addToConstraint(poly1, tvars).reporting({ _ =>
124+
i"added to constraint: $params%, %\n$debugBoundsDescription"
125+
}, gadts)
126+
}
76127

77128
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = {
78129
@annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match {
@@ -82,16 +133,17 @@ final class ProperGadtConstraint private(
82133
case _ => tp
83134
}
84135

85-
val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match {
136+
val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match {
86137
case tv: TypeVar => tv
87138
case inst =>
88139
gadts.println(i"instantiated: $sym -> $inst")
89140
return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
90141
}
91142

92143
val internalizedBound = bound match {
93-
case nt: NamedType if contains(nt.symbol) =>
94-
stripInternalTypeVar(tvar(nt.symbol))
144+
case nt: NamedType =>
145+
val ntTvar = mapping(nt.symbol)
146+
if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound
95147
case _ => bound
96148
}
97149
(
@@ -119,20 +171,22 @@ final class ProperGadtConstraint private(
119171
if (isUpper) addUpperBound(symTvar.origin, bound1)
120172
else addLowerBound(symTvar.origin, bound1)
121173
}
122-
).reporting({ res =>
174+
).reporting({ res =>
123175
val descr = if (isUpper) "upper" else "lower"
124176
val op = if (isUpper) "<:" else ">:"
125-
i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )"
177+
i"adding $descr bound $sym $op $bound = $res"
126178
}, gadts)
127179
}
128180

129181
override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean =
130-
constraint.isLess(tvar(sym1).origin, tvar(sym2).origin)
182+
constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin)
131183

132184
override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds =
133185
mapping(sym) match {
134186
case null => null
135-
case tv => fullBounds(tv.origin)
187+
case tv =>
188+
fullBounds(tv.origin)
189+
.ensuring(containsNoInternalTypes(_))
136190
}
137191

138192
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = {
@@ -145,14 +199,16 @@ final class ProperGadtConstraint private(
145199
TypeAlias(reverseMapping(tpr).typeRef)
146200
case tb => tb
147201
}
148-
retrieveBounds//.reporting({ res => i"gadt bounds $sym: $res" }, gadts)
202+
retrieveBounds
203+
//.reporting({ res => i"gadt bounds $sym: $res" }, gadts)
204+
.ensuring(containsNoInternalTypes(_))
149205
}
150206
}
151207

152208
override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null
153209

154210
override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = {
155-
val res = approximation(tvar(sym).origin, fromBelow = fromBelow)
211+
val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow)
156212
gadts.println(i"approximating $sym ~> $res")
157213
res
158214
}
@@ -207,36 +263,21 @@ final class ProperGadtConstraint private(
207263
case null => param
208264
}
209265

210-
private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = {
211-
mapping(sym) match {
212-
case tv: TypeVar =>
213-
tv
214-
case null =>
215-
val res = {
216-
import NameKinds.DepParamName
217-
// For symbols standing for HK types, we need to preserve the kind information
218-
// (see also usage of adaptHKvariances above)
219-
// Ideally we'd always preserve the bounds,
220-
// but first we need an equivalent of ConstraintHandling#addConstraint
221-
// TODO: implement the above
222-
val initialBounds = sym.info match {
223-
case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb
224-
case _ => TypeBounds.empty
225-
}
226-
// avoid registering the TypeVar with TyperState / TyperState#constraint
227-
// - we don't want TyperState instantiating these TypeVars
228-
// - we don't want TypeComparer constraining these TypeVars
229-
val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)(
230-
pt => initialBounds :: Nil,
231-
pt => defn.AnyType)
232-
new TypeVar(poly.paramRefs.head, creatorState = null)
233-
}
234-
gadts.println(i"GADTMap: created tvar $sym -> $res")
235-
constraint = constraint.add(res.origin.binder, res :: Nil)
236-
mapping = mapping.updated(sym, res)
237-
reverseMapping = reverseMapping.updated(res.origin, sym)
238-
res
239-
}
266+
private[this] def tvarOrError(sym: Symbol)(implicit ctx: Context): TypeVar =
267+
mapping(sym).ensuring(_ ne null, i"not a constrainable symbol: $sym")
268+
269+
private[this] def containsNoInternalTypes(
270+
tp: Type,
271+
acc: TypeAccumulator[Boolean] = null
272+
)(implicit ctx: Context): Boolean = tp match {
273+
case tpr: TypeParamRef => !reverseMapping.contains(tpr)
274+
case tv: TypeVar => !reverseMapping.contains(tv.origin)
275+
case tp =>
276+
(if (acc ne null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
277+
}
278+
279+
private[this] class ContainsNoInternalTypesAccumulator(implicit ctx: Context) extends TypeAccumulator[Boolean] {
280+
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp)
240281
}
241282

242283
// ---- Debug ------------------------------------------------------------
@@ -266,7 +307,7 @@ final class ProperGadtConstraint private(
266307

267308
override def contains(sym: Symbol)(implicit ctx: Context) = false
268309

269-
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGadtConstraint.addEmptyBounds")
310+
override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
270311
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
271312

272313
override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGadtConstraint.approximation")

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,10 @@ trait Symbols { this: Context =>
209209
modFlags | PackageCreationFlags, clsFlags | PackageCreationFlags,
210210
Nil, decls)
211211

212-
/** Define a new symbol associated with a Bind or pattern wildcard and
213-
* make it gadt narrowable.
214-
*/
215-
def newPatternBoundSymbol(name: Name, info: Type, span: Span): Symbol = {
212+
/** Define a new symbol associated with a Bind or pattern wildcard and, by default, make it gadt narrowable. */
213+
def newPatternBoundSymbol(name: Name, info: Type, span: Span, addToGadt: Boolean = true): Symbol = {
216214
val sym = newSymbol(owner, name, Case, info, coord = span)
217-
if (name.isTypeName) {
218-
val bounds = info.bounds
219-
gadt.addBound(sym, bounds.lo, isUpper = false)
220-
gadt.addBound(sym, bounds.hi, isUpper = true)
221-
}
215+
if (addToGadt && name.isTypeName) gadt.addToConstraint(sym)
222216
sym
223217
}
224218

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
387387
val bound1 = massage(bound)
388388
if (bound1 ne bound) {
389389
if (checkCtx eq ctx) checkCtx = ctx.fresh.setFreshGADTBounds
390-
if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addEmptyBounds(sym)
390+
if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addToConstraint(sym)
391391
checkCtx.gadt.addBound(sym, bound1, fromBelow)
392392
typr.println("install GADT bound $bound1 for when checking F-bounded $sym")
393393
}

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,13 +284,17 @@ object Inferencing {
284284
if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x)
285285
tvar.instantiate(fromBelow = false)
286286
else {
287-
val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span)
287+
// since the symbols we're creating may have inter-dependencies in their bounds,
288+
// we add them to the GADT constraint later, simultaneously
289+
val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span, addToGadt = false)
288290
tvar.instantiateWith(wildCard.typeRef)
289291
patternBound += wildCard
290292
}
291293
}
292294
}
293-
patternBound.toList
295+
val res = patternBound.toList
296+
if (res.nonEmpty) ctx.gadt.addToConstraint(res)
297+
res
294298
}
295299

296300
type VarianceMap = SimpleIdentityMap[TypeVar, Integer]

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -821,11 +821,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
821821
}
822822

823823
def registerAsGadtSyms(typeBinds: TypeBindsMap)(implicit ctx: Context): Unit =
824-
typeBinds.foreachBinding { case (sym, _) =>
825-
val TypeBounds(lo, hi) = sym.info.bounds
826-
ctx.gadt.addBound(sym, lo, isUpper = false)
827-
ctx.gadt.addBound(sym, hi, isUpper = true)
828-
}
824+
if (typeBinds.size > 0) ctx.gadt.addToConstraint(typeBinds.keys)
829825

830826
pat match {
831827
case Typed(pat1, tpt) =>

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,12 +1338,11 @@ class Namer { typer: Typer =>
13381338
var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType)
13391339
if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody)
13401340
if (typeParams.nonEmpty) {
1341+
// we'll be typing an expression from a polymorphic definition's body,
1342+
// so we must allow constraining its type parameters
1343+
// compare with typedDefDef, see tests/pos/gadt-inference.scala
13411344
rhsCtx.setFreshGADTBounds
1342-
typeParams.foreach { tdef =>
1343-
val TypeBounds(lo, hi) = tdef.info.bounds
1344-
rhsCtx.gadt.addBound(tdef, lo, isUpper = false)
1345-
rhsCtx.gadt.addBound(tdef, hi, isUpper = true)
1346-
}
1345+
rhsCtx.gadt.addToConstraint(typeParams)
13471346
}
13481347
def rhsType = typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(rhsCtx).tpe
13491348

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,38 +1508,28 @@ class Typer extends Namer
15081508
if (sym is ImplicitOrImplied) checkImplicitConversionDefOK(sym)
15091509
val tpt1 = checkSimpleKinded(typedType(tpt))
15101510

1511-
val rhsCtx: Context = {
1512-
var _result: FreshContext = null
1513-
def resultCtx(): FreshContext = {
1514-
if (_result == null) _result = ctx.fresh
1515-
_result
1516-
}
1517-
1518-
if (tparams1.nonEmpty) {
1519-
resultCtx().setFreshGADTBounds
1520-
if (!sym.isConstructor) {
1521-
// if we're _not_ in a constructor, allow constraining type parameters
1522-
tparams1.foreach { tdef =>
1523-
val tb @ TypeBounds(lo, hi) = tdef.symbol.info.bounds
1524-
resultCtx().gadt.addBound(tdef.symbol, lo, isUpper = false)
1525-
resultCtx().gadt.addBound(tdef.symbol, hi, isUpper = true)
1526-
}
1527-
} else if (!sym.isPrimaryConstructor) {
1528-
// otherwise, for secondary constructors we need a context that "knows"
1529-
// that their type parameters are aliases of the class type parameters.
1530-
// See pos/i941.scala
1531-
(tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) =>
1532-
val tr = tparam.typeRef
1533-
resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = false)
1534-
resultCtx().gadt.addBound(tdef.symbol, tr, isUpper = true)
1535-
}
1511+
val rhsCtx = ctx.fresh
1512+
if (tparams1.nonEmpty) {
1513+
rhsCtx.setFreshGADTBounds
1514+
if (!sym.isConstructor) {
1515+
// we're typing a polymorphic definition's body,
1516+
// so we allow constraining all of its type parameters
1517+
// constructors are an exception as we don't allow constraining type params of classes
1518+
rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol))
1519+
} else if (!sym.isPrimaryConstructor) {
1520+
// otherwise, for secondary constructors we need a context that "knows"
1521+
// that their type parameters are aliases of the class type parameters.
1522+
// See pos/i941.scala
1523+
rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol))
1524+
(tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) =>
1525+
val tr = tparam.typeRef
1526+
rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false)
1527+
rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true)
15361528
}
15371529
}
1538-
1539-
if (sym.isInlineMethod) resultCtx().addMode(Mode.InlineableBody)
1540-
1541-
if (_result ne null) _result else ctx
15421530
}
1531+
1532+
if (sym.isInlineMethod) rhsCtx.addMode(Mode.InlineableBody)
15431533
val rhs1 = typedExpr(ddef.rhs, tpt1.tpe.widenExpr)(rhsCtx)
15441534

15451535
if (sym.isInlineMethod) {

0 commit comments

Comments
 (0)