Skip to content

Commit 75ab141

Browse files
committed
Support inferring dependent polymorphic lambdas from the expected type
Reuse the existing `DependentTypeTree` mechanism already in place for monomorphic lambdas to compute the result type of polymorphic lambdas based on their expected type.
1 parent 5846f0e commit 75ab141

File tree

7 files changed

+36
-16
lines changed

7 files changed

+36
-16
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,15 +1745,9 @@ object desugar {
17451745
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
17461746

17471747
def typeTree(tp: Type) = tp match
1748-
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1749-
var bail = false
1750-
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
1751-
case tp: TypeRef => ref(tp)
1752-
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
1753-
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
1754-
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
1755-
val mapped = mapper(mt.resultType, topLevel = true)
1756-
if bail then TypeTree() else mapped
1748+
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1749+
untpd.DependentTypeTree((tsyms, vsyms) =>
1750+
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
17571751
case _ => TypeTree()
17581752

17591753
val applyVParams = vargs.asInstanceOf[List[ValDef]]

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
151151
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
152152

153153
/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
154-
case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree
154+
case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
155155

156156
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] {
157157
override def isEmpty: Boolean = true

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,15 +1681,17 @@ class Namer { typer: Typer =>
16811681
def valOrDefDefSig(mdef: ValOrDefDef, sym: Symbol, paramss: List[List[Symbol]], paramFn: Type => Type)(using Context): Type = {
16821682

16831683
def inferredType = inferredResultType(mdef, sym, paramss, paramFn, WildcardType)
1684-
lazy val termParamss = paramss.collect { case TermSymbols(vparams) => vparams }
16851684

16861685
val tptProto = mdef.tpt match {
16871686
case _: untpd.DerivedTypeTree =>
16881687
WildcardType
16891688
case TypeTree() =>
16901689
checkMembersOK(inferredType, mdef.srcPos)
16911690
case DependentTypeTree(tpFun) =>
1692-
val tpe = tpFun(termParamss.head)
1691+
// A lambda has at most one type parameter list followed by exactly one term parameter list.
1692+
val tpe = (paramss: @unchecked) match
1693+
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
1694+
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
16931695
if (isFullyDefined(tpe, ForceDegree.none)) tpe
16941696
else typedAheadExpr(mdef.rhs, tpe).tpe
16951697
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
@@ -1713,7 +1715,8 @@ class Namer { typer: Typer =>
17131715
// So fixing levels at instantiation avoids the soundness problem but apparently leads
17141716
// to type inference problems since it comes too late.
17151717
if !Config.checkLevelsOnConstraints then
1716-
val hygienicType = TypeOps.avoid(rhsType, termParamss.flatten)
1718+
val termParams = paramss.collect { case TermSymbols(vparams) => vparams }.flatten
1719+
val hygienicType = TypeOps.avoid(rhsType, termParams)
17171720
if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe))
17181721
report.error(
17191722
em"""return type ${tpt.tpe} of lambda cannot be made hygienic

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,11 +1334,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13341334
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
13351335
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
13361336
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1337-
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
1337+
(formals, untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
13381338
case SAMType(mt @ MethodTpe(_, formals, restpe)) =>
13391339
(formals,
13401340
if (mt.isResultDependent)
1341-
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
1341+
untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
13421342
else
13431343
typeTree(restpe))
13441344
case _ =>
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 ---------------------------------------------
22
1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error
33
| ^
4-
| Found: [T] => (x: Int) => Int
4+
| Found: [T] => (x: Int) => x.type
55
| Required: [T] => (x: T) => x.type
66
|
77
| longer explanation available when compiling with `-explain`

tests/pos/i16756.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class DependentPoly {
2+
3+
sealed trait Col[V] {
4+
5+
trait Wrapper
6+
val wrapper: Wrapper = ???
7+
}
8+
9+
object Col1 extends Col[Int]
10+
11+
object Col2 extends Col[Double]
12+
13+
val polyFn: [C <: DependentPoly.this.Col[?]] => (x: C) => x.Wrapper =
14+
[C <: Col[?]] => (x: C) => (x.wrapper: x.Wrapper)
15+
}
16+

tests/run/polymorphic-functions.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ object Test extends App {
8585
val v0a: String = v0
8686
assert(v0 == "foo")
8787

88+
// Used to fail with: Found: ... => List[T]
89+
// Expected: ... => List[x.type]
90+
val md2: [T] => (x: T) => List[x.type] = [T] => (x: T) => List(x)
91+
val x = 1
92+
val v1 = md2(x)
93+
val v1a: List[x.type] = v1
94+
8895
// Contextual
8996
trait Show[T] { def show(t: T): String }
9097
implicit val si: Show[Int] =

0 commit comments

Comments
 (0)