Skip to content

Commit 8816aaf

Browse files
committed
improve type inference for parameters of higher-kinded types
1 parent 9101116 commit 8816aaf

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import Phases.{gettersPhase, elimByNamePhase}
88
import StdNames.nme
99
import TypeOps.refineUsingParent
1010
import collection.mutable
11+
import annotation.tailrec
1112
import util.Stats
1213
import util.NoSourcePosition
1314
import config.Config
@@ -182,16 +183,37 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
182183
try op finally comparedTypeLambdas = saved
183184

184185
protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
186+
inline def followAlias[T](inline tp: Type)(inline default: T)(inline f: (TypeProxy, Symbol) => T): T =
187+
tp.stripAnnots.stripTypeVar match
188+
case tp: (AppliedType | TypeRef) => f(tp, tp.typeSymbol)
189+
case _ => default
190+
191+
@tailrec def dealias(tp: Type, syms: Set[Symbol]): Type =
192+
followAlias(tp)(NoType) { (tp, sym) =>
193+
if syms contains sym then tp
194+
else if sym.isAliasType then dealias(tp.superType, syms)
195+
else NoType
196+
}
197+
198+
@tailrec def aliasedSymbols(tp: Type, result: Set[Symbol] = Set.empty): Set[Symbol] =
199+
followAlias(tp)(result) { (tp, sym) =>
200+
if sym.isAliasType then aliasedSymbols(tp.superType, result + sym)
201+
else if sym.exists && (sym ne AnyClass) then result + sym
202+
else result
203+
}
204+
205+
val tp1dealiased = dealias(tp1, aliasedSymbols(tp2)) orElse tp1
206+
185207
val savedApprox = approx
186208
val savedLeftRoot = leftRoot
187209
if (a == ApproxState.Fresh) {
188210
this.approx = ApproxState.None
189-
this.leftRoot = tp1
211+
this.leftRoot = tp1dealiased
190212
}
191213
else this.approx = a
192-
try recur(tp1, tp2)
214+
try recur(tp1dealiased, tp2)
193215
catch {
194-
case ex: Throwable => handleRecursive("subtype", i"$tp1 <:< $tp2", ex, weight = 2)
216+
case ex: Throwable => handleRecursive("subtype", i"$tp1dealiased <:< $tp2", ex, weight = 2)
195217
}
196218
finally {
197219
this.approx = savedApprox
@@ -1026,7 +1048,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10261048
def isMatchingApply(tp1: Type): Boolean = tp1.widen match {
10271049
case tp1 @ AppliedType(tycon1, args1) =>
10281050
// We intentionally do not automatically dealias `tycon1` or `tycon2` here.
1029-
// `TypeApplications#appliedTo` already takes care of dealiasing type
1051+
// `isSubType` already takes care of dealiasing type
10301052
// constructors when this can be done without affecting type
10311053
// inference, doing it here would not only prevent code from compiling
10321054
// but could also result in the wrong thing being inferred later, for example

tests/run/hk-alias-unification.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,24 @@ trait ErasedFoo[FT]
1010
object Test {
1111
type Foo[F[_], T] = ErasedFoo[F[T]]
1212
type Foo2[F[_], T] = Foo[F, T]
13+
type Foo3[T, F[_]] = Foo[F, T]
1314

1415
def mkFoo[F[_], T](implicit gen: Bla[T]): Foo[F, T] = new Foo[F, T] {}
1516
def mkFoo2[F[_], T](implicit gen: Bla[T]): Foo2[F, T] = new Foo2[F, T] {}
17+
def mkFoo3[F[_], T](implicit gen: Bla[T]): Foo3[T, F] = new Foo3[T, F] {}
1618

1719
def main(args: Array[String]): Unit = {
18-
val a: Foo[[X] =>> (X, String), Int] = mkFoo
19-
val b: Foo2[[X] =>> (X, String), Int] = mkFoo
20-
val c: Foo[[X] =>> (X, String), Int] = mkFoo2
20+
val a1: Foo[[X] =>> (X, String), Int] = mkFoo
21+
val b1: Foo2[[X] =>> (X, String), Int] = mkFoo
22+
val c1: Foo3[Int, [X] =>> (X, String)] = mkFoo
23+
24+
val a2: Foo[[X] =>> (X, String), Int] = mkFoo2
25+
val b2: Foo2[[X] =>> (X, String), Int] = mkFoo2
26+
val c2: Foo3[Int, [X] =>> (X, String)] = mkFoo2
27+
28+
val a3: Foo[[X] =>> (X, String), Int] = mkFoo3
29+
val b3: Foo2[[X] =>> (X, String), Int] = mkFoo3
30+
val c3: Foo3[Int, [X] =>> (X, String)] = mkFoo3
2131
}
2232
}
2333

0 commit comments

Comments
 (0)