Skip to content

Commit e8d1b19

Browse files
committed
Strengthen overloading resolution to deal with extension methods
If resolving with the first parameter list yields a draw, and there are further argument lists following the first one, take these into account as well in order to arrive at a single best alternative. This is particularly useful for extension methods, where we might well have several overloaded extension methods with the same first parameter list.
1 parent 6dbf71d commit e8d1b19

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,16 +1632,50 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
16321632
}
16331633
else compat
16341634
}
1635+
1636+
/** For each candidate `C`, a proxy termref paired with `C`.
1637+
* The proxy termref has as symbol a copy of the original candidate symbol,
1638+
* with an info that strips the first value parameter list away.
1639+
* @param argTypes The types of the arguments of the FunProto `pt`.
1640+
*/
1641+
def advanceCandidates(argTypes: List[Type]): List[(TermRef, TermRef)] = {
1642+
def strippedType(tp: Type): Type = tp match {
1643+
case tp: PolyType =>
1644+
val rt = strippedType(tp.resultType)
1645+
if (rt.exists) tp.derivedLambdaType(resType = rt) else rt
1646+
case tp: MethodType =>
1647+
tp.instantiate(argTypes)
1648+
case _ =>
1649+
NoType
1650+
}
1651+
def cloneCandidate(cand: TermRef): List[(TermRef, TermRef)] = {
1652+
val strippedInfo = strippedType(cand.widen)
1653+
if (strippedInfo.exists) {
1654+
val sym = cand.symbol.asTerm.copy(info = strippedInfo)
1655+
(TermRef(cand.prefix, sym), cand) :: Nil
1656+
}
1657+
else Nil
1658+
}
1659+
overload.println(i"look at more params: ${candidates.head.symbol}: ${candidates.map(_.widen)}%, % with $pt, [$targs%, %]")
1660+
candidates.flatMap(cloneCandidate)
1661+
}
1662+
16351663
val found = narrowMostSpecific(candidates)
16361664
if (found.length <= 1) found
1637-
else {
1638-
val noDefaults = alts.filter(!_.symbol.hasDefaultParams)
1639-
if (noDefaults.length == 1) noDefaults // return unique alternative without default parameters if it exists
1640-
else {
1641-
val deepPt = pt.deepenProto
1642-
if (deepPt ne pt) resolveOverloaded(alts, deepPt, targs)
1643-
else alts
1644-
}
1665+
else pt match {
1666+
case pt @ FunProto(_, resType: FunProto) =>
1667+
// try to narrow further with snd argument list
1668+
val advanced = advanceCandidates(pt.typedArgs.tpes)
1669+
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
1670+
.map(advanced.toMap) // map surviving result(s) back to original candidates
1671+
case _ =>
1672+
val noDefaults = alts.filter(!_.symbol.hasDefaultParams)
1673+
if (noDefaults.length == 1) noDefaults // return unique alternative without default parameters if it exists
1674+
else {
1675+
val deepPt = pt.deepenProto
1676+
if (deepPt ne pt) resolveOverloaded(alts, deepPt, targs)
1677+
else alts
1678+
}
16451679
}
16461680
}
16471681

tests/run/extmethod-overload.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
object Test extends App {
2+
// warmup
3+
def f(x: Int)(y: Int) = y
4+
def f(x: Int)(y: String) = y.length
5+
assert(f(1)(2) == 2)
6+
assert(f(1)("two") == 3)
7+
8+
def g[T](x: T)(y: Int) = y
9+
def g[T](x: T)(y: String) = y.length
10+
assert(g[Int](1)(2) == 2)
11+
assert(g[Int](1)("two") == 3)
12+
assert(g(1)(2) == 2)
13+
assert(g(1)("two") == 3)
14+
15+
def h[T](x: T)(y: T)(z: Int) = z
16+
def h[T](x: T)(y: T)(z: String) = z.length
17+
assert(h[Int](1)(1)(2) == 2)
18+
assert(h[Int](1)(1)("two") == 3)
19+
assert(h(1)(1)(2) == 2)
20+
assert(h(1)(1)("two") == 3)
21+
22+
implied Foo {
23+
def (x: Int) |+| (y: Int) = x + y
24+
def (x: Int) |+| (y: String) = x + y.length
25+
26+
def (xs: List[T]) +++ [T] (ys: List[T]): List[T] = xs ++ ys ++ ys
27+
def (xs: List[T]) +++ [T] (ys: Iterator[T]): List[T] = xs ++ ys ++ ys
28+
}
29+
30+
assert((1 |+| 2) == 3)
31+
assert((1 |+| "2") == 2)
32+
33+
val xs = List(1, 2)
34+
assert((xs +++ xs).length == 6)
35+
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
36+
}

0 commit comments

Comments
 (0)