Skip to content

Commit 73b0e75

Browse files
committed
Type parameter clause inference for lambdas
When a lambda is written without a type parameter clause but the expected type is a polymorphic function type, try to adapt the lambda into a polymorphic lambda by inferring an appropriate type parameter clause. This change broke one example in spire which relied on implicit conversions. The fix has been accepted upstream: typelevel/spire#1247
1 parent 7b39cb6 commit 73b0e75

File tree

6 files changed

+45
-1
lines changed

6 files changed

+45
-1
lines changed

compiler/src/dotty/tools/dotc/config/Feature.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object Feature:
3333
val pureFunctions = experimental("pureFunctions")
3434
val captureChecking = experimental("captureChecking")
3535
val into = experimental("into")
36+
val typeClauseInference = experimental("typeClauseInference")
3637

3738
val globalOnlyImports: Set[TermName] = Set(pureFunctions, captureChecking)
3839

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,6 +1462,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
14621462
def typedFunctionValue(tree: untpd.Function, pt: Type)(using Context): Tree = {
14631463
val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree: @unchecked
14641464

1465+
if Feature.enabled(Feature.typeClauseInference) then
1466+
// If the expected type is a polymorphic function type:
1467+
//
1468+
// [S_1, ..., S_m] => (T_1, ..., T_n) => R
1469+
// (where each S_i might have type bounds)
1470+
//
1471+
// and we are typing a lambda of the form:
1472+
//
1473+
// (x_1, ..., x_n) => e
1474+
// (where each x_i might have a type ascription)
1475+
//
1476+
// then continue with an inferred type parameter clause:
1477+
//
1478+
// [S'_1, ..., S'_m] => (x_1, ..., x_n) => e
1479+
// (where each S'_i is fresh and has the bounds of S_i after substituting S_j by S'_j for all j)
1480+
pt match
1481+
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
1482+
if params.lengthCompare(mt.paramNames) == 0 =>
1483+
val tparams = poly.paramNames.lazyZip(poly.paramInfos).map: (name, info) =>
1484+
untpd.TypeDef(UniqueName.fresh(name), untpd.InLambdaTypeTree(isResult = false, (tsyms, vsyms) =>
1485+
info.substParams(poly, tsyms.map(_.typeRef))
1486+
)).withFlags(SyntheticParam)
1487+
.withSpan(tree.span.startPos)
1488+
return typed(untpd.PolyFunction(tparams, tree), pt)
1489+
case _ =>
1490+
14651491
val (isContextual, isDefinedErased) = tree match {
14661492
case tree: untpd.FunctionWithMods => (tree.mods.is(Given), tree.erasedParams)
14671493
case _ => (false, tree.args.map(_ => false))

library/src/scala/runtime/stdLibPatches/language.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ object language:
9898
*/
9999
@compileTimeOnly("`into` can only be used at compile time in import statements")
100100
object into
101+
102+
@compileTimeOnly("`typeClauseInference` can only be used at compile time in import statements")
103+
object typeClauseInference
101104
end experimental
102105

103106
/** The deprecated object contains features that are no longer officially suypported in Scala.

tests/neg/typeClauseInference.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import scala.language.experimental.typeClauseInference
2+
3+
val notInScopeInferred: [T] => T => T = x => (x: T) // error

tests/pos/typeClauseInference.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import language.experimental.typeClauseInference
2+
3+
class Test:
4+
def test: Unit =
5+
val it1: [T] => T => T = x => x
6+
val it2: [T] => (T, Int) => T = (x, y: Int) => x
7+
val it3: [T, S <: List[T]] => (T, S) => List[T] = (x, y) => x :: y
8+
val tuple1: (String, String) = (1, 2.0).map[[_] =>> String](_.toString)
9+
val tuple2: (List[Int], List[Double]) = (1, 2.0).map(List(_))
10+
// Not supported yet, require eta-expansion with a polymorphic expected type
11+
// val tuple3: (List[Int], List[Double]) = (1, 2.0).map(List.apply)

0 commit comments

Comments
 (0)