Skip to content

Commit 0e8587c

Browse files
committed
improve type inference for parameters of higher-kinded types
1 parent 29e4b05 commit 0e8587c

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import Phases.{gettersPhase, elimByNamePhase}
88
import StdNames.nme
99
import TypeOps.refineUsingParent
1010
import collection.mutable
11+
import annotation.tailrec
1112
import util.Stats
1213
import config.Config
1314
import config.Feature.migrateTo3
@@ -178,16 +179,37 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
178179
try op finally comparedTypeLambdas = saved
179180

180181
protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
182+
inline def followAlias[T](inline tp: Type)(inline default: T)(inline f: (TypeProxy, Symbol) => T): T =
183+
tp.stripAnnots.stripTypeVar match
184+
case tp: (AppliedType | TypeRef) => f(tp, tp.typeSymbol)
185+
case _ => default
186+
187+
@tailrec def dealias(tp: Type, syms: Set[Symbol]): Type =
188+
followAlias(tp)(NoType) { (tp, sym) =>
189+
if syms contains sym then tp
190+
else if sym.isAliasType then dealias(tp.superType, syms)
191+
else NoType
192+
}
193+
194+
@tailrec def aliasedSymbols(tp: Type, result: Set[Symbol] = Set.empty): Set[Symbol] =
195+
followAlias(tp)(Set.empty) { (tp, sym) =>
196+
if sym.isAliasType then aliasedSymbols(tp.superType, result + sym)
197+
else if sym.exists && (sym ne AnyClass) then result + sym
198+
else Set.empty
199+
}
200+
201+
val tp1dealiased = dealias(tp1, aliasedSymbols(tp2)) orElse tp1
202+
181203
val savedApprox = approx
182204
val savedLeftRoot = leftRoot
183205
if (a == ApproxState.Fresh) {
184206
this.approx = ApproxState.None
185-
this.leftRoot = tp1
207+
this.leftRoot = tp1dealiased
186208
}
187209
else this.approx = a
188-
try recur(tp1, tp2)
210+
try recur(tp1dealiased, tp2)
189211
catch {
190-
case ex: Throwable => handleRecursive("subtype", i"$tp1 <:< $tp2", ex, weight = 2)
212+
case ex: Throwable => handleRecursive("subtype", i"$tp1dealiased <:< $tp2", ex, weight = 2)
191213
}
192214
finally {
193215
this.approx = savedApprox
@@ -1009,7 +1031,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10091031
def isMatchingApply(tp1: Type): Boolean = tp1.widen match {
10101032
case tp1 @ AppliedType(tycon1, args1) =>
10111033
// We intentionally do not automatically dealias `tycon1` or `tycon2` here.
1012-
// `TypeApplications#appliedTo` already takes care of dealiasing type
1034+
// `isSubType` already takes care of dealiasing type
10131035
// constructors when this can be done without affecting type
10141036
// inference, doing it here would not only prevent code from compiling
10151037
// but could also result in the wrong thing being inferred later, for example

tests/run/hk-alias-unification.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,24 @@ trait ErasedFoo[FT]
1010
object Test {
1111
type Foo[F[_], T] = ErasedFoo[F[T]]
1212
type Foo2[F[_], T] = Foo[F, T]
13+
type Foo3[T, F[_]] = Foo[F, T]
1314

1415
def mkFoo[F[_], T](implicit gen: Bla[T]): Foo[F, T] = new Foo[F, T] {}
1516
def mkFoo2[F[_], T](implicit gen: Bla[T]): Foo2[F, T] = new Foo2[F, T] {}
17+
def mkFoo3[F[_], T](implicit gen: Bla[T]): Foo3[T, F] = new Foo3[T, F] {}
1618

1719
def main(args: Array[String]): Unit = {
18-
val a: Foo[[X] =>> (X, String), Int] = mkFoo
19-
val b: Foo2[[X] =>> (X, String), Int] = mkFoo
20-
val c: Foo[[X] =>> (X, String), Int] = mkFoo2
20+
val a1: Foo[[X] =>> (X, String), Int] = mkFoo
21+
val b1: Foo2[[X] =>> (X, String), Int] = mkFoo
22+
val c1: Foo3[Int, [X] =>> (X, String)] = mkFoo
23+
24+
val a2: Foo[[X] =>> (X, String), Int] = mkFoo2
25+
val b2: Foo2[[X] =>> (X, String), Int] = mkFoo2
26+
val c2: Foo3[Int, [X] =>> (X, String)] = mkFoo2
27+
28+
val a3: Foo[[X] =>> (X, String), Int] = mkFoo3
29+
val b3: Foo2[[X] =>> (X, String), Int] = mkFoo3
30+
val c3: Foo3[Int, [X] =>> (X, String)] = mkFoo3
2131
}
2232
}
2333

0 commit comments

Comments
 (0)