Skip to content

Commit 9d08db1

Browse files
committed
Support inferring dependent polymorphic lambdas from the expected type
Reuse the existing `DependentTypeTree` mechanism already in place for monomorphic lambdas to compute the result type of polymorphic lambdas based on their expected type.
1 parent 871e23f commit 9d08db1

File tree

7 files changed

+36
-16
lines changed

7 files changed

+36
-16
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,15 +1761,9 @@ object desugar {
17611761
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
17621762

17631763
def typeTree(tp: Type) = tp match
1764-
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1765-
var bail = false
1766-
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
1767-
case tp: TypeRef => ref(tp)
1768-
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
1769-
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
1770-
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
1771-
val mapped = mapper(mt.resultType, topLevel = true)
1772-
if bail then TypeTree() else mapped
1764+
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1765+
untpd.DependentTypeTree((tsyms, vsyms) =>
1766+
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
17731767
case _ => TypeTree()
17741768

17751769
val applyVParams = vargs.asInstanceOf[List[ValDef]]

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
151151
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
152152

153153
/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
154-
case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree
154+
case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
155155

156156
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] {
157157
override def isEmpty: Boolean = true

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,15 +1692,17 @@ class Namer { typer: Typer =>
16921692
def valOrDefDefSig(mdef: ValOrDefDef, sym: Symbol, paramss: List[List[Symbol]], paramFn: Type => Type)(using Context): Type = {
16931693

16941694
def inferredType = inferredResultType(mdef, sym, paramss, paramFn, WildcardType)
1695-
lazy val termParamss = paramss.collect { case TermSymbols(vparams) => vparams }
16961695

16971696
val tptProto = mdef.tpt match {
16981697
case _: untpd.DerivedTypeTree =>
16991698
WildcardType
17001699
case TypeTree() =>
17011700
checkMembersOK(inferredType, mdef.srcPos)
17021701
case DependentTypeTree(tpFun) =>
1703-
val tpe = tpFun(termParamss.head)
1702+
// A lambda has at most one type parameter list followed by exactly one term parameter list.
1703+
val tpe = (paramss: @unchecked) match
1704+
case TypeSymbols(tparams) :: TermSymbols(vparams) :: Nil => tpFun(tparams, vparams)
1705+
case TermSymbols(vparams) :: Nil => tpFun(Nil, vparams)
17041706
if (isFullyDefined(tpe, ForceDegree.none)) tpe
17051707
else typedAheadExpr(mdef.rhs, tpe).tpe
17061708
case TypedSplice(tpt: TypeTree) if !isFullyDefined(tpt.tpe, ForceDegree.none) =>
@@ -1724,7 +1726,8 @@ class Namer { typer: Typer =>
17241726
// So fixing levels at instantiation avoids the soundness problem but apparently leads
17251727
// to type inference problems since it comes too late.
17261728
if !Config.checkLevelsOnConstraints then
1727-
val hygienicType = TypeOps.avoid(rhsType, termParamss.flatten)
1729+
val termParams = paramss.collect { case TermSymbols(vparams) => vparams }.flatten
1730+
val hygienicType = TypeOps.avoid(rhsType, termParams)
17281731
if (!hygienicType.isValueType || !(hygienicType <:< tpt.tpe))
17291732
report.error(
17301733
em"""return type ${tpt.tpe} of lambda cannot be made hygienic

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,14 +1323,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13231323
(pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
13241324
case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe))
13251325
if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1326-
(formals, untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef))))
1326+
(formals, untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
13271327
case pt1 @ SAMType(mt @ MethodTpe(_, formals, _)) if !SAMType.isParamDependentRec(mt) =>
13281328
val restpe = mt.resultType match
13291329
case mt: MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined))
13301330
case tp => tp
13311331
(formals,
13321332
if (mt.isResultDependent)
1333-
untpd.DependentTypeTree(syms => restpe.substParams(mt, syms.map(_.termRef)))
1333+
untpd.DependentTypeTree((_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
13341334
else
13351335
typeTree(restpe))
13361336
case _ =>
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 ---------------------------------------------
22
1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error
33
| ^
4-
| Found: [T] => (x: Int) => Int
4+
| Found: [T] => (x: Int) => x.type
55
| Required: [T] => (x: T) => x.type
66
|
77
| longer explanation available when compiling with `-explain`

tests/pos/i16756.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class DependentPoly {
2+
3+
sealed trait Col[V] {
4+
5+
trait Wrapper
6+
val wrapper: Wrapper = ???
7+
}
8+
9+
object Col1 extends Col[Int]
10+
11+
object Col2 extends Col[Double]
12+
13+
val polyFn: [C <: DependentPoly.this.Col[?]] => (x: C) => x.Wrapper =
14+
[C <: Col[?]] => (x: C) => (x.wrapper: x.Wrapper)
15+
}
16+

tests/run/polymorphic-functions.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ object Test extends App {
8585
val v0a: String = v0
8686
assert(v0 == "foo")
8787

88+
// Used to fail with: Found: ... => List[T]
89+
// Expected: ... => List[x.type]
90+
val md2: [T] => (x: T) => List[x.type] = [T] => (x: T) => List(x)
91+
val x = 1
92+
val v1 = md2(x)
93+
val v1a: List[x.type] = v1
94+
8895
// Contextual
8996
trait Show[T] { def show(t: T): String }
9097
implicit val si: Show[Int] =

0 commit comments

Comments
 (0)