Skip to content

Commit 0022e5b

Browse files
committed
Implement polymorphic lambdas using Closure nodes for efficiency
Previously, we desugared them manually into anonymous class instances, but by using a Closure node instead, we ensure that they get translated into indy lambdas on the JVM. Also cleaned up and added a TODO in the desugaring of polymorphic function types into refinement types since I realized that purity wasn't taken into account.
1 parent 75ab141 commit 0022e5b

File tree

4 files changed

+87
-77
lines changed

4 files changed

+87
-77
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,40 @@ object desugar {
10201020
name
10211021
}
10221022

1023+
/** Strip parens and empty blocks around the body of `tree`. */
1024+
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
1025+
def stripped(body: Tree): Tree = body match
1026+
case Parens(body1) =>
1027+
stripped(body1)
1028+
case Block(Nil, body1) =>
1029+
stripped(body1)
1030+
case _ => body
1031+
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
1032+
1033+
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1034+
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1035+
*/
1036+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1037+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1038+
val funFlags = fun match
1039+
case fun: FunctionWithMods =>
1040+
fun.mods.flags
1041+
case _ => EmptyFlags
1042+
1043+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1044+
// val isImpure = funFlags.is(Impure)
1045+
1046+
// Function flags to be propagated to each parameter in the desugared method type.
1047+
val paramFlags = funFlags.toTermFlags & Given
1048+
val vparams = vparamTypes.zipWithIndex.map:
1049+
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
1050+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1051+
1052+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1053+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1054+
)).withSpan(tree.span)
1055+
end makePolyFunctionType
1056+
10231057
/** Invent a name for an anonympus given of type or template `impl`. */
10241058
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
10251059
val str = impl match
@@ -1413,14 +1447,17 @@ object desugar {
14131447
}
14141448

14151449
/** Make closure corresponding to function.
1416-
* params => body
1450+
* [tparams] => params => body
14171451
* ==>
1418-
* def $anonfun(params) = body
1452+
* def $anonfun[tparams](params) = body
14191453
* Closure($anonfun)
14201454
*/
1421-
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1455+
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1456+
val paramss: List[ParamClause] =
1457+
if tparams.isEmpty then vparams :: Nil
1458+
else tparams :: vparams :: Nil
14221459
Block(
1423-
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
1460+
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
14241461
.withSpan(span)
14251462
.withMods(synthetic | Artifact),
14261463
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
@@ -1712,56 +1749,6 @@ object desugar {
17121749
}
17131750
}
17141751

1715-
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
1716-
case Parens(body1) =>
1717-
makePolyFunction(targs, body1, pt)
1718-
case Block(Nil, body1) =>
1719-
makePolyFunction(targs, body1, pt)
1720-
case Function(vargs, res) =>
1721-
assert(targs.nonEmpty)
1722-
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1723-
val mods = body match {
1724-
case body: FunctionWithMods => body.mods
1725-
case _ => untpd.EmptyModifiers
1726-
}
1727-
val polyFunctionTpt = ref(defn.PolyFunctionType)
1728-
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1729-
if (ctx.mode.is(Mode.Type)) {
1730-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1731-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1732-
1733-
val applyVParams = vargs.zipWithIndex.map {
1734-
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1735-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
1736-
}
1737-
RefinedTypeTree(polyFunctionTpt, List(
1738-
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
1739-
))
1740-
}
1741-
else {
1742-
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1743-
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1744-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1745-
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1746-
1747-
def typeTree(tp: Type) = tp match
1748-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1749-
untpd.DependentTypeTree((tsyms, vsyms) =>
1750-
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1751-
case _ => TypeTree()
1752-
1753-
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1754-
.map(varg => varg.withAddedFlags(mods.flags | Param))
1755-
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1756-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
1757-
))
1758-
}
1759-
case _ =>
1760-
// may happen for erroneous input. An error will already have been reported.
1761-
assert(ctx.reporter.errorsReported)
1762-
EmptyTree
1763-
}
1764-
17651752
// begin desugar
17661753

17671754
// Special case for `Parens` desugaring: unlike all the desugarings below,
@@ -1774,8 +1761,6 @@ object desugar {
17741761
}
17751762

17761763
val desugared = tree match {
1777-
case PolyFunction(targs, body) =>
1778-
makePolyFunction(targs, body, pt) orElse tree
17791764
case SymbolLit(str) =>
17801765
Apply(
17811766
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,8 @@ object Types {
18421842
if alwaysDependent || mt.isResultDependent then
18431843
RefinedType(funType, nme.apply, mt)
18441844
else funType
1845+
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
1846+
RefinedType(defn.PolyFunctionType, nme.apply, poly)
18451847
}
18461848

18471849
/** The signature of this type. This is by default NotAMethod,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,12 +1633,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16331633
)
16341634
cpy.ValDef(param)(tpt = paramTpt)
16351635
if isErased then param0.withAddedFlags(Flags.Erased) else param0
1636-
desugared = desugar.makeClosure(inferredParams, fnBody, resultTpt, tree.span)
1636+
desugared = desugar.makeClosure(Nil, inferredParams, fnBody, resultTpt, tree.span)
16371637

16381638
typed(desugared, pt)
16391639
.showing(i"desugared fun $tree --> $desugared with pt = $pt", typr)
16401640
}
16411641

1642+
1643+
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
1644+
val tree1 = desugar.normalizePolyFunction(tree)
1645+
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1646+
else typedPolyFunctionValue(tree1, pt)
1647+
1648+
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
1649+
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
1650+
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
1651+
1652+
val resultTpt = pt.dealias match
1653+
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1654+
untpd.DependentTypeTree((tsyms, vsyms) =>
1655+
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1656+
case _ => untpd.TypeTree()
1657+
1658+
val desugared = desugar.makeClosure(tparams, vparams, body, resultTpt, tree.span)
1659+
typed(desugared, pt)
1660+
end typedPolyFunctionValue
1661+
16421662
def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
16431663
val env1 = tree.env mapconserve (typed(_))
16441664
val meth1 = typedUnadapted(tree.meth)
@@ -1676,6 +1696,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16761696
else
16771697
EmptyTree
16781698
}
1699+
case _: PolyType =>
1700+
// Polymorphic SAMs are not currently supported (#6904).
1701+
EmptyTree
16791702
case tp =>
16801703
if !tp.isErroneous then
16811704
throw new java.lang.Error(i"internal error: closing over non-method $tp, pos = ${tree.span}")
@@ -2433,7 +2456,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24332456
case rhs => typedExpr(rhs, tpt1.tpe.widenExpr)
24342457
}
24352458
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
2436-
postProcessInfo(sym)
2459+
postProcessInfo(vdef1, sym)
24372460
vdef1.setDefTree
24382461
}
24392462

@@ -2536,19 +2559,31 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
25362559

25372560
val ddef2 = assignType(cpy.DefDef(ddef)(name, paramss1, tpt1, rhs1), sym)
25382561

2539-
postProcessInfo(sym)
2562+
postProcessInfo(ddef2, sym)
25402563
ddef2.setDefTree
25412564
//todo: make sure dependent method types do not depend on implicits or by-name params
25422565
}
25432566

25442567
/** (1) Check that the signature of the class member does not return a repeated parameter type
25452568
* (2) If info is an erased class, set erased flag of member
2569+
* (3) Check that erased classes are not parameters of polymorphic functions.
25462570
*/
2547-
private def postProcessInfo(sym: Symbol)(using Context): Unit =
2571+
private def postProcessInfo(mdef: MemberDef, sym: Symbol)(using Context): Unit =
25482572
if (!sym.isOneOf(Synthetic | InlineProxy | Param) && sym.info.finalResultType.isRepeatedParam)
25492573
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
25502574
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
25512575
sym.setFlag(Erased)
2576+
if
2577+
sym.info.isInstanceOf[PolyType] &&
2578+
((sym.name eq nme.ANON_FUN) ||
2579+
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
2580+
then
2581+
mdef match
2582+
case DefDef(_, _ :: vparams :: Nil, _, _) =>
2583+
vparams.foreach: vparam =>
2584+
if vparam.symbol.is(Erased) then
2585+
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
2586+
case _ =>
25522587

25532588
def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
25542589
val TypeDef(name, rhs) = tdef
@@ -2695,19 +2730,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
26952730
// check value class constraints
26962731
checkDerivedValueClass(cls, body1)
26972732

2698-
// check PolyFunction constraints (no erased functions!)
2699-
if parents1.exists(_.tpe.classSymbol eq defn.PolyFunctionClass) then
2700-
body1.foreach {
2701-
case ddef: DefDef =>
2702-
ddef.paramss.foreach { params =>
2703-
val erasedParam = params.collectFirst { case vdef: ValDef if vdef.symbol.is(Erased) => vdef }
2704-
erasedParam.foreach { p =>
2705-
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", p.srcPos)
2706-
}
2707-
}
2708-
case _ =>
2709-
}
2710-
27112733
val effectiveOwner = cls.owner.skipWeakOwner
27122734
if !cls.isRefinementClass
27132735
&& !cls.isAllOf(PrivateLocal)
@@ -3059,6 +3081,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
30593081
case tree: untpd.Block => typedBlock(desugar.block(tree), pt)(using ctx.fresh.setNewScope)
30603082
case tree: untpd.If => typedIf(tree, pt)
30613083
case tree: untpd.Function => typedFunction(tree, pt)
3084+
case tree: untpd.PolyFunction => typedPolyFunction(tree, pt)
30623085
case tree: untpd.Closure => typedClosure(tree, pt)
30633086
case tree: untpd.Import => typedImport(tree)
30643087
case tree: untpd.Export => typedExport(tree)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 ---------------------------------------------
1+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:33 ---------------------------------------------
22
1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error
3-
| ^
4-
| Found: [T] => (x: Int) => x.type
5-
| Required: [T] => (x: T) => x.type
3+
| ^^^^^^^^^^^^^^^^^^^^
4+
| Found: [T] => (x: Int) => x.type
5+
| Required: [T] => (x: T) => x.type
66
|
77
| longer explanation available when compiling with `-explain`

0 commit comments

Comments
 (0)