Skip to content

Fix #7778: infer parameter type for contextual functions #9075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1336,9 +1336,9 @@ object desugar {
Function(param :: Nil, Block(vdefs, body))
}

def makeContextualFunction(formals: List[Type], body: Tree, isErased: Boolean)(implicit ctx: Context): Tree = {
def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(implicit ctx: Context): Function = {
val mods = if (isErased) Given | Erased else Given
val params = makeImplicitParameters(formals.map(TypeTree), mods)
val params = makeImplicitParameters(formals, mods)
FunctionWithMods(params, body, Modifiers(mods))
}

Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
case _ => tree
}

def stripAnnotated(tree: Tree): Tree = tree match {
case Annotated(arg, _) => arg
case _ => tree
}

/** The number of arguments in an application */
def numArgs(tree: Tree): Int = unsplice(tree) match {
case Apply(fn, args) => numArgs(fn) + args.length
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ object ProtoTypes {
*/
private def wildApprox(tp: Type, theMap: WildApproxMap, seen: Set[TypeParamRef], internal: Set[TypeLambda])(using Context): Type = tp match {
case tp: NamedType => // default case, inlined for speed
val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.is(Flags.Case) && !tp.symbol.isClass
val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.isPatternBound
if (isPatternBoundTypeRef) WildcardType(tp.underlying.bounds)
else if (tp.symbol.isStatic || (tp.prefix `eq` NoPrefix)) tp
else tp.derivedSelect(wildApprox(tp.prefix, theMap, seen, internal))
Expand Down
53 changes: 50 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1049,15 +1049,35 @@ class Typer extends Namer
*/
var paramIndex = Map[Name, Int]()

/** If function is of the form
/** Infer parameter type from the body of the function
*
* 1. If function is of the form
*
* (x1, ..., xN) => f(... x1, ..., XN, ...)
*
* where each `xi` occurs exactly once in the argument list of `f` (in
* any order), the type of `f`, otherwise NoType.
*
* 2. If the function is of the form
*
* (using x1, ..., xN) => f
*
* where `f` is a contextual function type of the form `(T1, ..., TN) ?=> T`,
* then `xi` takes the type `Ti`.
*
* Updates `fnBody` and `paramIndex` as a side effect.
* @post: If result exists, `paramIndex` is defined for the name of
* every parameter in `params`.
*/
lazy val calleeType: Type = fnBody match {
lazy val calleeType: Type = untpd.stripAnnotated(fnBody) match {
case ident: untpd.Ident if isContextual =>
val ident1 = typedIdent(ident, WildcardType)
val tp = ident1.tpe.widen
if defn.isContextFunctionType(tp) && params.size == defn.functionArity(tp) then
paramIndex = params.map(_.name).zipWithIndex.toMap
fnBody = untpd.TypedSplice(ident1)
tp.select(nme.apply)
else NoType
case app @ Apply(expr, args) =>
paramIndex = {
for (param <- params; idx <- paramIndices(param, args))
Expand Down Expand Up @@ -2450,7 +2470,34 @@ class Typer extends Namer

protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
val defn.FunctionOf(formals, _, true, _) = pt.dropDependentRefinement
val ifun = desugar.makeContextualFunction(formals, tree, defn.isErasedFunctionType(pt))

// The getter of default parameters may reach here.
// Given the code below
//
// class Foo[A](run: A ?=> Int) {
// def foo[T](f: T ?=> Int = run) = ()
// }
//
// it desugars to
//
// class Foo[A](run: A ?=> Int) {
// def foo$default$1[T] = run
// def foo[T](f: T ?=> Int = run) = ()
// }
//
// The expected type for checking `run` in `foo$default$1` is
//
// <?> ?=> Int
//
// see tests/pos/i7778b.scala

val paramTypes = {
val hasWildcard = formals.exists(_.isInstanceOf[WildcardType])
if hasWildcard then formals.map(_ => untpd.TypeTree())
else formals.map(untpd.TypeTree)
}

val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt))
typr.println(i"make contextual function $tree / $pt ---> $ifun")
typed(ifun, pt)
}
Expand Down
12 changes: 12 additions & 0 deletions tests/pos/i7778.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
object Example extends App {
final case class Foo[A](run: A ?=> Int)
}

object Example2 extends App {
final case class Foo[A, B](run: (A, B) ?=> Int)
}


object Example3 extends App {
final case class Foo[A, B](run: () ?=> Int)
}
4 changes: 4 additions & 0 deletions tests/pos/i7778b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Foo[A](run: A ?=> Int) {
def foo[T](f: T ?=> Int = run) = ()
}