Skip to content

Type parameter clause inference for lambdas (and method references via eta-expansion) #18169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object Feature:
val pureFunctions = experimental("pureFunctions")
val captureChecking = experimental("captureChecking")
val into = experimental("into")
val typeClauseInference = experimental("typeClauseInference")

val globalOnlyImports: Set[TermName] = Set(pureFunctions, captureChecking)

Expand Down
49 changes: 49 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedFunctionValue(tree: untpd.Function, pt: Type)(using Context): Tree = {
val untpd.Function(params: List[untpd.ValDef] @unchecked, _) = tree: @unchecked

if Feature.enabled(Feature.typeClauseInference) then
// If the expected type is a polymorphic function type:
//
// [S_1, ..., S_m] => (T_1, ..., T_n) => R
// (where each S_i might have type bounds)
//
// and we are typing a lambda of the form:
//
// (x_1, ..., x_n) => e
// (where each x_i might have a type ascription)
//
// then continue with an inferred type parameter clause:
//
// [S'_1, ..., S'_m] => (x_1, ..., x_n) => e
// (where each S'_i is fresh and has the bounds of S_i after substituting S_j by S'_j for all j)
pt match
case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType))
if params.lengthCompare(mt.paramNames) == 0 =>
val tparams = poly.paramNames.lazyZip(poly.paramInfos).map: (name, info) =>
untpd.TypeDef(UniqueName.fresh(name), untpd.InLambdaTypeTree(isResult = false, (tsyms, vsyms) =>
info.substParams(poly, tsyms.map(_.typeRef))
)).withFlags(SyntheticParam)
.withSpan(tree.span.startPos)
return typed(untpd.PolyFunction(tparams, tree), pt)
case _ =>

val (isContextual, isDefinedErased) = tree match {
case tree: untpd.FunctionWithMods => (tree.mods.is(Given), tree.erasedParams)
case _ => (false, tree.args.map(_ => false))
Expand Down Expand Up @@ -4318,6 +4344,29 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if isApplyProxy(tree) then newExpr
else if pt.isInstanceOf[PolyProto] then tree
else
if Feature.enabled(Feature.typeClauseInference) then
// If `tree` is a polymorphic method reference and the expected
// type is a polymorphic function, perform a monomorphic
// eta-expansion of the method reference.
// For example, this means that
//
// (1, 2.0).map(Option.apply)
//
// will expand to:
//
// (1, 2.0).map(x => Option.apply(x))
//
// A type parameter clause for the lambda will subsequently be
// inferred (from its expected type) in typedFunctionValue.
pt match
case defn.PolyFunctionOf(_: PolyType) =>
poly.resultType match
case mt: MethodType =>
val expanded = etaExpand(tree, mt, mt.paramInfos.length)
return simplify(typed(expanded, pt), pt, locked)
case _ =>
case _ =>
end if
var typeArgs = tree match
case Select(qual, nme.CONSTRUCTOR) => qual.tpe.widenDealias.argTypesLo.map(TypeTree(_))
case _ => Nil
Expand Down
3 changes: 3 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ object language:
*/
@compileTimeOnly("`into` can only be used at compile time in import statements")
object into

@compileTimeOnly("`typeClauseInference` can only be used at compile time in import statements")
object typeClauseInference
end experimental

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
7 changes: 7 additions & 0 deletions tests/neg/typeClauseInference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import scala.language.experimental.typeClauseInference

val notInScopeInferred: [T] => T => T = x => (x: T) // error

def bar[A]: A => A = x => x
val barf1: [T] => T => T = bar(_) // ok
val barf2: [T] => T => T = bar // error, unlike in the original SIP-49.
20 changes: 20 additions & 0 deletions tests/pos/typeClauseInference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import language.experimental.typeClauseInference

class Test:
def test: Unit =
val it1: [T] => T => T = x => x
val it2: [T] => (T, Int) => T = (x, y: Int) => x
val it3: [T, S <: List[T]] => (T, S) => List[T] = (x, y) => x :: y
val tuple1: (String, String) = (1, 2.0).map[[_] =>> String](_.toString)
val tuple2: (List[Int], List[Double]) = (1, 2.0).map(List(_))

// Eta-expansion
val e1: [T] => T => Option[T] = Option.apply
val tuple3: (Option[Int], Option[Double]) = (1, 2.0).map(Option.apply)

// Eta-expansion that wouldn't work with the original SIP-49
def pair[S, T](x: S, y: T): (S, T) = (x, y)
val f5: [T] => (Int, T) => (Int, T) = pair
val f6: [T] => (T, Int) => (T, Int) = pair
def id[T](x: T): T = x
val f7: [S] => List[S] => List[S] = id