Skip to content

Commit 34c2918

Browse files
committed
Add the expected type to Poly's desugaring
By doing so, the expected type can drive the correct GADT casting in the match cases, which means the overall poly function tree will have a type that conforms to its expected type.
1 parent fad3175 commit 34c2918

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ object desugar {
15041504
.withSpan(original.span.withPoint(named.span.start))
15051505

15061506
/** Main desugaring method */
1507-
def apply(tree: Tree)(using Context): Tree = {
1507+
def apply(tree: Tree, pt: Type = NoType)(using Context): Tree = {
15081508

15091509
/** Create tree for for-comprehension `<for (enums) do body>` or
15101510
* `<for (enums) yield body>` where mapName and flatMapName are chosen
@@ -1698,11 +1698,11 @@ object desugar {
16981698
}
16991699
}
17001700

1701-
def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match {
1701+
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
17021702
case Parens(body1) =>
1703-
makePolyFunction(targs, body1)
1703+
makePolyFunction(targs, body1, pt)
17041704
case Block(Nil, body1) =>
1705-
makePolyFunction(targs, body1)
1705+
makePolyFunction(targs, body1, pt)
17061706
case Function(vargs, res) =>
17071707
assert(targs.nonEmpty)
17081708
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
@@ -1726,12 +1726,26 @@ object desugar {
17261726
}
17271727
else {
17281728
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1729-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1729+
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1730+
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1731+
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1732+
1733+
def typeTree(tp: Type) = tp match
1734+
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1735+
var bail = false
1736+
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
1737+
case tp: TypeRef => ref(tp)
1738+
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
1739+
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
1740+
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
1741+
val mapped = mapper(mt.resultType, topLevel = true)
1742+
if bail then TypeTree() else mapped
1743+
case _ => TypeTree()
17301744

17311745
val applyVParams = vargs.asInstanceOf[List[ValDef]]
17321746
.map(varg => varg.withAddedFlags(mods.flags | Param))
17331747
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1734-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res))
1748+
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
17351749
))
17361750
}
17371751
case _ =>
@@ -1753,7 +1767,7 @@ object desugar {
17531767

17541768
val desugared = tree match {
17551769
case PolyFunction(targs, body) =>
1756-
makePolyFunction(targs, body) orElse tree
1770+
makePolyFunction(targs, body, pt) orElse tree
17571771
case SymbolLit(str) =>
17581772
Apply(
17591773
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2871,7 +2871,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28712871

28722872
typedTypeOrClassDef
28732873
case tree: untpd.Labeled => typedLabeled(tree)
2874-
case _ => typedUnadapted(desugar(tree), pt, locked)
2874+
case _ => typedUnadapted(desugar(tree, pt), pt, locked)
28752875
}
28762876
}
28772877

@@ -2924,7 +2924,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
29242924
case tree: untpd.Splice => typedSplice(tree, pt)
29252925
case tree: untpd.MacroTree => report.error("Unexpected macro", tree.srcPos); tpd.nullLiteral // ill-formed code may reach here
29262926
case tree: untpd.Hole => typedHole(tree, pt)
2927-
case _ => typedUnadapted(desugar(tree), pt, locked)
2927+
case _ => typedUnadapted(desugar(tree, pt), pt, locked)
29282928
}
29292929

29302930
try

tests/pos/i15554.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
enum PingMessage[Response]:
2+
case Ping(from: String) extends PingMessage[String]
3+
4+
val pongBehavior: [O] => (Unit, PingMessage[O]) => (Unit, O) =
5+
[P] =>
6+
(state: Unit, msg: PingMessage[P]) =>
7+
msg match
8+
case PingMessage.Ping(from) => ((), s"Pong from $from")

0 commit comments

Comments
 (0)