Skip to content

Commit 7e919b6

Browse files
authored
Merge pull request #3510 from dotty-staging/simple-gadt-check
A simple way to check exhaustivity of GADTs
2 parents 60d04ca + cd5a1d1 commit 7e919b6

14 files changed

+158
-7
lines changed

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,38 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
708708
else text
709709
}
710710

711+
/** Whether the counterexample is satisfiable. The space is flattened and non-empty. */
712+
def satisfiable(sp: Space): Boolean = {
713+
def impossible: Nothing = throw new AssertionError("`satisfiable` only accepts flattened space.")
714+
715+
def genConstraint(space: Space): List[(Type, Type)] = space match {
716+
case Prod(tp, unappTp, unappSym, ss, _) =>
717+
val tps = signature(unappTp, unappSym, ss.length)
718+
ss.zip(tps).flatMap {
719+
case (sp : Prod, tp) => sp.tp -> tp :: genConstraint(sp)
720+
case (Typ(tp1, _), tp2) => tp1 -> tp2 :: Nil
721+
case _ => impossible
722+
}
723+
case Typ(_, _) => Nil
724+
case _ => impossible
725+
}
726+
727+
def checkConstraint(constrs: List[(Type, Type)])(implicit ctx: Context): Boolean = {
728+
val tvarMap = collection.mutable.Map.empty[Symbol, TypeVar]
729+
val typeParamMap = new TypeMap() {
730+
override def apply(tp: Type): Type = tp match {
731+
case tref: TypeRef if tref.symbol.is(TypeParam) =>
732+
tvarMap.getOrElseUpdate(tref.symbol, newTypeVar(tref.underlying.bounds))
733+
case tp => mapOver(tp)
734+
}
735+
}
736+
737+
constrs.forall { case (tp1, tp2) => typeParamMap(tp1) <:< typeParamMap(tp2) }
738+
}
739+
740+
checkConstraint(genConstraint(sp))(ctx.fresh.setNewTyperState())
741+
}
742+
711743
/** Display spaces */
712744
def show(s: Space): String = {
713745
def params(tp: Type): List[Type] = tp.classSymbol.primaryConstructor.info.firstParamTypes
@@ -775,6 +807,16 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
775807
res
776808
}
777809

810+
/** Whehter counter-examples should be further checked? True for GADTs. */
811+
def shouldCheckExamples(tp: Type): Boolean = {
812+
new TypeAccumulator[Boolean] {
813+
override def apply(b: Boolean, tp: Type): Boolean = tp match {
814+
case tref: TypeRef if tref.symbol.is(TypeParam) && variance != 1 => true
815+
case tp => b || foldOver(b, tp)
816+
}
817+
}.apply(false, tp)
818+
}
819+
778820
def checkExhaustivity(_match: Match): Unit = {
779821
val Match(sel, cases) = _match
780822
val selTyp = sel.tpe.widen.dealias
@@ -785,10 +827,15 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
785827
debug.println(s"${x.pat.show} ====> ${show(space)}")
786828
space
787829
}).reduce((a, b) => Or(List(a, b)))
788-
val uncovered = simplify(minus(Typ(selTyp, true), patternSpace), aggressive = true)
789830

790-
if (uncovered != Empty)
791-
ctx.warning(PatternMatchExhaustivity(show(uncovered)), sel.pos)
831+
val checkGADTSAT = shouldCheckExamples(selTyp)
832+
833+
val uncovered =
834+
flatten(simplify(minus(Typ(selTyp, true), patternSpace), aggressive = true))
835+
.filter(s => s != Empty && (!checkGADTSAT || satisfiable(s)))
836+
837+
if (uncovered.nonEmpty)
838+
ctx.warning(PatternMatchExhaustivity(show(Or(uncovered))), sel.pos)
792839
}
793840

794841
def checkRedundancy(_match: Match): Unit = {

tests/patmat/3454.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object O {
2+
sealed trait Expr[T]
3+
case class BExpr(bool: Boolean) extends Expr[Boolean]
4+
case class IExpr(int: Int) extends Expr[Int]
5+
6+
def join[T](e1: Expr[T], e2: Expr[T]): Expr[T] = (e1, e2) match {
7+
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
8+
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
9+
}
10+
}

tests/patmat/exhausting.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
32: Pattern Match Exhaustivity: List(_, _: _*)
44
39: Pattern Match Exhaustivity: Bar3
55
44: Pattern Match Exhaustivity: (Bar2, Bar2)
6-
53: Pattern Match Exhaustivity: (Bar2, Bar2), (Bar2, Bar1), (Bar1, Bar3), (Bar1, Bar2)
6+
50: Pattern Match Exhaustivity: (Bar2, Bar2)

tests/patmat/exhausting.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ object Test {
4646
case (Bar2, Bar3) => ()
4747
case (Bar3, _) => ()
4848
}
49-
// fails for: (Bar1, Bar2)
50-
// fails for: (Bar1, Bar3)
51-
// fails for: (Bar2, Bar1)
5249
// fails for: (Bar2, Bar2)
5350
def fail5[T](xx: (Foo[T], Foo[T])) = xx match {
5451
case (Bar1, Bar1) => ()

tests/patmat/gadt-basic.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
object O1 {
2+
sealed trait Expr[T]
3+
case class BExpr(bool: Boolean) extends Expr[Boolean]
4+
case class IExpr(int: Int) extends Expr[Int]
5+
6+
def join[T](e1: Expr[T], e2: Expr[T]): Expr[T] = (e1, e2) match {
7+
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
8+
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
9+
}
10+
}
11+
12+
object O2 {
13+
sealed trait GADT[A, B]
14+
case object IntString extends GADT[Int, String]
15+
case object IntFloat extends GADT[Int, Float]
16+
case object FloatFloat extends GADT[Float, Float]
17+
18+
def m[A, B](g1: GADT[A, B], g2: GADT[A, B]) = (g1, g2) match {
19+
case (IntString, IntString) => ;
20+
case (IntFloat, IntFloat) => ;
21+
case (FloatFloat, FloatFloat) => ;
22+
}
23+
}

tests/patmat/gadt-covariant.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
6: Pattern Match Exhaustivity: (BExpr(_), IExpr(_)), (IExpr(_), BExpr(_))

tests/patmat/gadt-covariant.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object O {
2+
sealed trait Expr[+T]
3+
case class IExpr(x: Int) extends Expr[Int]
4+
case class BExpr(b: Boolean) extends Expr[Boolean]
5+
6+
def foo[T](x: Expr[T], y: Expr[T]) = (x, y) match {
7+
case (IExpr(_), IExpr(_)) => true
8+
case (BExpr(_), BExpr(_)) => false
9+
}
10+
}

tests/patmat/gadt-invariant.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object O {
2+
sealed trait Expr[T]
3+
case class IExpr(x: Int) extends Expr[Int]
4+
case class BExpr(b: Boolean) extends Expr[Boolean]
5+
6+
def foo[T](x: Expr[T], y: Expr[T]) = (x, y) match {
7+
case (IExpr(_), IExpr(_)) => true
8+
case (BExpr(_), BExpr(_)) => false
9+
}
10+
}

tests/patmat/gadt-nontrivial.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
7: Pattern Match Exhaustivity: (IntExpr(_), AddExpr(_, _))

tests/patmat/gadt-nontrivial.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
object O {
2+
sealed trait Expr[T]
3+
case class BoolExpr(v: Boolean) extends Expr[Boolean]
4+
case class IntExpr(v: Int) extends Expr[Int]
5+
case class AddExpr(e1: Expr[Int], e2: Expr[Int]) extends Expr[Int]
6+
7+
def join[T](e1: Expr[T], e2: Expr[T]): Expr[T] = (e1, e2) match {
8+
case (BoolExpr(b1), BoolExpr(b2)) => BoolExpr(b1 && b2)
9+
case (IntExpr(i1), IntExpr(i2)) => IntExpr(i1 + i2)
10+
case (AddExpr(ei1, ei2), ie) => join(join(ei1, ei2), ie)
11+
}
12+
}

tests/patmat/gadt-nontrivial2.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
13:

tests/patmat/gadt-nontrivial2.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object O {
2+
sealed trait Nat
3+
type Zero = Nat
4+
sealed trait Succ[N <: Nat] extends Nat
5+
6+
sealed trait NVec[N <: Nat, +A]
7+
case object NEmpty extends NVec[Zero, Nothing]
8+
case class NCons[N <: Nat, +A](head: A, tail: NVec[N, A]) extends NVec[Succ[N], A]
9+
10+
def nzip[N <: Nat, A, B](v1: NVec[N, A], v2: NVec[N, B]): NVec[N, (A, B)] =
11+
(v1, v2) match {
12+
case (NEmpty, NEmpty) => NEmpty
13+
case (NCons(a, atail), NCons(b, btail)) =>
14+
NCons((a, b), nzip(atail, btail))
15+
}
16+
}
File renamed without changes.

tests/patmat/t9926.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
object Model {
2+
sealed trait Field[T]
3+
case object StringField extends Field[String]
4+
case object BoolField extends Field[Boolean]
5+
6+
sealed trait Value[T]
7+
case class Literal(v: String) extends Value[String]
8+
case class Bool(v: Boolean) extends Value[Boolean]
9+
10+
sealed trait Expression[T]
11+
case class Equality[T](field: Field[T], value: Value[T]) extends Expression[T]
12+
13+
def interpret[T](expr: Expression[T]): Int = expr match {
14+
case Equality(StringField, Literal(v)) => 1
15+
case Equality(BoolField, Bool(v)) => 2
16+
}
17+
18+
// T - T(s1, s2, ...)
19+
// if T contains type parameter, get its dimension types.
20+
//
21+
// actively decompose its dimensions, and then do subtype checking to
22+
// see if the type paramter can be instantiated.
23+
}

0 commit comments

Comments
 (0)