Skip to content

Commit 71abc3f

Browse files
committed
Rework how GADT constraints are inferred
GADT constraints are now inferred with an intersection-inspired algorithm. Inferencing.constrainPatternType was moved to PatternTypeConstrainer to organize the code better.
1 parent 3bcaf1d commit 71abc3f

13 files changed

+340
-85
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package dotty.tools
2+
package dotc
3+
package core
4+
5+
import Decorators._
6+
import Symbols._
7+
import Types._
8+
import Flags._
9+
import dotty.tools.dotc.reporting.trace
10+
import config.Printers._
11+
12+
trait PatternTypeConstrainer { self: TypeComparer =>
13+
14+
/** Derive type and GADT constraints that necessarily follow from a pattern with the given type matching
15+
* a scrutinee of the given type.
16+
*
17+
* We have the following situation in case of a (dynamic) pattern match:
18+
*
19+
* StaticScrutineeType PatternType
20+
* \ /
21+
* DynamicScrutineeType
22+
*
23+
* In simple cases, it must hold that `PatternType <: StaticScrutineeType`:
24+
*
25+
* StaticScrutineeType
26+
* | \
27+
* | PatternType
28+
* | /
29+
* DynamicScrutineeType
30+
*
31+
* A good example of a situation where the above must hold is when static scrutinee type is the root of an enum,
32+
* and the pattern is an unapply of a case class, or a case object literal (of that enum).
33+
*
34+
* In slightly more complex cases, we may need to upcast `StaticScrutineeType`:
35+
*
36+
* SharedPatternScrutineeSuperType
37+
* / \
38+
* StaticScrutineeType PatternType
39+
* \ /
40+
* DynamicScrutineeType
41+
*
42+
* This may be the case if the scrutinee is a singleton type or a path-dependent type. It is also the case
43+
* for the following definitions:
44+
*
45+
* trait Expr[T]
46+
* trait IntExpr extends Expr[T]
47+
* trait Const[T] extends Expr[T]
48+
*
49+
* StaticScrutineeType = Const[T]
50+
* PatternType = IntExpr
51+
*
52+
* Union and intersection types are an additional complication - if either scrutinee or pattern are a union type,
53+
* then the above relationships only need to hold for the "leaves" of the types.
54+
*
55+
* Finally, if pattern type contains hk-types applied to concrete types (as opposed to type variables),
56+
* or either scrutinee or pattern type contain type member refinements, the above relationships do not need
57+
* to hold at all. Consider (where `T1`, `T2` are unrelated traits):
58+
*
59+
* StaticScrutineeType = { type T <: T1 }
60+
* PatternType = { type T <: T2 }
61+
*
62+
* In the above situation, DynamicScrutineeType can equal { type T = T1 & T2 }, but there is no useful relationship
63+
* between StaticScrutineeType and PatternType (nor any of their subcomponents). Similarly:
64+
*
65+
* StaticScrutineeType = Option[T1]
66+
* PatternType = Some[T2]
67+
*
68+
* Again, DynamicScrutineeType may equal Some[T1 & T2], and there's no useful relationship between the static
69+
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
70+
* in which case the subtyping relationship "heals" the type.
71+
*/
72+
def constrainPatternType(pat: Type, scrut: Type): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
73+
74+
def classesMayBeCompatible: Boolean = {
75+
import Flags._
76+
val patClassSym = pat.widenSingleton.classSymbol
77+
val scrutClassSym = scrut.widenSingleton.classSymbol
78+
!patClassSym.exists || !scrutClassSym.exists || {
79+
if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym)
80+
else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym)
81+
else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait))
82+
patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym)
83+
else true
84+
}
85+
}
86+
87+
def stripRefinement(tp: Type): Type = tp match {
88+
case tp: RefinedOrRecType => stripRefinement(tp.parent)
89+
case tp => tp
90+
}
91+
92+
def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
93+
val upcasted: Type = scrut match {
94+
case scrut: TypeRef if scrut.symbol.isClass =>
95+
// we do not infer constraints following from all parents for performance reasons
96+
// in principle however, if `A extends B, C`, then `A` can be treated as `B & C`
97+
scrut.firstParent
98+
case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
99+
val patClassSym = pat.classSymbol
100+
// as above, we do not consider all parents for performance reasons
101+
def firstParentSharedWithPat(tp: Type, tpClassSym: ClassSymbol): Symbol = {
102+
var parents = tpClassSym.info.parents
103+
parents match {
104+
case first :: rest =>
105+
if (first.classSymbol == defn.ObjectClass) parents = rest
106+
case _ => ;
107+
}
108+
parents match {
109+
case first :: _ =>
110+
val firstClassSym = first.classSymbol.asClass
111+
val res = if (patClassSym.derivesFrom(firstClassSym)) firstClassSym
112+
else firstParentSharedWithPat(first, firstClassSym)
113+
res
114+
case _ => NoSymbol
115+
}
116+
}
117+
val sym = firstParentSharedWithPat(tycon, tycon.symbol.asClass)
118+
if (sym.exists) scrut.baseType(sym) else NoType
119+
case scrut: TypeProxy => scrut.superType
120+
case _ => NoType
121+
}
122+
if (upcasted.exists)
123+
constrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
124+
else true
125+
}
126+
127+
scrut.dealias match {
128+
case OrType(scrut1, scrut2) =>
129+
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
130+
case AndType(scrut1, scrut2) =>
131+
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
132+
case scrut: RefinedOrRecType =>
133+
constrainPatternType(pat, stripRefinement(scrut))
134+
case scrut => pat.dealias match {
135+
case OrType(pat1, pat2) =>
136+
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
137+
case AndType(pat1, pat2) =>
138+
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
139+
case scrut: RefinedOrRecType =>
140+
constrainPatternType(stripRefinement(scrut), pat)
141+
case pat =>
142+
constrainSimplePatternType(pat, scrut) || classesMayBeCompatible && constrainUpcasted(scrut)
143+
}
144+
}
145+
}
146+
147+
/** Constrain "simple" patterns (see `constrainPatternType`).
148+
*
149+
* This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee,
150+
* or otherwise it cannot possibly match. In order to do that, we:
151+
*
152+
* 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents
153+
* 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type.
154+
* 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type
155+
* 4. Check that `WidenedScrutineeType <: NarrowedPatternType`
156+
*
157+
* Importantly, note that the pattern type may contain type variables.
158+
*
159+
* ## Invariant refinement
160+
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
161+
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
162+
*
163+
* trait C[+A]
164+
* trait D[+B](val b: B) extends C[B]
165+
* trait E extends D[Any](0) with C[String]
166+
*
167+
* `E` is a counter-example to the above - if `e: E`, then `e: C[String]` and `e` is instance of `D`, but
168+
* it is false that `e: D[String]`! This is a problem if we're constraining a pattern like the below:
169+
*
170+
* def foo[T](c: C[T]): T = c match {
171+
* case d: D[t] => d.b
172+
* }
173+
*
174+
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
175+
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
176+
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
177+
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
178+
* case classes without also appropriately extending the relevant case class
179+
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
180+
*/
181+
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type): Boolean = {
182+
def refinementIsInvariant(tp: Type): Boolean = tp match {
183+
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
184+
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
185+
case _ => false
186+
}
187+
188+
def widenVariantParams = new TypeMap {
189+
def apply(tp: Type) = mapOver(tp) match {
190+
case tp @ AppliedType(tycon, args) =>
191+
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
192+
if (tparam.paramVariance != 0) TypeBounds.empty else arg
193+
)
194+
tp.derivedAppliedType(tycon, args1)
195+
case tp =>
196+
tp
197+
}
198+
}
199+
200+
val widePt = if (ctx.scala2Mode || refinementIsInvariant(patternTp)) scrutineeTp else widenVariantParams(scrutineeTp)
201+
val narrowTp = SkolemType(patternTp)
202+
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
203+
isSubType(narrowTp, widePt)
204+
}
205+
}
206+
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ object AbsentContext {
2525

2626
/** Provides methods to compare types.
2727
*/
28-
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
28+
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] with PatternTypeConstrainer {
2929
import TypeComparer._
3030
implicit def ctx(implicit nc: AbsentContext): Context = initctx
3131

@@ -1227,7 +1227,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
12271227
* @see [[sufficientEither]] for the normal case
12281228
* @see [[necessaryEither]] for the GADTFlexible case
12291229
*/
1230-
private def either(op1: => Boolean, op2: => Boolean): Boolean =
1230+
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
12311231
if (ctx.mode.is(Mode.GADTflexible)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
12321232

12331233
/** Returns true iff the result of evaluating either `op1` or `op2` is true,

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,21 +1084,6 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10841084

10851085
def fromScala2x = unapplyFn.symbol.exists && (unapplyFn.symbol.owner is Scala2x)
10861086

1087-
/** Is `subtp` a subtype of `tp` or of some generalization of `tp`?
1088-
* The generalizations of a type T are the smallest set G such that
1089-
*
1090-
* - T is in G
1091-
* - If a typeref R in G represents a class or trait, R's superclass is in G.
1092-
* - If a type proxy P is not a reference to a class, P's supertype is in G
1093-
*/
1094-
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
1095-
if (constrainPatternType(subtp, tp)) true
1096-
else tp match {
1097-
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
1098-
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)
1099-
case _ => false
1100-
}
1101-
11021087
unapplyFn.tpe.widen match {
11031088
case mt: MethodType if mt.paramInfos.length == 1 =>
11041089
val unapplyArgType = mt.paramInfos.head
@@ -1108,17 +1093,15 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
11081093
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
11091094
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
11101095
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
1111-
} else if (isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))) {
1096+
} else {
1097+
// note that we simply ignore whether constraining actually succeeded or not
1098+
// in theory, constraining should only fail if the pattern cannot possibly match
1099+
// however, during exhaustivity checks, we perform a strictly better check
1100+
ctx.addMode(Mode.GADTflexible).typeComparer.constrainPatternType(unapplyArgType, selType)
11121101
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
11131102
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
11141103
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
11151104
unapplyArgType
1116-
} else {
1117-
unapp.println("Neither sub nor super")
1118-
unapp.println(TypeComparer.explained(implicit ctx => unapplyArgType <:< selType))
1119-
errorType(
1120-
ex"Pattern type $unapplyArgType is neither a subtype nor a supertype of selector type $selType",
1121-
tree.sourcePos)
11221105
}
11231106
val dummyArg = dummyTreeOfType(ownType)
11241107
val unapplyApp = typedExpr(untpd.TypedSplice(Apply(unapplyFn, dummyArg :: Nil)))

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -180,66 +180,6 @@ object Inferencing {
180180
def isSkolemFree(tp: Type)(implicit ctx: Context): Boolean =
181181
!tp.existsPart(_.isInstanceOf[SkolemType])
182182

183-
/** Derive information about a pattern type by comparing it with some variant of the
184-
* static scrutinee type. We have the following situation in case of a (dynamic) pattern match:
185-
*
186-
* StaticScrutineeType PatternType
187-
* \ /
188-
* DynamicScrutineeType
189-
*
190-
* If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained.
191-
* Now let's say we can prove that `PatternType <: StaticScrutineeType`.
192-
*
193-
* StaticScrutineeType
194-
* | \
195-
* | \
196-
* | \
197-
* | PatternType
198-
* | /
199-
* DynamicScrutineeType
200-
*
201-
* What can we say about the relationship of parameter types between `PatternType` and
202-
* `DynamicScrutineeType`?
203-
*
204-
* - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType`
205-
* in the same way as `PatternType` ("invariant refinement"), the subtype test
206-
* `PatternType <:< StaticScrutineeType` tells us all we need to know.
207-
* - Otherwise, if variant refinement is a possibility we can only make predictions
208-
* about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test
209-
* where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams`
210-
* replaces all type argument of variant parameters with empty bounds.
211-
*
212-
* Invariant refinement can be assumed if `PatternType`'s class(es) are final or
213-
* case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`).
214-
*/
215-
def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = {
216-
def refinementIsInvariant(tp: Type): Boolean = tp match {
217-
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
218-
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
219-
case tp: AndType => refinementIsInvariant(tp.tp1) && refinementIsInvariant(tp.tp2)
220-
case tp: OrType => refinementIsInvariant(tp.tp1) && refinementIsInvariant(tp.tp2)
221-
case _ => false
222-
}
223-
224-
def widenVariantParams = new TypeMap {
225-
def apply(tp: Type) = mapOver(tp) match {
226-
case tp @ AppliedType(tycon, args) =>
227-
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
228-
if (tparam.paramVariance != 0) TypeBounds.empty else arg
229-
)
230-
tp.derivedAppliedType(tycon, args1)
231-
case tp =>
232-
tp
233-
}
234-
}
235-
236-
val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt)
237-
val narrowTp = SkolemType(tp)
238-
trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
239-
narrowTp <:< widePt
240-
}
241-
}
242-
243183
/** The list of uninstantiated type variables bound by some prefix of type `T` which
244184
* occur in at least one formal parameter type of a prefix application.
245185
* Considered prefixes are:

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ class Typer extends Namer
604604
def handlePattern: Tree = {
605605
val tpt1 = typedTpt
606606
if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef)
607-
constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible))
607+
ctx.addMode(Mode.GADTflexible).typeComparer.constrainPatternType(tpt1.tpe, pt)
608608
// special case for an abstract type that comes with a class tag
609609
tryWithClassTag(ascription(tpt1, isWildcard = true), pt)
610610
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
object Test {
2+
trait Expr[+T]
3+
trait IntExpr extends Expr[Int]
4+
class Const[+T] extends Expr[T]
5+
final class Fin
6+
7+
def foo1[T](x: Unit | Const[T]): T = x match {
8+
case _: IntExpr => 0 // error
9+
}
10+
11+
def bar1[T](x: Const[T]): T = x match {
12+
case _: (Unit | IntExpr) => 0 // error
13+
}
14+
15+
def foo2[T](x: Fin | Const[T]): T = x match {
16+
case _: IntExpr => 0 // error
17+
}
18+
19+
def bar2[T](x: Const[T]): T = x match {
20+
case _: (Fin | IntExpr) => 0 // error
21+
}
22+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
trait Test {
2+
type A
3+
4+
enum Foo[X, Y] {
5+
case StrStr() extends Foo[String, String]
6+
case IntInt() extends Foo[Int, Int]
7+
}
8+
9+
def foo[T, U](f: Foo[A, T] | Foo[String, U]): Unit =
10+
f match { case Foo.StrStr() =>
11+
val t: T = "" // error
12+
val u: U = "" // error
13+
}
14+
}

0 commit comments

Comments
 (0)