Skip to content

Commit c72e062

Browse files
committed
Eliminate class hierarchy in GadtConstraint
1 parent c9c95d4 commit c72e062

File tree

5 files changed

+51
-115
lines changed

5 files changed

+51
-115
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ object Contexts {
814814
.updated(notNullInfosLoc, Nil)
815815
.updated(compilationUnitLoc, NoCompilationUnit)
816816
searchHistory = new SearchRoot
817-
gadt = EmptyGadtConstraint
817+
gadt = GadtConstraint.empty
818818
}
819819

820820
@sharable object NoContext extends Context((null: ContextBase | Null).uncheckedNN) {

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 43 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,17 @@ import printing._
1212

1313
import scala.annotation.internal.sharable
1414

15-
/** Represents GADT constraints currently in scope */
16-
sealed abstract class GadtConstraint extends Showable {
17-
/** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */
18-
def bounds(sym: Symbol)(using Context): TypeBounds | Null
19-
20-
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
21-
*
22-
* @note this performs subtype checks between ordered symbols.
23-
* Using this in isSubType can lead to infinite recursion. Consider `bounds` instead.
24-
*/
25-
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null
26-
27-
/** Is `sym1` ordered to be less than `sym2`? */
28-
def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean
29-
30-
/** Add symbols to constraint, correctly handling inter-dependencies.
31-
*
32-
* @see [[ConstraintHandling.addToConstraint]]
33-
*/
34-
def addToConstraint(syms: List[Symbol])(using Context): Boolean
35-
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
36-
37-
/** Further constrain a symbol already present in the constraint. */
38-
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean
15+
object GadtConstraint:
16+
@sharable val empty =
17+
new GadtConstraint(OrderingConstraint.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, false)
3918

40-
/** Is the symbol registered in the constraint?
41-
*
42-
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
43-
*/
44-
def contains(sym: Symbol)(using Context): Boolean
45-
46-
/** GADT constraint narrows bounds of at least one variable */
47-
def isNarrowing: Boolean
48-
49-
/** See [[ConstraintHandling.approximation]] */
50-
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type
51-
52-
def symbols: List[Symbol]
53-
54-
def fresh: GadtConstraint
55-
56-
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
57-
def restore(other: GadtConstraint): Unit
58-
59-
/** Provides more information than toText, by showing the underlying Constraint details. */
60-
def debugBoundsDescription(using Context): String
61-
}
62-
63-
final class ProperGadtConstraint private(
19+
/** Represents GADT constraints currently in scope */
20+
final class GadtConstraint private(
6421
private var myConstraint: Constraint,
6522
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
6623
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
6724
private var wasConstrained: Boolean
68-
) extends GadtConstraint with ConstraintHandling {
25+
) extends ConstraintHandling with Showable {
6926
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
7027

7128
def this() = this(
@@ -77,10 +34,7 @@ final class ProperGadtConstraint private(
7734

7835
/** Exposes ConstraintHandling.subsumes */
7936
def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = {
80-
def extractConstraint(g: GadtConstraint) = g match {
81-
case s: ProperGadtConstraint => s.constraint
82-
case EmptyGadtConstraint => OrderingConstraint.empty
83-
}
37+
def extractConstraint(g: GadtConstraint) = g.constraint
8438
subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre))
8539
}
8640

@@ -89,7 +43,12 @@ final class ProperGadtConstraint private(
8943
// the case where they're valid, so no approximating is needed.
9044
rawBound
9145

92-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
46+
/** Add symbols to constraint, correctly handling inter-dependencies.
47+
*
48+
* @see [[ConstraintHandling.addToConstraint]]
49+
*/
50+
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
51+
def addToConstraint(params: List[Symbol])(using Context): Boolean = {
9352
import NameKinds.DepParamName
9453

9554
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
@@ -138,7 +97,8 @@ final class ProperGadtConstraint private(
13897
.showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts)
13998
}
14099

141-
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
100+
/** Further constrain a symbol already present in the constraint. */
101+
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
142102
@annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match {
143103
case tv: TypeVar =>
144104
val inst = constraint.instType(tv)
@@ -179,10 +139,16 @@ final class ProperGadtConstraint private(
179139
result
180140
}
181141

182-
override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
142+
/** Is `sym1` ordered to be less than `sym2`? */
143+
def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
183144
constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin)
184145

185-
override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null =
146+
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
147+
*
148+
* @note this performs subtype checks between ordered symbols.
149+
* Using this in isSubType can lead to infinite recursion. Consider `bounds` instead.
150+
*/
151+
def fullBounds(sym: Symbol)(using Context): TypeBounds | Null =
186152
mapping(sym) match {
187153
case null => null
188154
// TODO: Improve flow typing so that ascription becomes redundant
@@ -191,7 +157,8 @@ final class ProperGadtConstraint private(
191157
// .ensuring(containsNoInternalTypes(_))
192158
}
193159

194-
override def bounds(sym: Symbol)(using Context): TypeBounds | Null =
160+
/** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */
161+
def bounds(sym: Symbol)(using Context): TypeBounds | Null =
195162
mapping(sym) match {
196163
case null => null
197164
// TODO: Improve flow typing so that ascription becomes redundant
@@ -202,11 +169,17 @@ final class ProperGadtConstraint private(
202169
//.ensuring(containsNoInternalTypes(_))
203170
}
204171

205-
override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null
172+
/** Is the symbol registered in the constraint?
173+
*
174+
* @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]].
175+
*/
176+
def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null
206177

178+
/** GADT constraint narrows bounds of at least one variable */
207179
def isNarrowing: Boolean = wasConstrained
208180

209-
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = {
181+
/** See [[ConstraintHandling.approximation]] */
182+
def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int = Int.MaxValue)(using Context): Type = {
210183
val res =
211184
approximation(tvarOrError(sym).origin, fromBelow, maxLevel) match
212185
case tpr: TypeParamRef =>
@@ -220,23 +193,16 @@ final class ProperGadtConstraint private(
220193
res
221194
}
222195

223-
override def symbols: List[Symbol] = mapping.keys
196+
def symbols: List[Symbol] = mapping.keys
224197

225-
override def fresh: GadtConstraint = new ProperGadtConstraint(
226-
myConstraint,
227-
mapping,
228-
reverseMapping,
229-
wasConstrained
230-
)
198+
def fresh: GadtConstraint = new GadtConstraint(myConstraint, mapping, reverseMapping, wasConstrained)
231199

232-
def restore(other: GadtConstraint): Unit = other match {
233-
case other: ProperGadtConstraint =>
234-
this.myConstraint = other.myConstraint
235-
this.mapping = other.mapping
236-
this.reverseMapping = other.reverseMapping
237-
this.wasConstrained = other.wasConstrained
238-
case _ => ;
239-
}
200+
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
201+
def restore(other: GadtConstraint): Unit =
202+
this.myConstraint = other.myConstraint
203+
this.mapping = other.mapping
204+
this.reverseMapping = other.reverseMapping
205+
this.wasConstrained = other.wasConstrained
240206

241207
// ---- Protected/internal -----------------------------------------------
242208

@@ -294,30 +260,6 @@ final class ProperGadtConstraint private(
294260

295261
override def toText(printer: Printer): Texts.Text = printer.toText(this)
296262

297-
override def debugBoundsDescription(using Context): String = i"$this\n$constraint"
298-
}
299-
300-
@sharable object EmptyGadtConstraint extends GadtConstraint {
301-
override def bounds(sym: Symbol)(using Context): TypeBounds | Null = null
302-
override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = null
303-
304-
override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess")
305-
306-
override def isNarrowing: Boolean = false
307-
308-
override def contains(sym: Symbol)(using Context) = false
309-
310-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
311-
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
312-
313-
override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
314-
315-
override def symbols: List[Symbol] = Nil
316-
317-
override def fresh = new ProperGadtConstraint
318-
override def restore(other: GadtConstraint): Unit =
319-
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
320-
321-
override def toText(printer: Printer): Texts.Text = printer.toText(this)
322-
override def debugBoundsDescription(using Context): String = i"$this"
263+
/** Provides more information than toText, by showing the underlying Constraint details. */
264+
def debugBoundsDescription(using Context): String = i"$this\n$constraint"
323265
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,11 +1830,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
18301830
val preGadt = ctx.gadt.fresh
18311831

18321832
def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
1833-
subsumes(left, right, preConstraint) && preGadt.match
1834-
case preGadt: ProperGadtConstraint =>
1835-
preGadt.subsumes(leftGadt, rightGadt, preGadt)
1836-
case _ =>
1837-
true
1833+
subsumes(left, right, preConstraint) && preGadt.subsumes(leftGadt, rightGadt, preGadt)
18381834

18391835
if op1 then
18401836
val op1Constraint = constraint

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -693,13 +693,11 @@ class PlainPrinter(_ctx: Context) extends Printer {
693693
finally
694694
ctx.typerState.constraint = savedConstraint
695695

696-
def toText(g: GadtConstraint): Text = g match
697-
case EmptyGadtConstraint => "EmptyGadtConstraint"
698-
case g: ProperGadtConstraint =>
699-
val deps = for sym <- g.symbols yield
700-
val bound = g.fullBounds(sym).nn
701-
(typeText(toText(sym.typeRef)) ~ toText(bound)).close
702-
("GadtConstraint(" ~ Text(deps, ", ") ~ ")").close
696+
def toText(g: GadtConstraint): Text =
697+
val deps = for sym <- g.symbols yield
698+
val bound = g.fullBounds(sym).nn
699+
(typeText(toText(sym.typeRef)) ~ toText(bound)).close
700+
("GadtConstraint(" ~ Text(deps, ", ") ~ ")").close
703701

704702
def plain: PlainPrinter = this
705703

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3774,7 +3774,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
37743774
adaptToSubType(wtp)
37753775
case CompareResult.OKwithGADTUsed
37763776
if pt.isValueType
3777-
&& !inContext(ctx.fresh.setGadt(EmptyGadtConstraint)) {
3777+
&& !inContext(ctx.fresh.setGadt(GadtConstraint.empty)) {
37783778
val res = (tree.tpe.widenExpr frozen_<:< pt)
37793779
if res then
37803780
// we overshot; a cast is not needed, after all.

0 commit comments

Comments
 (0)