diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index c2517145c935..24be0530ac5b 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1336,9 +1336,9 @@ object desugar { Function(param :: Nil, Block(vdefs, body)) } - def makeContextualFunction(formals: List[Type], body: Tree, isErased: Boolean)(implicit ctx: Context): Tree = { + def makeContextualFunction(formals: List[Tree], body: Tree, isErased: Boolean)(implicit ctx: Context): Function = { val mods = if (isErased) Given | Erased else Given - val params = makeImplicitParameters(formals.map(TypeTree), mods) + val params = makeImplicitParameters(formals, mods) FunctionWithMods(params, body, Modifiers(mods)) } diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 657ce9fd0e94..7f4c93d8542e 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -98,6 +98,11 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] => case _ => tree } + def stripAnnotated(tree: Tree): Tree = tree match { + case Annotated(arg, _) => arg + case _ => tree + } + /** The number of arguments in an application */ def numArgs(tree: Tree): Int = unsplice(tree) match { case Apply(fn, args) => numArgs(fn) + args.length diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index eeb65cc5727a..4c2ce2ed9ee0 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -598,7 +598,7 @@ object ProtoTypes { */ private def wildApprox(tp: Type, theMap: WildApproxMap, seen: Set[TypeParamRef], internal: Set[TypeLambda])(using Context): Type = tp match { case tp: NamedType => // default case, inlined for speed - val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.is(Flags.Case) && !tp.symbol.isClass + val isPatternBoundTypeRef = tp.isInstanceOf[TypeRef] && tp.symbol.isPatternBound if (isPatternBoundTypeRef) WildcardType(tp.underlying.bounds) else if (tp.symbol.isStatic || (tp.prefix `eq` NoPrefix)) tp else tp.derivedSelect(wildApprox(tp.prefix, theMap, seen, internal)) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 8ed14c2fd32b..5e56352c739f 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1049,15 +1049,35 @@ class Typer extends Namer */ var paramIndex = Map[Name, Int]() - /** If function is of the form + /** Infer parameter type from the body of the function + * + * 1. If function is of the form + * * (x1, ..., xN) => f(... x1, ..., XN, ...) + * * where each `xi` occurs exactly once in the argument list of `f` (in * any order), the type of `f`, otherwise NoType. + * + * 2. If the function is of the form + * + * (using x1, ..., xN) => f + * + * where `f` is a contextual function type of the form `(T1, ..., TN) ?=> T`, + * then `xi` takes the type `Ti`. + * * Updates `fnBody` and `paramIndex` as a side effect. * @post: If result exists, `paramIndex` is defined for the name of * every parameter in `params`. */ - lazy val calleeType: Type = fnBody match { + lazy val calleeType: Type = untpd.stripAnnotated(fnBody) match { + case ident: untpd.Ident if isContextual => + val ident1 = typedIdent(ident, WildcardType) + val tp = ident1.tpe.widen + if defn.isContextFunctionType(tp) && params.size == defn.functionArity(tp) then + paramIndex = params.map(_.name).zipWithIndex.toMap + fnBody = untpd.TypedSplice(ident1) + tp.select(nme.apply) + else NoType case app @ Apply(expr, args) => paramIndex = { for (param <- params; idx <- paramIndices(param, args)) @@ -2450,7 +2470,34 @@ class Typer extends Namer protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { val defn.FunctionOf(formals, _, true, _) = pt.dropDependentRefinement - val ifun = desugar.makeContextualFunction(formals, tree, defn.isErasedFunctionType(pt)) + + // The getter of default parameters may reach here. + // Given the code below + // + // class Foo[A](run: A ?=> Int) { + // def foo[T](f: T ?=> Int = run) = () + // } + // + // it desugars to + // + // class Foo[A](run: A ?=> Int) { + // def foo$default$1[T] = run + // def foo[T](f: T ?=> Int = run) = () + // } + // + // The expected type for checking `run` in `foo$default$1` is + // + // ?=> Int + // + // see tests/pos/i7778b.scala + + val paramTypes = { + val hasWildcard = formals.exists(_.isInstanceOf[WildcardType]) + if hasWildcard then formals.map(_ => untpd.TypeTree()) + else formals.map(untpd.TypeTree) + } + + val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt)) typr.println(i"make contextual function $tree / $pt ---> $ifun") typed(ifun, pt) } diff --git a/tests/pos/i7778.scala b/tests/pos/i7778.scala new file mode 100644 index 000000000000..cefed289c819 --- /dev/null +++ b/tests/pos/i7778.scala @@ -0,0 +1,12 @@ +object Example extends App { + final case class Foo[A](run: A ?=> Int) +} + +object Example2 extends App { + final case class Foo[A, B](run: (A, B) ?=> Int) +} + + +object Example3 extends App { + final case class Foo[A, B](run: () ?=> Int) +} diff --git a/tests/pos/i7778b.scala b/tests/pos/i7778b.scala new file mode 100644 index 000000000000..eaea85e5cf3b --- /dev/null +++ b/tests/pos/i7778b.scala @@ -0,0 +1,4 @@ +class Foo[A](run: A ?=> Int) { + def foo[T](f: T ?=> Int = run) = () +} +