Skip to content

A simple way to check exhaustivity of GADTs #3510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,38 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
else text
}

/** Whether the counterexample is satisfiable. The space is flattened and non-empty. */
def satisfiable(sp: Space): Boolean = {
def impossible: Nothing = throw new AssertionError("`satisfiable` only accepts flattened space.")

def genConstraint(space: Space): List[(Type, Type)] = space match {
case Prod(tp, unappTp, unappSym, ss, _) =>
val tps = signature(unappTp, unappSym, ss.length)
ss.zip(tps).flatMap {
case (sp : Prod, tp) => sp.tp -> tp :: genConstraint(sp)
case (Typ(tp1, _), tp2) => tp1 -> tp2 :: Nil
case _ => impossible
}
case Typ(_, _) => Nil
case _ => impossible
}

def checkConstraint(constrs: List[(Type, Type)])(implicit ctx: Context): Boolean = {
val tvarMap = collection.mutable.Map.empty[Symbol, TypeVar]
val typeParamMap = new TypeMap() {
override def apply(tp: Type): Type = tp match {
case tref: TypeRef if tref.symbol.is(TypeParam) =>
tvarMap.getOrElseUpdate(tref.symbol, newTypeVar(tref.underlying.bounds))
case tp => mapOver(tp)
}
}

constrs.forall { case (tp1, tp2) => typeParamMap(tp1) <:< typeParamMap(tp2) }
}

checkConstraint(genConstraint(sp))(ctx.fresh.setNewTyperState())
}

/** Display spaces */
def show(s: Space): String = {
def params(tp: Type): List[Type] = tp.classSymbol.primaryConstructor.info.firstParamTypes
Expand Down Expand Up @@ -775,6 +807,16 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
res
}

/** Whehter counter-examples should be further checked? True for GADTs. */
def shouldCheckExamples(tp: Type): Boolean = {
new TypeAccumulator[Boolean] {
override def apply(b: Boolean, tp: Type): Boolean = tp match {
case tref: TypeRef if tref.symbol.is(TypeParam) && variance != 1 => true
case tp => b || foldOver(b, tp)
}
}.apply(false, tp)
}

def checkExhaustivity(_match: Match): Unit = {
val Match(sel, cases) = _match
val selTyp = sel.tpe.widen.dealias
Expand All @@ -785,10 +827,15 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
debug.println(s"${x.pat.show} ====> ${show(space)}")
space
}).reduce((a, b) => Or(List(a, b)))
val uncovered = simplify(minus(Typ(selTyp, true), patternSpace), aggressive = true)

if (uncovered != Empty)
ctx.warning(PatternMatchExhaustivity(show(uncovered)), sel.pos)
val checkGADTSAT = shouldCheckExamples(selTyp)

val uncovered =
flatten(simplify(minus(Typ(selTyp, true), patternSpace), aggressive = true))
.filter(s => s != Empty && (!checkGADTSAT || satisfiable(s)))

if (uncovered.nonEmpty)
ctx.warning(PatternMatchExhaustivity(show(Or(uncovered))), sel.pos)
}

def checkRedundancy(_match: Match): Unit = {
Expand Down
10 changes: 10 additions & 0 deletions tests/patmat/3454.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object O {
sealed trait Expr[T]
case class BExpr(bool: Boolean) extends Expr[Boolean]
case class IExpr(int: Int) extends Expr[Int]

def join[T](e1: Expr[T], e2: Expr[T]): Expr[T] = (e1, e2) match {
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
}
}
2 changes: 1 addition & 1 deletion tests/patmat/exhausting.check
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
32: Pattern Match Exhaustivity: List(_, _: _*)
39: Pattern Match Exhaustivity: Bar3
44: Pattern Match Exhaustivity: (Bar2, Bar2)
53: Pattern Match Exhaustivity: (Bar2, Bar2), (Bar2, Bar1), (Bar1, Bar3), (Bar1, Bar2)
50: Pattern Match Exhaustivity: (Bar2, Bar2)
3 changes: 0 additions & 3 deletions tests/patmat/exhausting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ object Test {
case (Bar2, Bar3) => ()
case (Bar3, _) => ()
}
// fails for: (Bar1, Bar2)
// fails for: (Bar1, Bar3)
// fails for: (Bar2, Bar1)
// fails for: (Bar2, Bar2)
def fail5[T](xx: (Foo[T], Foo[T])) = xx match {
case (Bar1, Bar1) => ()
Expand Down
23 changes: 23 additions & 0 deletions tests/patmat/gadt-basic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
object O1 {
sealed trait Expr[T]
case class BExpr(bool: Boolean) extends Expr[Boolean]
case class IExpr(int: Int) extends Expr[Int]

def join[T](e1: Expr[T], e2: Expr[T]): Expr[T] = (e1, e2) match {
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
}
}

object O2 {
sealed trait GADT[A, B]
case object IntString extends GADT[Int, String]
case object IntFloat extends GADT[Int, Float]
case object FloatFloat extends GADT[Float, Float]

def m[A, B](g1: GADT[A, B], g2: GADT[A, B]) = (g1, g2) match {
case (IntString, IntString) => ;
case (IntFloat, IntFloat) => ;
case (FloatFloat, FloatFloat) => ;
}
}
1 change: 1 addition & 0 deletions tests/patmat/gadt-covariant.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6: Pattern Match Exhaustivity: (BExpr(_), IExpr(_)), (IExpr(_), BExpr(_))
10 changes: 10 additions & 0 deletions tests/patmat/gadt-covariant.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object O {
sealed trait Expr[+T]
case class IExpr(x: Int) extends Expr[Int]
case class BExpr(b: Boolean) extends Expr[Boolean]

def foo[T](x: Expr[T], y: Expr[T]) = (x, y) match {
case (IExpr(_), IExpr(_)) => true
case (BExpr(_), BExpr(_)) => false
}
}
10 changes: 10 additions & 0 deletions tests/patmat/gadt-invariant.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object O {
sealed trait Expr[T]
case class IExpr(x: Int) extends Expr[Int]
case class BExpr(b: Boolean) extends Expr[Boolean]

def foo[T](x: Expr[T], y: Expr[T]) = (x, y) match {
case (IExpr(_), IExpr(_)) => true
case (BExpr(_), BExpr(_)) => false
}
}
1 change: 1 addition & 0 deletions tests/patmat/gadt-nontrivial.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
7: Pattern Match Exhaustivity: (IntExpr(_), AddExpr(_, _))
12 changes: 12 additions & 0 deletions tests/patmat/gadt-nontrivial.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
object O {
sealed trait Expr[T]
case class BoolExpr(v: Boolean) extends Expr[Boolean]
case class IntExpr(v: Int) extends Expr[Int]
case class AddExpr(e1: Expr[Int], e2: Expr[Int]) extends Expr[Int]

def join[T](e1: Expr[T], e2: Expr[T]): Expr[T] = (e1, e2) match {
case (BoolExpr(b1), BoolExpr(b2)) => BoolExpr(b1 && b2)
case (IntExpr(i1), IntExpr(i2)) => IntExpr(i1 + i2)
case (AddExpr(ei1, ei2), ie) => join(join(ei1, ei2), ie)
}
}
1 change: 1 addition & 0 deletions tests/patmat/gadt-nontrivial2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
13:
16 changes: 16 additions & 0 deletions tests/patmat/gadt-nontrivial2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
object O {
sealed trait Nat
type Zero = Nat
sealed trait Succ[N <: Nat] extends Nat

sealed trait NVec[N <: Nat, +A]
case object NEmpty extends NVec[Zero, Nothing]
case class NCons[N <: Nat, +A](head: A, tail: NVec[N, A]) extends NVec[Succ[N], A]

def nzip[N <: Nat, A, B](v1: NVec[N, A], v2: NVec[N, B]): NVec[N, (A, B)] =
(v1, v2) match {
case (NEmpty, NEmpty) => NEmpty
case (NCons(a, atail), NCons(b, btail)) =>
NCons((a, b), nzip(atail, btail))
}
}
File renamed without changes.
23 changes: 23 additions & 0 deletions tests/patmat/t9926.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
object Model {
sealed trait Field[T]
case object StringField extends Field[String]
case object BoolField extends Field[Boolean]

sealed trait Value[T]
case class Literal(v: String) extends Value[String]
case class Bool(v: Boolean) extends Value[Boolean]

sealed trait Expression[T]
case class Equality[T](field: Field[T], value: Value[T]) extends Expression[T]

def interpret[T](expr: Expression[T]): Int = expr match {
case Equality(StringField, Literal(v)) => 1
case Equality(BoolField, Bool(v)) => 2
}

// T - T(s1, s2, ...)
// if T contains type parameter, get its dimension types.
//
// actively decompose its dimensions, and then do subtype checking to
// see if the type paramter can be instantiated.
}