Skip to content

Commit f5a1ef2

Browse files
committed
Use Skolems to infer GADT constraints
The rationale for using a Skolem here is: we want to record that there is at least one value that is both of the pattern type and the scrutinee type. All symbols are now considered valid for adding GADT constraints - the rationale is that set of constrainable symbols should be either selected on a per-(sub)pattern basis, or be the same for all matches. Previously, symbols which were only appearing variantly in a scrutinee type could be considered constrainable anyway because of an outer pattern match.
1 parent eac0e53 commit f5a1ef2

File tree

10 files changed

+216
-54
lines changed

10 files changed

+216
-54
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3704,7 +3704,12 @@ object Types {
37043704

37053705
// ----- Skolem types -----------------------------------------------
37063706

3707-
/** A skolem type reference with underlying type `info` */
3707+
/** A skolem type reference with underlying type `info`.
3708+
*
3709+
* For Dotty, a skolem type is a singleton type of some unknown value of type `info`.
3710+
* Note that care is needed when creating them, since not all types need to be inhabited.
3711+
* A skolem is equal to itself and no other type.
3712+
*/
37083713
case class SkolemType(info: Type) extends UncachedProxyType with ValueType with SingletonType {
37093714
override def underlying(implicit ctx: Context): Type = info
37103715
def derivedSkolemType(info: Type)(implicit ctx: Context): SkolemType =

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10901090
* - If a type proxy P is not a reference to a class, P's supertype is in G
10911091
*/
10921092
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
1093-
if (constrainPatternType(subtp, tp)) true
1093+
if (constrainPatternType(subtp, tp, termPattern = true)) true
10941094
else tp match {
10951095
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
10961096
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -153,41 +153,22 @@ object Inferencing {
153153
def isSkolemFree(tp: Type)(implicit ctx: Context): Boolean =
154154
!tp.existsPart(_.isInstanceOf[SkolemType])
155155

156-
/** Derive information about a pattern type by comparing it with some variant of the
157-
* static scrutinee type. We have the following situation in case of a (dynamic) pattern match:
156+
/** Infer constraints that should be in scope for a case body with given pattern and scrutinee types.
158157
*
159-
* StaticScrutineeType PatternType
160-
* \ /
161-
* DynamicScrutineeType
158+
* If `termPattern`, infer constraints from knowing that there exists a value which of both scrutinee
159+
* and pattern types (which is the case for normal pattern matching). If not `termPattern`, instead
160+
* infer constraints from knowing that `tp <: pt`.
162161
*
163-
* If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained.
164-
* Now let's say we can prove that `PatternType <: StaticScrutineeType`.
162+
* If a pattern matches during normal pattern matching, we can be certain that there exists a value
163+
* which is of both scrutinee and pattern types (the value we're matching on). If this value
164+
* was in a variable, say `x`, then we could simply infer constraints from `x.type <: pt`. Since we might
165+
* be matching on an expression as well, we take a skolem of the scrutinee, which is essentially an existential
166+
* singleton type (see [[dotty.tools.dotc.core.Types.SkolemType]]).
165167
*
166-
* StaticScrutineeType
167-
* | \
168-
* | \
169-
* | \
170-
* | PatternType
171-
* | /
172-
* DynamicScrutineeType
173-
*
174-
* What can we say about the relationship of parameter types between `PatternType` and
175-
* `DynamicScrutineeType`?
176-
*
177-
* - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType`
178-
* in the same way as `PatternType` ("invariant refinement"), the subtype test
179-
* `PatternType <:< StaticScrutineeType` tells us all we need to know.
180-
* - Otherwise, if variant refinement is a possibility we can only make predictions
181-
* about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test
182-
* where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams`
183-
* replaces all type argument of variant parameters with empty bounds.
184-
*
185-
* Invariant refinement can be assumed if `PatternType`'s class(es) are final or
186-
* case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`).
187-
*
188-
* TODO: Update so that GADT symbols can be variant, and we special case final class types in patterns
168+
* Note that we need to sometimes widen type parameters of the scrutinee type to avoid unsoundness -
169+
* see i3989c.scala and related issue discussion on Github.
189170
*/
190-
def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = {
171+
def constrainPatternType(tp: Type, pt: Type, termPattern: Boolean)(implicit ctx: Context): Boolean = {
191172
def refinementIsInvariant(tp: Type): Boolean = tp match {
192173
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
193174
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
@@ -209,8 +190,9 @@ object Inferencing {
209190
}
210191

211192
val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt)
212-
trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
213-
tp <:< widePt
193+
val narrowTp = if (termPattern) SkolemType(tp) else tp
194+
trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
195+
narrowTp <:< widePt
214196
}
215197
}
216198

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ class Typer extends Namer
604604
def handlePattern: Tree = {
605605
val tpt1 = typedTpt
606606
if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef)
607-
constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible))
607+
constrainPatternType(tpt1.tpe, pt, termPattern = true)(ctx.addMode(Mode.GADTflexible))
608608
// special case for an abstract type that comes with a class tag
609609
tryWithClassTag(ascription(tpt1, isWildcard = true), pt)
610610
}
@@ -1104,7 +1104,7 @@ class Typer extends Namer
11041104
def caseRest(implicit ctx: Context) = {
11051105
val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern)))
11061106
if (!ctx.isAfterTyper)
1107-
constrainPatternType(pat1.tpe, selType)(ctx.addMode(Mode.GADTflexible))
1107+
constrainPatternType(pat1.tpe, selType, termPattern = false)(ctx.addMode(Mode.GADTflexible))
11081108
val pat2 = indexPattern(cdef).transform(pat1)
11091109
val body1 = typedType(cdef.body, pt)
11101110
assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
object buffer {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
case class Inv[T](t: T)
7+
8+
enum EQ[A, B] { case Refl[T]() extends EQ[T, T] }
9+
enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] } // A <: B
10+
11+
def test_eq1[A, B](eq: EQ[A, B], a: A, b: B): B =
12+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
13+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) | Sko(Int)
14+
eq match { case EQ.Refl() => // a = b
15+
val success: A = b
16+
val fail: A = 0 // error
17+
0 // error
18+
}
19+
}
20+
}
21+
22+
def test_eq2[A, B](eq: EQ[A, B], a: A, b: B): B =
23+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
24+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int)
25+
eq match { case EQ.Refl() => // a = b
26+
val success: A = b
27+
val fail: A = 0 // error
28+
0 // error
29+
}
30+
}
31+
}
32+
33+
def test_sub1[A, B](sub: SUB[A, B], a: A, b: B): B =
34+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int)
35+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int)
36+
sub match { case SUB.Refl() => // b >: a
37+
val success: B = a
38+
val fail: A = 0 // error
39+
0 // error
40+
}
41+
}
42+
}
43+
44+
def test_sub2[A, B](sub: SUB[A, B], a: A, b: B): B =
45+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
46+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int)
47+
sub match { case SUB.Refl() => // b >: a
48+
val success: B = a
49+
val fail: A = 0 // error
50+
0 // error
51+
}
52+
}
53+
}
54+
55+
56+
def test_sub_eq[A, B, C](sub: SUB[A|B, C], eqA: EQ[A, 5], eqB: EQ[B, 6]): C =
57+
sub match { case SUB.Refl() => // C >: A | B
58+
eqA match { case EQ.Refl() => // A = 5
59+
eqB match { case EQ.Refl() => // B = 6
60+
val fail1: A = 0 // error
61+
val fail2: B = 0 // error
62+
0 // error
63+
}
64+
}
65+
}
66+
}

tests/neg/int-extractor.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
object Test {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
def foo1[T](t: T): T = t match {
7+
case EssaInt(_) =>
8+
0 // error
9+
}
10+
11+
def foo2[T](t: T): T = t match {
12+
case EssaInt(_) => t match {
13+
case EssaInt(_) =>
14+
0 // error
15+
}
16+
}
17+
18+
case class Inv[T](t: T)
19+
20+
def bar1[T](t: T): T = Inv(t) match {
21+
case Inv(EssaInt(_)) =>
22+
0 // error
23+
}
24+
25+
def bar2[T](t: T): T = t match {
26+
case Inv(EssaInt(_)) => t match {
27+
case Inv(EssaInt(_)) =>
28+
0 // error
29+
}
30+
}
31+
}

tests/neg/invariant-gadt.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object `invariant-gadt` {
2+
case class Invariant[T](value: T)
3+
4+
def unsound0[T](t: T): T = Invariant(t) match {
5+
case Invariant(_: Int) =>
6+
(0: Any) // error
7+
}
8+
9+
def unsound1[T](t: T): T = Invariant(t) match {
10+
case Invariant(_: Int) =>
11+
0 // error
12+
}
13+
14+
def unsound2[T](t: T): T = Invariant(t) match {
15+
case Invariant(value) => value match {
16+
case _: Int =>
17+
0 // error
18+
}
19+
}
20+
21+
def unsoundTwice[T](t: T): T = Invariant(t) match {
22+
case Invariant(_: Int) => Invariant(t) match {
23+
case Invariant(_: Int) =>
24+
0 // error
25+
}
26+
}
27+
}

tests/neg/typeclass-derivation2.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ object TypeLevel {
111111
* It informs that type `T` has shape `S` and also implements runtime reflection on `T`.
112112
*/
113113
abstract class Shaped[T, S <: Shape] extends Reflected[T]
114+
115+
// substitute for erasedValue that allows precise matching
116+
final abstract class Type[-A, +B]
117+
type Subtype[t] = Type[_, t]
118+
type Supertype[t] = Type[t, _]
119+
type Exactly[t] = Type[t, t]
120+
erased def typeOf[T]: Type[T, T] = ???
114121
}
115122

116123
// An algebraic datatype
@@ -203,7 +210,7 @@ trait Show[T] {
203210
def show(x: T): String
204211
}
205212
object Show {
206-
import scala.compiletime.erasedValue
213+
import scala.compiletime.{erasedValue, error}
207214
import TypeLevel._
208215

209216
inline def tryShow[T](x: T): String = implicit match {
@@ -229,9 +236,14 @@ object Show {
229236
inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String =
230237
inline erasedValue[Alts] match {
231238
case _: (Shape.Case[alt, elems] *: alts1) =>
232-
x match {
233-
case x: `alt` => showCase[T, elems](r, x)
234-
case _ => showCases[T, alts1](r, x)
239+
inline typeOf[alt] match {
240+
case _: Subtype[T] =>
241+
x match {
242+
case x: `alt` => showCase[T, elems](r, x)
243+
case _ => showCases[T, alts1](r, x)
244+
}
245+
case _ =>
246+
error("invalid call to showCases: one of Alts is not a subtype of T")
235247
}
236248
case _: Unit =>
237249
throw new MatchError(x)

tests/pos/precise-pattern-type.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object `precise-pattern-type` {
2+
class Type {
3+
def isType: Boolean = true
4+
}
5+
6+
class Tree[-T >: Null] {
7+
def tpe: T @annotation.unchecked.uncheckedVariance = ???
8+
}
9+
10+
case class Select[-T >: Null](qual: Tree[T]) extends Tree[T]
11+
12+
def test[T <: Tree[Type]](tree: T) = tree match {
13+
case Select(q) =>
14+
q.tpe.isType
15+
}
16+
}

tests/run/typeclass-derivation2.scala

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ object TypeLevel {
113113
* It informs that type `T` has shape `S` and also implements runtime reflection on `T`.
114114
*/
115115
abstract class Shaped[T, S <: Shape] extends Reflected[T]
116+
117+
// substitute for erasedValue that allows precise matching
118+
final abstract class Type[-A, +B]
119+
type Subtype[t] = Type[_, t]
120+
type Supertype[t] = Type[t, _]
121+
type Exactly[t] = Type[t, t]
122+
erased def typeOf[T]: Type[T, T] = ???
116123
}
117124

118125
// An algebraic datatype
@@ -217,7 +224,7 @@ trait Eq[T] {
217224
}
218225

219226
object Eq {
220-
import scala.compiletime.erasedValue
227+
import scala.compiletime.{erasedValue, error}
221228
import TypeLevel._
222229

223230
inline def tryEql[T](x: T, y: T) = implicit match {
@@ -239,8 +246,13 @@ object Eq {
239246
inline def eqlCases[T, Alts <: Tuple](xm: Mirror, ym: Mirror, ordinal: Int, n: Int): Boolean =
240247
inline erasedValue[Alts] match {
241248
case _: (Shape.Case[alt, elems] *: alts1) =>
242-
if (n == ordinal) eqlElems[elems](xm, ym, 0)
243-
else eqlCases[T, alts1](xm, ym, ordinal, n + 1)
249+
inline typeOf[alt] match {
250+
case _: Subtype[T] =>
251+
if (n == ordinal) eqlElems[elems](xm, ym, 0)
252+
else eqlCases[T, alts1](xm, ym, ordinal, n + 1)
253+
case _ =>
254+
error("invalid call to eqlCases: one of Alts is not a subtype of T")
255+
}
244256
case _: Unit =>
245257
false
246258
}
@@ -271,7 +283,7 @@ trait Pickler[T] {
271283
}
272284

273285
object Pickler {
274-
import scala.compiletime.{erasedValue, constValue}
286+
import scala.compiletime.{erasedValue, constValue, error}
275287
import TypeLevel._
276288

277289
def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1)
@@ -294,12 +306,17 @@ object Pickler {
294306
inline def pickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T, n: Int): Unit =
295307
inline erasedValue[Alts] match {
296308
case _: (Shape.Case[alt, elems] *: alts1) =>
297-
x match {
298-
case x: `alt` =>
299-
buf += n
300-
pickleCase[T, elems](r, buf, x)
309+
inline typeOf[alt] match {
310+
case _: Subtype[T] =>
311+
x match {
312+
case x: `alt` =>
313+
buf += n
314+
pickleCase[T, elems](r, buf, x)
315+
case _ =>
316+
pickleCases[T, alts1](r, buf, x, n + 1)
317+
}
301318
case _ =>
302-
pickleCases[T, alts1](r, buf, x, n + 1)
319+
error("invalid pickleCases call: one of Alts is not a subtype of T")
303320
}
304321
case _: Unit =>
305322
}
@@ -362,7 +379,7 @@ trait Show[T] {
362379
def show(x: T): String
363380
}
364381
object Show {
365-
import scala.compiletime.erasedValue
382+
import scala.compiletime.{erasedValue, error}
366383
import TypeLevel._
367384

368385
inline def tryShow[T](x: T): String = implicit match {
@@ -388,9 +405,15 @@ object Show {
388405
inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String =
389406
inline erasedValue[Alts] match {
390407
case _: (Shape.Case[alt, elems] *: alts1) =>
391-
x match {
392-
case x: `alt` => showCase[T, elems](r, x)
393-
case _ => showCases[T, alts1](r, x)
408+
inline typeOf[alt] match {
409+
case _: Subtype[T] =>
410+
x match {
411+
case x: `alt` =>
412+
showCase[T, elems](r, x)
413+
case _ => showCases[T, alts1](r, x)
414+
}
415+
case _ =>
416+
error("invalid call to showCases: one of Alts is not a subtype of T")
394417
}
395418
case _: Unit =>
396419
throw new MatchError(x)

0 commit comments

Comments
 (0)