From ddd986b07dcf6ce7eca408a53d689b8a886ada7d Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Mon, 27 Nov 2017 19:13:35 +0100 Subject: [PATCH] Better function return type inference If the function prototype is a type variable, its upper bound might contain useful information for inferring the return type of the function. Before this PR, the added testcases failed because `() => 4` was typed as `() => Int` and `() => new Inv` as `() => Inv[Nothing]`, even though the expected types of the functions give enough information to correctly infer them. --- compiler/src/dotty/tools/dotc/typer/Typer.scala | 4 +++- tests/pos/functions1.scala | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index ee824534d6d0..3c256edc513e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -697,7 +697,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit case _: WildcardType => untpd.TypeTree() case _ => untpd.TypeTree(tp) } - pt match { + pt.stripTypeVar match { case _ if defn.isNonDepFunctionType(pt) => // if expected parameter type(s) are wildcards, approximate from below. // if expected result type is a wildcard, approximate from above. @@ -711,6 +711,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))) else typeTree(restpe)) + case tp: TypeParamRef => + decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity) case _ => (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) } diff --git a/tests/pos/functions1.scala b/tests/pos/functions1.scala index 04b90d80e442..285daed6f79d 100644 --- a/tests/pos/functions1.scala +++ b/tests/pos/functions1.scala @@ -33,4 +33,11 @@ object Functions { val z: Spore[String, String] = x => x + x val z2: Spore2[String, String] = x => x + x } + + object retType { + val a: List[() => 4] = List(() => 4) + + class Inv[T] + val b: List[() => Inv[Int]] = List(() => new Inv) + } }