From 34c2918fb39e12211d0b0eca2ef974fa5657707f Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Fri, 1 Jul 2022 19:30:09 +0100 Subject: [PATCH] Add the expected type to Poly's desugaring By doing so, the expected type can drive the correct GADT casting in the match cases, which means the overall poly function tree will have a type that conforms to its expected type. --- .../src/dotty/tools/dotc/ast/Desugar.scala | 28 ++++++++++++++----- .../src/dotty/tools/dotc/typer/Typer.scala | 4 +-- tests/pos/i15554.scala | 8 ++++++ 3 files changed, 31 insertions(+), 9 deletions(-) create mode 100644 tests/pos/i15554.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index fb045b8a5f64..10d4fed7f058 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1504,7 +1504,7 @@ object desugar { .withSpan(original.span.withPoint(named.span.start)) /** Main desugaring method */ - def apply(tree: Tree)(using Context): Tree = { + def apply(tree: Tree, pt: Type = NoType)(using Context): Tree = { /** Create tree for for-comprehension `` or * `` where mapName and flatMapName are chosen @@ -1698,11 +1698,11 @@ object desugar { } } - def makePolyFunction(targs: List[Tree], body: Tree): Tree = body match { + def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match { case Parens(body1) => - makePolyFunction(targs, body1) + makePolyFunction(targs, body1, pt) case Block(Nil, body1) => - makePolyFunction(targs, body1) + makePolyFunction(targs, body1, pt) case Function(vargs, res) => assert(targs.nonEmpty) // TODO: Figure out if we need a `PolyFunctionWithMods` instead. @@ -1726,12 +1726,26 @@ object desugar { } else { // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body - // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body } + // with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R + // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body } + // where R2 is R, with all references to S_1..S_M replaced with T1..T_M. + + def typeTree(tp: Type) = tp match + case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass => + var bail = false + def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match + case tp: TypeRef => ref(tp) + case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name) + case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_))) + case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree } + val mapped = mapper(mt.resultType, topLevel = true) + if bail then TypeTree() else mapped + case _ => TypeTree() val applyVParams = vargs.asInstanceOf[List[ValDef]] .map(varg => varg.withAddedFlags(mods.flags | Param)) New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef, - List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, TypeTree(), res)) + List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res)) )) } case _ => @@ -1753,7 +1767,7 @@ object desugar { val desugared = tree match { case PolyFunction(targs, body) => - makePolyFunction(targs, body) orElse tree + makePolyFunction(targs, body, pt) orElse tree case SymbolLit(str) => Apply( ref(defn.ScalaSymbolClass.companionModule.termRef), diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 71a8872343b4..830131311c12 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2871,7 +2871,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedTypeOrClassDef case tree: untpd.Labeled => typedLabeled(tree) - case _ => typedUnadapted(desugar(tree), pt, locked) + case _ => typedUnadapted(desugar(tree, pt), pt, locked) } } @@ -2924,7 +2924,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case tree: untpd.Splice => typedSplice(tree, pt) case tree: untpd.MacroTree => report.error("Unexpected macro", tree.srcPos); tpd.nullLiteral // ill-formed code may reach here case tree: untpd.Hole => typedHole(tree, pt) - case _ => typedUnadapted(desugar(tree), pt, locked) + case _ => typedUnadapted(desugar(tree, pt), pt, locked) } try diff --git a/tests/pos/i15554.scala b/tests/pos/i15554.scala new file mode 100644 index 000000000000..8573a5fff549 --- /dev/null +++ b/tests/pos/i15554.scala @@ -0,0 +1,8 @@ +enum PingMessage[Response]: + case Ping(from: String) extends PingMessage[String] + +val pongBehavior: [O] => (Unit, PingMessage[O]) => (Unit, O) = + [P] => + (state: Unit, msg: PingMessage[P]) => + msg match + case PingMessage.Ping(from) => ((), s"Pong from $from")