Skip to content

Change pretyping of argument closure #2018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,21 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
case _ => false
}

def isFunctionWithUnknownParamType(tree: Tree) = tree match {
/** Is this a function literal (either a lambda or a case-block) with an
* unknown parameter type?
*
* @param idx If a non-negative value is given, only the specific parameter
* at that index is tested, otherwise all parameters are tested.
*/
def isFunctionWithUnknownParamType(tree: Tree, idx: Int = -1): Boolean = tree match {
case Function(args, _) =>
args.exists {
val hasUnknownParamType: Tree => Boolean = {
case ValDef(_, tpt, _) => tpt.isEmpty
case _ => false
}
if (idx >= 0) hasUnknownParamType(args(idx)) else args.exists(hasUnknownParamType)
case Match(EmptyTree, _) =>
true
case _ => false
}

Expand Down
57 changes: 30 additions & 27 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
def typeShape(tree: untpd.Tree): Type = tree match {
case untpd.Function(args, body) =>
defn.FunctionOf(args map Function.const(defn.AnyType), typeShape(body))
case Match(EmptyTree, _) =>
defn.PartialFunctionType.appliedTo(defn.AnyType :: defn.NothingType :: Nil)
case _ =>
defn.NothingType
}
Expand Down Expand Up @@ -1271,7 +1273,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
alts filter (alt => sizeFits(alt, alt.widen))

def narrowByShapes(alts: List[TermRef]): List[TermRef] = {
if (normArgs exists (_.isInstanceOf[untpd.Function]))
if (normArgs exists (untpd.isFunctionWithUnknownParamType(_)))
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
else narrowByTypes(alts, normArgs map typeShape, resultType)
else
Expand Down Expand Up @@ -1351,33 +1353,34 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
case ValDef(_, tpt, _) => tpt.isEmpty
case _ => false
}
arg match {
case arg: untpd.Function if arg.args.exists(isUnknownParamType) =>
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
val formalsForArg: List[Type] = altFormals.map(_.head)
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
// (p_1_1, ..., p_m_1) => r_1
// ...
// (p_1_n, ..., p_m_n) => r_n
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
formalsForArg.map(defn.FunctionOf.unapply)
if (decomposedFormalsForArg.forall(_.isDefined)) {
val formalParamTypessForArg: List[List[Type]] =
decomposedFormalsForArg.map(_.get._1)
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
// Given definitions above, for i = 1,...,m,
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
// If all p_i_k's are the same, assume the type as formal parameter
// type of the i'th parameter of the closure.
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
else WildcardType)
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
overload.println(i"pretype arg $arg with expected type $commonFormal")
pt.typedArg(arg, commonFormal)
}
if (untpd.isFunctionWithUnknownParamType(arg)) {
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
val formalsForArg: List[Type] = altFormals.map(_.head)
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
// (p_1_1, ..., p_m_1) => r_1
// ...
// (p_1_n, ..., p_m_n) => r_n
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
formalsForArg.map(defn.FunctionOf.unapply)
if (decomposedFormalsForArg.forall(_.isDefined)) {
val formalParamTypessForArg: List[List[Type]] =
decomposedFormalsForArg.map(_.get._1)
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
val commonParamTypes = formalParamTypessForArg
.transpose
.zipWithIndex
.map {
case (ps, idx) =>
if (untpd.isFunctionWithUnknownParamType(arg, idx))
ps.reduceLeft(_ | _)
else
WildcardType
}
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
overload.println(i"pretype arg $arg with expected type $commonFormal")
pt.typedArg(arg, commonFormal)
}
case _ =>
}
}
recur(altFormals.map(_.tail), args1)
case _ =>
Expand Down
41 changes: 41 additions & 0 deletions tests/pos/inferOverloaded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
class MySeq[T] {
def map1[U](f: T => U): MySeq[U] = new MySeq[U]
def map2[U](f: T => U): MySeq[U] = new MySeq[U]
}

class MyMap[A, B] extends MySeq[(A, B)] {
def map1[C](f: (A, B) => C): MySeq[C] = new MySeq[C]
def map1[C, D](f: (A, B) => (C, D)): MyMap[C, D] = new MyMap[C, D]
def map1[C, D](f: ((A, B)) => (C, D)): MyMap[C, D] = new MyMap[C, D]

def foo(f: Function2[Int, Int, Int]): Unit = ()
def foo[R](pf: PartialFunction[(A, B), R]): MySeq[R] = new MySeq[R]
}

object Test {
val m = new MyMap[Int, String]

// This one already worked because it is not overloaded:
m.map2 { case (k, v) => k - 1 }

// These already worked because preSelectOverloaded eliminated the non-applicable overload:
m.map1(t => t._1)
m.map1((kInFunction, vInFunction) => kInFunction - 1)
val r1 = m.map1(t => (t._1, 42.0))
val r1t: MyMap[Int, Double] = r1

// These worked because the argument types are known for overload resolution:
m.map1({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])
m.map2({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])

// These ones did not work before:
m.map1 { case (k, v) => k }
val r = m.map1 { case (k, v) => (k, k*10) }
val rt: MyMap[Int, Int] = r
m.foo { case (k, v) => k - 1 }

// Used to be ambiguous but overload resolution now favors PartialFunction
def h[R](pf: Function2[Int, String, R]): Unit = ()
def h[R](pf: PartialFunction[(Double, Double), R]): Unit = ()
h { case (a: Double, b: Double) => 42: Int }
}