Skip to content

Commit 980213b

Browse files
committed
Make the expandion of context bounds for poly types slightly more elegant
1 parent 42d914e commit 980213b

File tree

4 files changed

+77
-33
lines changed

4 files changed

+77
-33
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,7 @@ object desugar {
527527
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)
528528

529529
if meth.hasAttachment(PolyFunctionApply) then
530-
meth.removeAttachment(PolyFunctionApply)
531-
// (kπ): deffer this until we can type the result?
530+
// meth.removeAttachment(PolyFunctionApply)
532531
if ctx.mode.is(Mode.Type) then
533532
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
534533
else
@@ -1238,29 +1237,35 @@ object desugar {
12381237
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12391238
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12401239
*/
1241-
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1242-
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1243-
val paramFlags = fun match
1244-
case fun: FunctionWithMods =>
1245-
// TODO: make use of this in the desugaring when pureFuns is enabled.
1246-
// val isImpure = funFlags.is(Impure)
1247-
1248-
// Function flags to be propagated to each parameter in the desugared method type.
1249-
val givenFlag = fun.mods.flags.toTermFlags & Given
1250-
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1251-
case _ =>
1252-
vparamTypes.map(_ => EmptyFlags)
1253-
1254-
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1255-
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1256-
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1257-
}.toList
1258-
1259-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1260-
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1261-
.withFlags(Synthetic)
1262-
.withAttachment(PolyFunctionApply, List.empty)
1263-
)).withSpan(tree.span)
1240+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = tree match
1241+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
1242+
val paramFlags = fun match
1243+
case fun: FunctionWithMods =>
1244+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1245+
// val isImpure = funFlags.is(Impure)
1246+
1247+
// Function flags to be propagated to each parameter in the desugared method type.
1248+
val givenFlag = fun.mods.flags.toTermFlags & Given
1249+
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1250+
case _ =>
1251+
vparamTypes.map(_ => EmptyFlags)
1252+
1253+
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1254+
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1255+
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1256+
}.toList
1257+
1258+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1259+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
1260+
.withFlags(Synthetic)
1261+
.withAttachment(PolyFunctionApply, List.empty)
1262+
)).withSpan(tree.span)
1263+
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, res) =>
1264+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1265+
DefDef(nme.apply, tparams :: Nil, res, EmptyTree)
1266+
.withFlags(Synthetic)
1267+
.withAttachment(PolyFunctionApply, List.empty)
1268+
)).withSpan(tree.span)
12641269
end makePolyFunctionType
12651270

12661271
/** Invent a name for an anonympus given of type or template `impl`. */

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,8 @@ object Parsers {
17611761
getFunction(body) match
17621762
case Some(f) =>
17631763
PolyFunction(tparams, body)
1764+
case None if tparams.exists(_.rhs.isInstanceOf[ContextBounds]) =>
1765+
PolyFunction(tparams, body)
17641766
case None =>
17651767
syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
17661768
Ident(nme.ERROR.toTypeName)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3590,14 +3590,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35903590

35913591
private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
35923592
case tpe: MethodType =>
3593-
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3593+
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
35943594
case tpe: PolyType =>
3595-
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3595+
tpe.derivedLambdaType(tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
35963596
case tpe: RefinedType =>
3597-
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3598-
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3597+
tpe.derivedRefinedType(
3598+
pushDownDeferredEvidenceParams(tpe.parent, params, span),
3599+
tpe.refinedName,
3600+
pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3601+
)
35993602
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3600-
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3603+
tpe.derivedAppliedType(tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
36013604
case tpe =>
36023605
val paramNames = params.map(_.name)
36033606
val paramTpts = params.map(_.tpt)
@@ -3606,18 +3609,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36063609
typed(ctxFunction).tpe
36073610
}
36083611

3609-
private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
3612+
private def extractTopMethodTermParams(tpe: Type)(using Context): (List[TermName], List[Type]) = tpe match {
3613+
case tpe: MethodType =>
3614+
tpe.paramNames -> tpe.paramInfos
3615+
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
3616+
extractTopMethodTermParams(tpe.refinedInfo)
3617+
case _ =>
3618+
Nil -> Nil
3619+
}
3620+
3621+
private def removeTopMethodTermParams(tpe: Type)(using Context): Type = tpe match {
3622+
case tpe: MethodType =>
3623+
tpe.resultType
3624+
case tpe: RefinedType if defn.isFunctionType(tpe.parent) =>
3625+
tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3626+
case tpe: AppliedType if defn.isFunctionType(tpe) =>
3627+
tpe.args.last
3628+
case _ =>
3629+
tpe
3630+
}
3631+
3632+
private def healToPolyFunctionType(tree: Tree)(using Context): Tree = tree match {
3633+
case defdef: DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam))) && defdef.paramss.size == 1 =>
3634+
val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3635+
val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3636+
val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef(name, TypeTree(tpe), flags = SyntheticTermParam))
3637+
val newDefDef = cpy.DefDef(defdef)(paramss = defdef.paramss ++ List(newParams), tpt = untpd.TypeTree(newTpe))
3638+
val nestedCtx = ctx.fresh.setNewTyperState()
3639+
typed(newDefDef)(using nestedCtx)
3640+
case _ => tree
3641+
}
3642+
3643+
private def addDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
36103644
tree.getAttachment(desugar.PolyFunctionApply) match
36113645
case Some(params) if params.nonEmpty =>
36123646
tree.removeAttachment(desugar.PolyFunctionApply)
36133647
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
36143648
TypeTree(tpe).withSpan(tree.span) -> tpe
3649+
// case Some(params) if params.isEmpty =>
3650+
// println(s"tree: $tree")
3651+
// healToPolyFunctionType(tree) -> pt
36153652
case _ => tree -> pt
36163653
}
36173654

36183655
/** Interpolate and simplify the type of the given tree. */
36193656
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
3620-
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
3657+
val (tree1, pt1) = addDeferredEvidenceParams(tree, pt)
36213658
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
36223659
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
36233660
|| tree1.isDef // ... unless tree is a definition

tests/pos/contextbounds-for-poly-functions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type CmpWeak[X] = X => Boolean
3232
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
3333
val less4_0: [X: Ord] => X => X => Boolean =
3434
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
35-
val less4: Comparer2Weak =
35+
val less4_1: Comparer2Weak =
3636
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
3737

3838
val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

0 commit comments

Comments
 (0)