Skip to content

Commit d994307

Browse files
committed
Ensure invariant refinement for classes extended case classes
1 parent 49f9dc6 commit d994307

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,15 @@ object Inferencing {
175175
* `DynamicScrutineeType`?
176176
*
177177
* - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType`
178-
* in the same way as `PatternType`, the subtype test `PatternType <:< StaticScrutineeType`
179-
* tells us all we need to know.
178+
* in the same way as `PatternType` ("invariant refinement"), the subtype test
179+
* `PatternType <:< StaticScrutineeType` tells us all we need to know.
180180
* - Otherwise, if variant refinement is a possibility we can only make predictions
181181
* about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test
182182
* where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams`
183183
* replaces all type argument of variant parameters with empty bounds.
184+
*
185+
* Invariant refinement can be assumed if `PatternType`'s class(es) are final or
186+
* case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`).
184187
*/
185188
def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context) = {
186189
def refinementIsInvariant(tp: Type): Boolean = tp match {

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,27 @@ object RefChecks {
641641
}
642642
}
643643

644+
/** Check that inheriting a case class does not constitute a variant refinement
645+
* of a base type of the case class. It is because of this restriction that we
646+
* can assume invariant refinement for case classes in `constrainPatternType`.
647+
*/
648+
def checkCaseClassInheritanceInvariant() = {
649+
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
650+
for (bc <- caseCls.info.baseClasses.tail)
651+
if (bc.typeParams.exists(_.paramVariance != 0)) {
652+
val caseBT = self.baseType(caseCls)
653+
val thisBT = self.baseType(bc)
654+
val combinedBT = caseBT.baseType(bc)
655+
if (!(thisBT =:= combinedBT))
656+
ctx.errorOrMigrationWarning(
657+
em"""illegal inheritance: $clazz inherits case $caseCls
658+
|but the two have different base type instances for $bc.
659+
|
660+
| Basetype for $clazz: $thisBT
661+
| Basetype via $caseCls: $combinedBT""", clazz.pos)
662+
}
663+
}
664+
644665
checkNoAbstractMembers()
645666
if (abstractErrors.isEmpty)
646667
checkNoAbstractDecls(clazz)
@@ -649,6 +670,7 @@ object RefChecks {
649670
ctx.error(abstractErrorMessage, clazz.pos)
650671

651672
checkMemberTypesOK()
673+
checkCaseClassInheritanceInvariant()
652674
} else if (clazz.is(Trait) && !(clazz derivesFrom defn.AnyValClass)) {
653675
// For non-AnyVal classes, prevent abstract methods in interfaces that override
654676
// final members in Object; see #4431

tests/neg/i3989b.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Test extends App {
2+
trait A[+X]
3+
case class B[+X](val x: X) extends A[X]
4+
class C[+X](x: Any) extends B[Any](x) with A[X] // error
5+
def f(a: A[Int]): Int = a match {
6+
case a: B[_] => a.x
7+
case _ => 0
8+
}
9+
f(new C[Int]("foo"))
10+
}

0 commit comments

Comments
 (0)