Skip to content

Commit 41279ac

Browse files
authored
Honour hard unions in lubbing and param replacing (#18680)
2 parents 8cb4945 + cc55175 commit 41279ac

File tree

5 files changed

+60
-13
lines changed

5 files changed

+60
-13
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ abstract class Constraint extends Showable {
138138
/** The same as this constraint, but with `tv` marked as hard. */
139139
def withHard(tv: TypeVar)(using Context): This
140140

141+
/** Mark toplevel type vars in `tp` as hard. */
142+
def hardenTypeVars(tp: Type)(using Context): This
143+
141144
/** Gives for each instantiated type var that does not yet have its `inst` field
142145
* set, the instance value stored in the constraint. Storing instances in constraints
143146
* is done only in a temporary way for contexts that may be retracted

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,18 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
750750
}
751751
if isRemovable(param.binder) then current = current.remove(param.binder)
752752
current.dropDeps(param)
753+
replacedTypeVar match
754+
case replacedTypeVar: TypeVar if isHard(replacedTypeVar) => current = current.hardenTypeVars(replacement)
755+
case _ =>
753756
current.checkWellFormed()
754757
end replace
755758

759+
def hardenTypeVars(tp: Type)(using Context): OrderingConstraint = tp.dealiasKeepRefiningAnnots match
760+
case tp: TypeVar if contains(tp.origin) => withHard(tp)
761+
case tp: TypeParamRef if contains(tp) => hardenTypeVars(typeVarOfParam(tp))
762+
case tp: AndOrType => hardenTypeVars(tp.tp1).hardenTypeVars(tp.tp2)
763+
case _ => this
764+
756765
def remove(pt: TypeLambda)(using Context): This = {
757766
def removeFromOrdering(po: ParamOrdering) = {
758767
def removeFromBoundss(key: TypeLambda, bndss: Array[List[TypeParamRef]]): Array[List[TypeParamRef]] = {

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -501,17 +501,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
501501
false
502502
}
503503

504-
/** Mark toplevel type vars in `tp2` as hard in the current constraint */
505-
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
506-
case tvar: TypeVar if constraint.contains(tvar.origin) =>
507-
constraint = constraint.withHard(tvar)
508-
case tp2: TypeParamRef if constraint.contains(tp2) =>
509-
hardenTypeVars(constraint.typeVarOfParam(tp2))
510-
case tp2: AndOrType =>
511-
hardenTypeVars(tp2.tp1)
512-
hardenTypeVars(tp2.tp2)
513-
case _ =>
514-
515504
val res = widenOK || joinOK
516505
|| recur(tp11, tp2) && recur(tp12, tp2)
517506
|| containsAnd(tp1)
@@ -534,7 +523,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
534523
// is marked so that it converts all soft unions in its lower bound to hard unions
535524
// before it is instantiated. The reason is that the variable's instance type will
536525
// be a supertype of (decomposed and reconstituted) `tp1`.
537-
hardenTypeVars(tp2)
526+
constraint = constraint.hardenTypeVars(tp2)
538527

539528
res
540529

@@ -2375,7 +2364,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
23752364
case Atoms.Range(lo2, hi2) =>
23762365
if hi1.subsetOf(lo2) then return tp2
23772366
if hi2.subsetOf(lo1) then return tp1
2378-
if (hi1 & hi2).isEmpty then return orType(tp1, tp2)
2367+
if (hi1 & hi2).isEmpty then return orType(tp1, tp2, isSoft = isSoft)
23792368
case none =>
23802369
case none =>
23812370
val t1 = mergeIfSuper(tp1, tp2, canConstrain)

tests/pos/i18626.min1.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
sealed trait Animal
2+
object Cat extends Animal
3+
object Dog extends Animal
4+
5+
type Mammal = Cat.type | Dog.type
6+
7+
class Test:
8+
def t1 =
9+
val mammals: List[Mammal] = ???
10+
val result = mammals.head
11+
val mammal: Mammal = result // was: Type Mismatch Error:
12+
// Found: (result : Animal)
13+
// Required: Mammal
14+
()

tests/pos/i18626.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
trait Random[F1[_]]:
2+
def element[T1](list: Seq[T1]): F1[T1] = ???
3+
4+
trait Monad[F2[_]]:
5+
def map[A1, B1](fa: F2[A1])(f: A1 => B1): F2[B1]
6+
7+
object Monad:
8+
extension [F3[_]: Monad, A3](fa: F3[A3])
9+
def map[B3](f: A3 => B3): F3[B3] = ???
10+
11+
sealed trait Animal
12+
object Cat extends Animal
13+
object Dog extends Animal
14+
15+
type Mammal = Cat.type | Dog.type
16+
val mammals: List[Mammal] = ???
17+
18+
class Work[F4[_]](random: Random[F4])(using mf: Monad[F4]):
19+
def result1: F4[Mammal] =
20+
mf.map(fa = random.element(mammals))(a => a)
21+
22+
def result2: F4[Mammal] = Monad.map(random.element(mammals))(a => a)
23+
24+
import Monad.*
25+
26+
def result3: F4[Mammal] = random
27+
.element(mammals)
28+
.map { a =>
29+
a // was: Type Mismatch Error:
30+
// Found: (a : Animal)
31+
// Required: Cat.type | Dog.type
32+
}

0 commit comments

Comments
 (0)