Skip to content

Commit 60a7057

Browse files
committed
Bugfix: Fix handling for call-by-name arguments of applied types
We need some special treatment for types such as `(=> A) => B`
1 parent cbf6be2 commit 60a7057

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ class CheckCaptures extends Recheck, SymTransformer:
432432
block match
433433
case closureDef(mdef) =>
434434
pt.dealias match
435-
case defn.FunctionOf(ptformals, _, _, _) if ptformals.forall(_.captureSet.isAlwaysEmpty) =>
435+
case defn.FunctionOf(ptformals, _, _, _)
436+
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
436437
// Redo setup of the anonymous function so that formal parameters don't
437438
// get capture sets. This is important to avoid false widenings to `*`
438439
// when taking the base type of the actual closures's dependent function
@@ -442,9 +443,14 @@ class CheckCaptures extends Recheck, SymTransformer:
442443
// First, undo the previous setup which installed a completer for `meth`.
443444
atPhase(preRecheckPhase.prev)(meth.denot.copySymDenotation())
444445
.installAfter(preRecheckPhase)
446+
447+
//atPhase(preRecheckPhase.prev)(meth.denot.copySymDenotation())
448+
// .installAfter(thisPhase)
445449
// Next, update all parameter symbols to match expected formals
446450
meth.paramSymss.head.lazyZip(ptformals).foreach { (psym, pformal) =>
447-
psym.copySymDenotation(info = pformal).installAfter(preRecheckPhase)
451+
psym.copySymDenotation(info = pformal.mapExprType).installAfter(preRecheckPhase)
452+
// psym.copySymDenotation(info = pformal).installAfter(thisPhase)
453+
// println(i"UPDATE $psym to ${pformal.mapExprType}, was $pformal")
448454
}
449455
// Next, update types of parameter ValDefs
450456
mdef.paramss.head.lazyZip(ptformals).foreach { (param, pformal) =>

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import StdNames.nme
2222
import reporting.trace
2323
import annotation.constructorOnly
2424
import cc.CaptureSet.IdempotentCaptRefMap
25+
import dotty.tools.dotc.core.Denotations.SingleDenotation
2526

2627
object Recheck:
2728
import tpd.*
@@ -91,6 +92,18 @@ object Recheck:
9192

9293
def hasRememberedType: Boolean = tree.hasAttachment(RecheckedType)
9394

95+
extension (tpe: Type)
96+
97+
/** Map ExprType => T to () ?=> T (and analogously for pure versions).
98+
* Even though this phase runs after ElimByName, ExprTypes can still occur
99+
* as by-name arguments of applied types. See note in doc comment for
100+
* ElimByName phase. Test case is bynamefun.scala.
101+
*/
102+
def mapExprType(using Context): Type = tpe match
103+
case ExprType(rt) => defn.ByNameFunction(rt)
104+
case _ => tpe
105+
106+
94107
/** A base class that runs a simplified typer pass over an already re-typed program. The pass
95108
* does not transform trees but returns instead the re-typed type of each tree as it is
96109
* traversed. The Recheck phase must be directly preceded by a phase of type PreRecheck.
@@ -152,15 +165,27 @@ abstract class Recheck extends Phase, SymTransformer:
152165
else AnySelectionProto
153166
recheckSelection(tree, recheck(qual, proto).widenIfUnstable, name, pt)
154167

168+
/** When we select the `apply` of a function with type such as `(=> A) => B`,
169+
* we need to convert the parameter type `=> A` to `() ?=> A`. See doc comment
170+
* of `mapExprType`.
171+
*/
172+
def normalizeByName(mbr: SingleDenotation)(using Context): SingleDenotation = mbr.info match
173+
case mt: MethodType if mt.paramInfos.exists(_.isInstanceOf[ExprType]) =>
174+
mbr.derivedSingleDenotation(mbr.symbol,
175+
mt.derivedLambdaType(paramInfos = mt.paramInfos.map(_.mapExprType)))
176+
case _ =>
177+
mbr
178+
155179
def recheckSelection(tree: Select, qualType: Type, name: Name,
156180
sharpen: Denotation => Denotation)(using Context): Type =
157181
if name.is(OuterSelectName) then tree.tpe
158182
else
159183
//val pre = ta.maybeSkolemizePrefix(qualType, name)
160-
val mbr = sharpen(
184+
val mbr = normalizeByName(
185+
sharpen(
161186
qualType.findMember(name, qualType,
162187
excluded = if tree.symbol.is(Private) then EmptyFlags else Private
163-
)).suchThat(tree.symbol == _)
188+
)).suchThat(tree.symbol == _))
164189
constFold(tree, qualType.select(name, mbr))
165190
//.showing(i"recheck select $qualType . $name : ${mbr.info} = $result")
166191

@@ -215,15 +240,16 @@ abstract class Recheck extends Phase, SymTransformer:
215240
mt.instantiate(argTypes)
216241

217242
def recheckApply(tree: Apply, pt: Type)(using Context): Type =
218-
recheck(tree.fun).widen match
243+
val funtpe = recheck(tree.fun)
244+
funtpe.widen match
219245
case fntpe: MethodType =>
220246
assert(fntpe.paramInfos.hasSameLengthAs(tree.args))
221247
val formals =
222248
if tree.symbol.is(JavaDefined) then mapJavaArgs(fntpe.paramInfos)
223249
else fntpe.paramInfos
224250
def recheckArgs(args: List[Tree], formals: List[Type], prefs: List[ParamRef]): List[Type] = args match
225251
case arg :: args1 =>
226-
val argType = recheck(arg, formals.head)
252+
val argType = recheck(arg, formals.head.mapExprType)
227253
val formals1 =
228254
if fntpe.isParamDependent
229255
then formals.tail.map(_.substParam(prefs.head, argType))
@@ -235,6 +261,8 @@ abstract class Recheck extends Phase, SymTransformer:
235261
val argTypes = recheckArgs(tree.args, formals, fntpe.paramRefs)
236262
constFold(tree, instantiate(fntpe, argTypes, tree.fun.symbol))
237263
//.showing(i"typed app $tree : $fntpe with ${tree.args}%, % : $argTypes%, % = $result")
264+
case tp =>
265+
assert(false, i"unexpected type of ${tree.fun}: $funtpe")
238266

239267
def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type =
240268
recheck(tree.fun).widen match
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
object test:
2+
class Plan(elem: Plan)
3+
object SomePlan extends Plan(???)
4+
def f1(expr: (-> Plan) -> Plan): Plan = expr(SomePlan)
5+
f1 { onf => Plan(onf) }
6+
def f2(expr: (=> Plan) -> Plan): Plan = ???
7+
f2 { onf => Plan(onf) }
8+
def f3(expr: (-> Plan) => Plan): Plan = ???
9+
f1 { onf => Plan(onf) }
10+
def f4(expr: (=> Plan) => Plan): Plan = ???
11+
f2 { onf => Plan(onf) }

0 commit comments

Comments
 (0)