From 5b2d7bff8e50dc3a8196e9b84ff75e9e66067bbf Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Sun, 15 Nov 2020 18:55:50 +0100 Subject: [PATCH 1/2] Better expected type for arguments of overloaded methods `pretypeArgs` allows arguments of overloaded methods to be typed with a more precise expected type when the formal parameter types of each overload are all compatible function types, but previously this logic only kicked in for arguments which were syntactically known to be functions themselves, which means that it worked when the argument was `foo(_)` or `x => foo(x)`, but not when it was just `foo`. This commit simply removes this restriction. Fixes #10325. --- .../dotty/tools/dotc/typer/Applications.scala | 64 +++++++++---------- tests/pos/i10325.scala | 19 ++++++ 2 files changed, 51 insertions(+), 32 deletions(-) create mode 100644 tests/pos/i10325.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 639e04e25415..bb94fce2f93e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -2026,38 +2026,38 @@ trait Applications extends Compatibility { private def pretypeArgs(alts: List[TermRef], pt: FunProto)(using Context): Unit = { def recur(altFormals: List[List[Type]], args: List[untpd.Tree]): Unit = args match { case arg :: args1 if !altFormals.exists(_.isEmpty) => - untpd.functionWithUnknownParamType(arg) match { - case Some(fn) => - def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) - val formalsForArg: List[Type] = altFormals.map(_.head) - def argTypesOfFormal(formal: Type): List[Type] = - formal match { - case defn.FunctionOf(args, result, isImplicit, isErased) => args - case defn.PartialFunctionOf(arg, result) => arg :: Nil - case _ => Nil - } - val formalParamTypessForArg: List[List[Type]] = - formalsForArg.map(argTypesOfFormal) - if (formalParamTypessForArg.forall(_.nonEmpty) && - isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { - val commonParamTypes = formalParamTypessForArg.transpose.map(ps => - // Given definitions above, for i = 1,...,m, - // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column - // If all p_i_k's are the same, assume the type as formal parameter - // type of the i'th parameter of the closure. - if (isUniform(ps)(_ frozen_=:= _)) ps.head - else WildcardType) - def isPartial = // we should generate a partial function for the arg - fn.isInstanceOf[untpd.Match] && - formalsForArg.exists(_.isRef(defn.PartialFunctionClass)) - val commonFormal = - if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType) - else defn.FunctionOf(commonParamTypes, WildcardType) - overload.println(i"pretype arg $arg with expected type $commonFormal") - if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom))) - withMode(Mode.ImplicitsEnabled)(pt.typedArg(arg, commonFormal)) - } - case None => + def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) + val formalsForArg: List[Type] = altFormals.map(_.head) + def argTypesOfFormal(formal: Type): List[Type] = + formal match { + case defn.FunctionOf(args, result, isImplicit, isErased) => args + case defn.PartialFunctionOf(arg, result) => arg :: Nil + case _ => Nil + } + val formalParamTypessForArg: List[List[Type]] = + formalsForArg.map(argTypesOfFormal) + if (formalParamTypessForArg.forall(_.nonEmpty) && + isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { + val commonParamTypes = formalParamTypessForArg.transpose.map(ps => + // Given definitions above, for i = 1,...,m, + // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column + // If all p_i_k's are the same, assume the type as formal parameter + // type of the i'th parameter of the closure. + if (isUniform(ps)(_ frozen_=:= _)) ps.head + else WildcardType) + /** Should we generate a partial function for the arg ? */ + def isPartial = untpd.functionWithUnknownParamType(arg) match + case Some(fn) => + fn.isInstanceOf[untpd.Match] && + formalsForArg.exists(_.isRef(defn.PartialFunctionClass)) + case None => + false + val commonFormal = + if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType) + else defn.FunctionOf(commonParamTypes, WildcardType) + overload.println(i"pretype arg $arg with expected type $commonFormal") + if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom))) + withMode(Mode.ImplicitsEnabled)(pt.typedArg(arg, commonFormal)) } recur(altFormals.map(_.tail), args1) case _ => diff --git a/tests/pos/i10325.scala b/tests/pos/i10325.scala new file mode 100644 index 000000000000..7668612cc2c8 --- /dev/null +++ b/tests/pos/i10325.scala @@ -0,0 +1,19 @@ +object Test { + def nullToNone[K, V](tuple: (K, V)): (K, Option[V]) = { + val (k, v) = tuple + (k, Option(v)) + } + + def test: Unit = { + val scalaMap: Map[String, String] = Map() + + val a = scalaMap.map(nullToNone) + val a1: Map[String, Option[String]] = a + + val b = scalaMap.map(nullToNone(_)) + val b1: Map[String, Option[String]] = b + + val c = scalaMap.map(x => nullToNone(x)) + val c1: Map[String, Option[String]] = c + } +} From 5e6bfa2f7dd8eb3ab9d36fc9dae12b74e776b758 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 17 Nov 2020 10:16:18 +0100 Subject: [PATCH 2/2] Polish pattern match --- compiler/src/dotty/tools/dotc/typer/Applications.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index bb94fce2f93e..aee490dd4792 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -2047,10 +2047,9 @@ trait Applications extends Compatibility { else WildcardType) /** Should we generate a partial function for the arg ? */ def isPartial = untpd.functionWithUnknownParamType(arg) match - case Some(fn) => - fn.isInstanceOf[untpd.Match] && + case Some(_: untpd.Match) => formalsForArg.exists(_.isRef(defn.PartialFunctionClass)) - case None => + case _ => false val commonFormal = if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)