Skip to content

Commit 292e56f

Browse files
authored
Eliminate class hierarchy in GadtConstraint (#16194)
2 parents 602ed35 + e6a9ffd commit 292e56f

File tree

5 files changed

+61
-124
lines changed

5 files changed

+61
-124
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ object Contexts {
814814
.updated(notNullInfosLoc, Nil)
815815
.updated(compilationUnitLoc, NoCompilationUnit)
816816
searchHistory = new SearchRoot
817-
gadt = EmptyGadtConstraint
817+
gadt = GadtConstraint.empty
818818
}
819819

820820
@sharable object NoContext extends Context((null: ContextBase | Null).uncheckedNN) {

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 53 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -10,77 +10,25 @@ import util.{SimpleIdentitySet, SimpleIdentityMap}
1010
import collection.mutable
1111
import printing._
1212

13-
import scala.annotation.internal.sharable
13+
object GadtConstraint:
14+
def apply(): GadtConstraint = empty
15+
def empty: GadtConstraint =
16+
new ProperGadtConstraint(OrderingConstraint.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, false)
1417

1518
/** Represents GADT constraints currently in scope */
16-
sealed abstract class GadtConstraint extends Showable {
17-
/** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */
18-
def bounds(sym: Symbol)(using Context): TypeBounds | Null
19-
20-
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
21-
*
22-
* @note this performs subtype checks between ordered symbols.
23-
* Using this in isSubType can lead to infinite recursion. Consider `bounds` instead.
24-
*/
25-
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null
26-
27-
/** Is `sym1` ordered to be less than `sym2`? */
28-
def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean
29-
30-
/** Add symbols to constraint, correctly handling inter-dependencies.
31-
*
32-
* @see [[ConstraintHandling.addToConstraint]]
33-
*/
34-
def addToConstraint(syms: List[Symbol])(using Context): Boolean
35-
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
36-
37-
/** Further constrain a symbol already present in the constraint. */
38-
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean
39-
40-
/** Is the symbol registered in the constraint?
41-
*
42-
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
43-
*/
44-
def contains(sym: Symbol)(using Context): Boolean
45-
46-
/** GADT constraint narrows bounds of at least one variable */
47-
def isNarrowing: Boolean
48-
49-
/** See [[ConstraintHandling.approximation]] */
50-
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type
51-
52-
def symbols: List[Symbol]
53-
54-
def fresh: GadtConstraint
55-
56-
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
57-
def restore(other: GadtConstraint): Unit
58-
59-
/** Provides more information than toText, by showing the underlying Constraint details. */
60-
def debugBoundsDescription(using Context): String
61-
}
62-
63-
final class ProperGadtConstraint private(
19+
sealed trait GadtConstraint (
6420
private var myConstraint: Constraint,
6521
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
6622
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
6723
private var wasConstrained: Boolean
68-
) extends GadtConstraint with ConstraintHandling {
69-
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
24+
) extends Showable {
25+
this: ConstraintHandling =>
7026

71-
def this() = this(
72-
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty),
73-
mapping = SimpleIdentityMap.empty,
74-
reverseMapping = SimpleIdentityMap.empty,
75-
wasConstrained = false
76-
)
27+
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
7728

7829
/** Exposes ConstraintHandling.subsumes */
7930
def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = {
80-
def extractConstraint(g: GadtConstraint) = g match {
81-
case s: ProperGadtConstraint => s.constraint
82-
case EmptyGadtConstraint => OrderingConstraint.empty
83-
}
31+
def extractConstraint(g: GadtConstraint) = g.constraint
8432
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
8533
}
8634

@@ -89,7 +37,12 @@ final class ProperGadtConstraint private(
8937
// the case where they're valid, so no approximating is needed.
9038
rawBound
9139

92-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
40+
/** Add symbols to constraint, correctly handling inter-dependencies.
41+
*
42+
* @see [[ConstraintHandling.addToConstraint]]
43+
*/
44+
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
45+
def addToConstraint(params: List[Symbol])(using Context): Boolean = {
9346
import NameKinds.DepParamName
9447

9548
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
@@ -138,7 +91,8 @@ final class ProperGadtConstraint private(
13891
.showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts)
13992
}
14093

141-
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
94+
/** Further constrain a symbol already present in the constraint. */
95+
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
14296
@annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match {
14397
case tv: TypeVar =>
14498
val inst = constraint.instType(tv)
@@ -179,10 +133,16 @@ final class ProperGadtConstraint private(
179133
result
180134
}
181135

182-
override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
136+
/** Is `sym1` ordered to be less than `sym2`? */
137+
def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
183138
constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin)
184139

185-
override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null =
140+
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
141+
*
142+
* @note this performs subtype checks between ordered symbols.
143+
* Using this in isSubType can lead to infinite recursion. Consider `bounds` instead.
144+
*/
145+
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null =
186146
mapping(sym) match {
187147
case null => null
188148
// TODO: Improve flow typing so that ascription becomes redundant
@@ -191,7 +151,8 @@ final class ProperGadtConstraint private(
191151
// .ensuring(containsNoInternalTypes(_))
192152
}
193153

194-
override def bounds(sym: Symbol)(using Context): TypeBounds | Null =
154+
/** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */
155+
def bounds(sym: Symbol)(using Context): TypeBounds | Null =
195156
mapping(sym) match {
196157
case null => null
197158
// TODO: Improve flow typing so that ascription becomes redundant
@@ -202,11 +163,17 @@ final class ProperGadtConstraint private(
202163
//.ensuring(containsNoInternalTypes(_))
203164
}
204165

205-
override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null
166+
/** Is the symbol registered in the constraint?
167+
*
168+
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
169+
*/
170+
def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null
206171

172+
/** GADT constraint narrows bounds of at least one variable */
207173
def isNarrowing: Boolean = wasConstrained
208174

209-
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
175+
/** See [[ConstraintHandling.approximation]] */
176+
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = {
210177
val res =
211178
approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match
212179
case tpr: TypeParamRef =>
@@ -220,23 +187,16 @@ final class ProperGadtConstraint private(
220187
res
221188
}
222189

223-
override def symbols: List[Symbol] = mapping.keys
224-
225-
override def fresh: GadtConstraint = new ProperGadtConstraint(
226-
myConstraint,
227-
mapping,
228-
reverseMapping,
229-
wasConstrained
230-
)
231-
232-
def restore(other: GadtConstraint): Unit = other match {
233-
case other: ProperGadtConstraint =>
234-
this.myConstraint = other.myConstraint
235-
this.mapping = other.mapping
236-
this.reverseMapping = other.reverseMapping
237-
this.wasConstrained = other.wasConstrained
238-
case _ => ;
239-
}
190+
def symbols: List[Symbol] = mapping.keys
191+
192+
def fresh: GadtConstraint = new ProperGadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained)
193+
194+
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
195+
def restore(other: GadtConstraint): Unit =
196+
this.myConstraint = other.myConstraint
197+
this.mapping = other.mapping
198+
this.reverseMapping = other.reverseMapping
199+
this.wasConstrained = other.wasConstrained
240200

241201
// ---- Protected/internal -----------------------------------------------
242202

@@ -294,30 +254,13 @@ final class ProperGadtConstraint private(
294254

295255
override def toText(printer: Printer): Texts.Text = printer.toText(this)
296256

297-
override def debugBoundsDescription(using Context): String = i"$this\n$constraint"
257+
/** Provides more information than toText, by showing the underlying Constraint details. */
258+
def debugBoundsDescription(using Context): String = i"$this\n$constraint"
298259
}
299260

300-
@sharable object EmptyGadtConstraint extends GadtConstraint {
301-
override def bounds(sym: Symbol)(using Context): TypeBounds | Null = null
302-
override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = null
303-
304-
override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess")
305-
306-
override def isNarrowing: Boolean = false
307-
308-
override def contains(sym: Symbol)(using Context) = false
309-
310-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
311-
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
312-
313-
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
314-
315-
override def symbols: List[Symbol] = Nil
316-
317-
override def fresh = new ProperGadtConstraint
318-
override def restore(other: GadtConstraint): Unit =
319-
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
320-
321-
override def toText(printer: Printer): Texts.Text = printer.toText(this)
322-
override def debugBoundsDescription(using Context): String = i"$this"
323-
}
261+
private class ProperGadtConstraint (
262+
myConstraint: Constraint,
263+
mapping: SimpleIdentityMap[Symbol, TypeVar],
264+
reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
265+
wasConstrained: Boolean,
266+
) extends ConstraintHandling with GadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained)

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,11 +1848,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
18481848
val preGadt = ctx.gadt.fresh
18491849

18501850
def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
1851-
subsumes(left, right, preConstraint) && preGadt.match
1852-
case preGadt: ProperGadtConstraint =>
1853-
preGadt.subsumes(leftGadt, rightGadt, preGadt)
1854-
case _ =>
1855-
true
1851+
subsumes(left, right, preConstraint) && preGadt.subsumes(leftGadt, rightGadt, preGadt)
18561852

18571853
if op1 then
18581854
val op1Constraint = constraint

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -694,13 +694,11 @@ class PlainPrinter(_ctx: Context) extends Printer {
694694
finally
695695
ctx.typerState.constraint = savedConstraint
696696

697-
def toText(g: GadtConstraint): Text = g match
698-
case EmptyGadtConstraint => "EmptyGadtConstraint"
699-
case g: ProperGadtConstraint =>
700-
val deps = for sym <- g.symbols yield
701-
val bound = g.fullBounds(sym).nn
702-
(typeText(toText(sym.typeRef)) ~ toText(bound)).close
703-
("GadtConstraint(" ~ Text(deps, ", ") ~ ")").close
697+
def toText(g: GadtConstraint): Text =
698+
val deps = for sym <- g.symbols yield
699+
val bound = g.fullBounds(sym).nn
700+
(typeText(toText(sym.typeRef)) ~ toText(bound)).close
701+
("GadtConstraint(" ~ Text(deps, ", ") ~ ")").close
704702

705703
def plain: PlainPrinter = this
706704

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3777,7 +3777,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
37773777
adaptToSubType(wtp)
37783778
case CompareResult.OKwithGADTUsed
37793779
if pt.isValueType
3780-
&& !inContext(ctx.fresh.setGadt(EmptyGadtConstraint)) {
3780+
&& !inContext(ctx.fresh.setGadt(GadtConstraint.empty)) {
37813781
val res = (tree.tpe.widenExpr frozen_<:< pt)
37823782
if res then
37833783
// we overshot; a cast is not needed, after all.

0 commit comments

Comments
 (0)