Skip to content

Add overloading support for case-closures #2015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
case ValDef(_, tpt, _) => tpt.isEmpty
case _ => false
}
case Match(EmptyTree, _) =>
true
case _ => false
}

Expand Down
54 changes: 27 additions & 27 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
def typeShape(tree: untpd.Tree): Type = tree match {
case untpd.Function(args, body) =>
defn.FunctionOf(args map Function.const(defn.AnyType), typeShape(body))
case Match(EmptyTree, _) =>
defn.PartialFunctionType.appliedTo(defn.AnyType :: defn.NothingType :: Nil)
case _ =>
defn.NothingType
}
Expand Down Expand Up @@ -1271,7 +1273,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
alts filter (alt => sizeFits(alt, alt.widen))

def narrowByShapes(alts: List[TermRef]): List[TermRef] = {
if (normArgs exists (_.isInstanceOf[untpd.Function]))
if (normArgs exists untpd.isFunctionWithUnknownParamType)
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
else narrowByTypes(alts, normArgs map typeShape, resultType)
else
Expand Down Expand Up @@ -1351,33 +1353,31 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
case ValDef(_, tpt, _) => tpt.isEmpty
case _ => false
}
arg match {
case arg: untpd.Function if arg.args.exists(isUnknownParamType) =>
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
val formalsForArg: List[Type] = altFormals.map(_.head)
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
// (p_1_1, ..., p_m_1) => r_1
// ...
// (p_1_n, ..., p_m_n) => r_n
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
formalsForArg.map(defn.FunctionOf.unapply)
if (decomposedFormalsForArg.forall(_.isDefined)) {
val formalParamTypessForArg: List[List[Type]] =
decomposedFormalsForArg.map(_.get._1)
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
// Given definitions above, for i = 1,...,m,
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
// If all p_i_k's are the same, assume the type as formal parameter
// type of the i'th parameter of the closure.
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
else WildcardType)
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
overload.println(i"pretype arg $arg with expected type $commonFormal")
pt.typedArg(arg, commonFormal)
}
if (untpd.isFunctionWithUnknownParamType(arg)) {
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
val formalsForArg: List[Type] = altFormals.map(_.head)
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
// (p_1_1, ..., p_m_1) => r_1
// ...
// (p_1_n, ..., p_m_n) => r_n
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
formalsForArg.map(defn.FunctionOf.unapply)
if (decomposedFormalsForArg.forall(_.isDefined)) {
val formalParamTypessForArg: List[List[Type]] =
decomposedFormalsForArg.map(_.get._1)
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
// Given definitions above, for i = 1,...,m,
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
// If all p_i_k's are the same, assume the type as formal parameter
// type of the i'th parameter of the closure.
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
else WildcardType)
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
overload.println(i"pretype arg $arg with expected type $commonFormal")
pt.typedArg(arg, commonFormal)
}
case _ =>
}
}
recur(altFormals.map(_.tail), args1)
case _ =>
Expand Down
41 changes: 41 additions & 0 deletions tests/pos/inferOverloaded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
class MySeq[T] {
def map1[U](f: T => U): MySeq[U] = new MySeq[U]
def map2[U](f: T => U): MySeq[U] = new MySeq[U]
}

class MyMap[A, B] extends MySeq[(A, B)] {
def map1[C](f: (A, B) => C): MySeq[C] = new MySeq[C]
def map1[C, D](f: (A, B) => (C, D)): MyMap[C, D] = new MyMap[C, D]
def map1[C, D](f: ((A, B)) => (C, D)): MyMap[C, D] = new MyMap[C, D]

def foo(f: Function2[Int, Int, Int]): Unit = ()
def foo[R](pf: PartialFunction[(A, B), R]): MySeq[R] = new MySeq[R]
}

object Test {
val m = new MyMap[Int, String]

// This one already worked because it is not overloaded:
m.map2 { case (k, v) => k - 1 }

// These already worked because preSelectOverloaded eliminated the non-applicable overload:
m.map1(t => t._1)
m.map1((kInFunction, vInFunction) => kInFunction - 1)
val r1 = m.map1(t => (t._1, 42.0))
val r1t: MyMap[Int, Double] = r1

// These worked because the argument types are known for overload resolution:
m.map1({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])
m.map2({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])

// These ones did not work before:
m.map1 { case (k, v) => k }
val r = m.map1 { case (k, v) => (k, k*10) }
val rt: MyMap[Int, Int] = r
m.foo { case (k, v) => k - 1 }

// Used to be ambiguous but overload resolution now favors PartialFunction
def h[R](pf: Function2[Int, String, R]): Unit = ()
def h[R](pf: PartialFunction[(Double, Double), R]): Unit = ()
h { case (a: Double, b: Double) => 42: Int }
}