Skip to content

Commit b2ce6b4

Browse files
authored
Merge pull request #9075 from dotty-staging/fix-7778
Fix #7778: infer parameter type for contextual functions
2 parents c9a4196 + c084742 commit b2ce6b4

File tree

6 files changed

+74
-6
lines changed

6 files changed

+74
-6
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,9 +1336,9 @@ object desugar {
13361336
Function(param :: Nil, Block(vdefs, body))
13371337
}
13381338

1339-
def makeContextualFunction(formals: List[Type], body: Tree, isErased: Boolean)(implicit ctx: Context): Tree = {
1339+
def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(implicit ctx: Context): Function = {
13401340
val mods = if (isErased) Given | Erased else Given
1341-
val params = makeImplicitParameters(formals.map(TypeTree), mods)
1341+
val params = makeImplicitParameters(formals, mods)
13421342
FunctionWithMods(params, body, Modifiers(mods))
13431343
}
13441344

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
9898
case _ => tree
9999
}
100100

101+
def stripAnnotated(tree: Tree): Tree = tree match {
102+
case Annotated(arg, _) => arg
103+
case _ => tree
104+
}
105+
101106
/** The number of arguments in an application */
102107
def numArgs(tree: Tree): Int = unsplice(tree) match {
103108
case Apply(fn, args) => numArgs(fn) + args.length

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ object ProtoTypes {
598598
*/
599599
private def wildApprox(tp: Type, theMap: WildApproxMap, seen: Set[TypeParamRef], internal: Set[TypeLambda])(using Context): Type = tp match {
600600
case tp: NamedType => // default case, inlined for speed
601-
val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.is(Flags.Case) && !tp.symbol.isClass
601+
val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.isPatternBound
602602
if (isPatternBoundTypeRef) WildcardType(tp.underlying.bounds)
603603
else if (tp.symbol.isStatic || (tp.prefix `eq` NoPrefix)) tp
604604
else tp.derivedSelect(wildApprox(tp.prefix, theMap, seen, internal))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,15 +1049,35 @@ class Typer extends Namer
10491049
*/
10501050
var paramIndex = Map[Name, Int]()
10511051

1052-
/** If function is of the form
1052+
/** Infer parameter type from the body of the function
1053+
*
1054+
* 1. If function is of the form
1055+
*
10531056
* (x1, ..., xN) => f(... x1, ..., XN, ...)
1057+
*
10541058
* where each `xi` occurs exactly once in the argument list of `f` (in
10551059
* any order), the type of `f`, otherwise NoType.
1060+
*
1061+
* 2. If the function is of the form
1062+
*
1063+
* (using x1, ..., xN) => f
1064+
*
1065+
* where `f` is a contextual function type of the form `(T1, ..., TN) ?=> T`,
1066+
* then `xi` takes the type `Ti`.
1067+
*
10561068
* Updates `fnBody` and `paramIndex` as a side effect.
10571069
* @post: If result exists, `paramIndex` is defined for the name of
10581070
* every parameter in `params`.
10591071
*/
1060-
lazy val calleeType: Type = fnBody match {
1072+
lazy val calleeType: Type = untpd.stripAnnotated(fnBody) match {
1073+
case ident: untpd.Ident if isContextual =>
1074+
val ident1 = typedIdent(ident, WildcardType)
1075+
val tp = ident1.tpe.widen
1076+
if defn.isContextFunctionType(tp) && params.size == defn.functionArity(tp) then
1077+
paramIndex = params.map(_.name).zipWithIndex.toMap
1078+
fnBody = untpd.TypedSplice(ident1)
1079+
tp.select(nme.apply)
1080+
else NoType
10611081
case app @ Apply(expr, args) =>
10621082
paramIndex = {
10631083
for (param <- params; idx <- paramIndices(param, args))
@@ -2450,7 +2470,34 @@ class Typer extends Namer
24502470

24512471
protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
24522472
val defn.FunctionOf(formals, _, true, _) = pt.dropDependentRefinement
2453-
val ifun = desugar.makeContextualFunction(formals, tree, defn.isErasedFunctionType(pt))
2473+
2474+
// The getter of default parameters may reach here.
2475+
// Given the code below
2476+
//
2477+
// class Foo[A](run: A ?=> Int) {
2478+
// def foo[T](f: T ?=> Int = run) = ()
2479+
// }
2480+
//
2481+
// it desugars to
2482+
//
2483+
// class Foo[A](run: A ?=> Int) {
2484+
// def foo$default$1[T] = run
2485+
// def foo[T](f: T ?=> Int = run) = ()
2486+
// }
2487+
//
2488+
// The expected type for checking `run` in `foo$default$1` is
2489+
//
2490+
// <?> ?=> Int
2491+
//
2492+
// see tests/pos/i7778b.scala
2493+
2494+
val paramTypes = {
2495+
val hasWildcard = formals.exists(_.isInstanceOf[WildcardType])
2496+
if hasWildcard then formals.map(_ => untpd.TypeTree())
2497+
else formals.map(untpd.TypeTree)
2498+
}
2499+
2500+
val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt))
24542501
typr.println(i"make contextual function $tree / $pt ---> $ifun")
24552502
typed(ifun, pt)
24562503
}

tests/pos/i7778.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
object Example extends App {
2+
final case class Foo[A](run: A ?=> Int)
3+
}
4+
5+
object Example2 extends App {
6+
final case class Foo[A, B](run: (A, B) ?=> Int)
7+
}
8+
9+
10+
object Example3 extends App {
11+
final case class Foo[A, B](run: () ?=> Int)
12+
}

tests/pos/i7778b.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class Foo[A](run: A ?=> Int) {
2+
def foo[T](f: T ?=> Int = run) = ()
3+
}
4+

0 commit comments

Comments
 (0)