Skip to content

Commit 8646308

Browse files
committed
Handle outer class roots when instantiating class members
The problem arises if we have a class like the one in pos-custom-args/captures/refs.scala: ```scala class MonoRef(init: Proc): type MonoProc = Proc var x: MonoProc = init def getX: MonoProc = x def setX(x: MonoProc): Unit = this.x = x ``` The type of `getX` and `setX` refer to the local root capability of class `MonoRef`. When we call `m.getX` or `m.setX` in `m: MonoRef`, these occurrences have to be adapted to capture roots in the scope of the selection. We determine these roots by inspecting the capture set of `m` and picking a root that corresponds to it.
1 parent 2d07bd5 commit 8646308

File tree

9 files changed

+256
-20
lines changed

9 files changed

+256
-20
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ trait FollowAliases extends TypeMap:
8989
mapOver(t)
9090

9191
class mapRoots(from0: CaptureRoot, to: CaptureRoot)(using Context) extends BiTypeMap, FollowAliases:
92-
thisMap =>
93-
9492
val from = from0.followAlias
9593

9694
//override val toString = i"mapRoots($from, $to)"
@@ -467,12 +465,10 @@ extension (sym: Symbol)
467465
else newRoot
468466
ccState.localRoots.getOrElseUpdate(owner, lclRoot)
469467

470-
def maxNested(other: Symbol, pickFirstOnConflict: Boolean = false)(using Context): Symbol =
468+
def maxNested(other: Symbol, onConflict: (Symbol, Symbol) => Context ?=> Symbol)(using Context): Symbol =
471469
if !sym.exists || other.isContainedIn(sym) then other
472470
else if !other.exists || sym.isContainedIn(other) then sym
473-
else
474-
assert(pickFirstOnConflict, i"incomparable nesting: $sym and $other")
475-
sym
471+
else onConflict(sym, other)
476472

477473
def minNested(other: Symbol)(using Context): Symbol =
478474
if !other.exists || other.isContainedIn(sym) then sym

compiler/src/dotty/tools/dotc/cc/CaptureRoot.scala

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ package cc
44

55
import core.*
66
import Types.*, Symbols.*, Contexts.*, Annotations.*, Flags.*
7+
import config.Printers.capt
78
import Hashable.Binders
89
import printing.Showable
910
import util.SimpleIdentitySet
10-
import Decorators.i
11+
import Decorators.*
12+
import StdNames.nme
1113
import scala.annotation.constructorOnly
1214
import scala.annotation.internal.sharable
1315

@@ -43,11 +45,31 @@ object CaptureRoot:
4345
case _ =>
4446
myAlias = r
4547

48+
/** A fresh var with the same limits and outerRoots as this one */
49+
def fresh(using Context): Var =
50+
val r = Var(owner, NoSymbol)
51+
r.innerLimit = innerLimit
52+
r.outerLimit = outerLimit
53+
r.outerRoots = outerRoots
54+
r
55+
56+
/** A fresh var that is enclosed by all roots in `rs`.
57+
* @throws A NoCommonRoot exception if this is not possible
58+
* since root scopes dont' overlap.
59+
*/
60+
def freshEnclosedBy(rs: CaptureRoot*)(using Context): CaptureRoot =
61+
val r = fresh
62+
if rs.forall(_.encloses(r)) then r else throw NoCommonRoot(rs*)
63+
4664
def computeHash(bs: Binders): Int = hash
4765
def hash: Int = System.identityHashCode(this)
4866
def underlying(using Context): Type = defn.Caps_Cap.typeRef
4967
end Var
5068

69+
class NoCommonRoot(rs: CaptureRoot*)(using Context) extends Exception(
70+
i"No common capture root nested in ${rs.mkString(" and ")}"
71+
)
72+
5173
extension (r: CaptureRoot)
5274

5375
def followAlias(using Context): CaptureRoot = r match
@@ -83,7 +105,9 @@ object CaptureRoot:
83105
else if !r2.innerLimit.isContainedIn(r1.outerLimit) then false // no overlap
84106
else if r1.outerRoots.contains(r2) then // unify
85107
r1.alias = r2
86-
r2.outerLimit = r1.outerLimit.maxNested(r2.outerLimit)
108+
r2.outerLimit =
109+
r1.outerLimit.maxNested(r2.outerLimit,
110+
onConflict = (_, _) => throw NoCommonRoot(r1, r2))
87111
r2.innerLimit = r1.innerLimit.minNested(r2.innerLimit)
88112
true
89113
else
@@ -93,6 +117,75 @@ object CaptureRoot:
93117
r2.outerRoots -= r2
94118
false
95119
end encloses
120+
end extension
121+
122+
/** The capture root enclosed by `root1` and `root2`.
123+
* If one of these is a Var, create a fresh Var with the appropriate constraints.
124+
* If the scopes of `root1` and `root2` don't overlap, thow a `NoCommonRoot` exception.
125+
*/
126+
def lub(root1: CaptureRoot, root2: CaptureRoot)(using Context): CaptureRoot =
127+
val (r1, r2) = (root1.followAlias, root2.followAlias)
128+
if r1 eq r2 then r1
129+
else (r1, r2) match
130+
case (r1: TermRef, r2: TermRef) =>
131+
r1.localRootOwner.maxNested(r2.localRootOwner,
132+
onConflict = (_, _) => throw NoCommonRoot(r1, r2)
133+
).termRef
134+
case (r1: TermRef, r2: Var) =>
135+
r2.freshEnclosedBy(r1, r2)
136+
case (r1: Var, r2) =>
137+
r1.freshEnclosedBy(r1, r2)
138+
139+
/** A map that instantiates all outer class roots in the info of `sym`
140+
* according to prefix `pre`. This is called for adapting the info of
141+
* a selection `pre.sym`. The logic of the function is modeled after
142+
* AsSeenFrom. But where AsSeenFrom maps a `this` of class `C` to a corresponding
143+
* prefix, the present method maps a local root corresponding to a class to
144+
* the root implied by the capture set of the corresponding prefix.
145+
* @param sym the class member symbol whose info is mapped
146+
* @param pre the prefix from which `sym` is selected
147+
* @param deafilt the capture root to use if the capture set of the corresponding
148+
* prfefix is empty.
149+
*/
150+
class instantiateOuterClassRoots(sym: Symbol, pre: Type, default: CaptureRoot)(using Context) extends ApproximatingTypeMap:
151+
val cls = sym.owner.asClass
152+
153+
def apply(tp: Type): Type =
154+
155+
/** Analogous to `toPrefix` in `AssSeenFromMap`, but result prefix gets
156+
* further mapped to a capture root via `impliedRoot`.
157+
*/
158+
def mapCaptureRoot(pre: Type, cls: Symbol, thiscls: ClassSymbol, fallBack: CaptureRoot): CaptureRoot =
159+
if (pre eq NoType) || (pre eq NoPrefix) || (cls is PackageClass) then
160+
fallBack
161+
else pre match
162+
case pre: SuperType =>
163+
mapCaptureRoot(pre.thistpe, cls, thiscls, fallBack)
164+
case _ =>
165+
if thiscls.derivesFrom(cls) && pre.baseType(thiscls).exists then
166+
pre.captureSet.impliedRoot(default)
167+
else if pre.termSymbol.is(Package) && !thiscls.is(Package) then
168+
mapCaptureRoot(pre.select(nme.PACKAGE), cls, thiscls, fallBack)
169+
else
170+
mapCaptureRoot(pre.baseType(cls).normalizedPrefix, cls.owner, thiscls, fallBack)
171+
172+
def instRoot(elem: CaptureRef): CaptureRef = elem match
173+
case elem: TermRef
174+
if elem.name == nme.LOCAL_CAPTURE_ROOT && elem.symbol.owner.isLocalDummy =>
175+
mapCaptureRoot(pre, cls, elem.localRootOwner.asClass, elem)
176+
.showing(i"mapped capture root $elem in $cls to $result", capt)
177+
case _ =>
178+
elem
179+
180+
tp match
181+
case t @ CapturingType(parent, refs) =>
182+
val elems = refs.elems.toList
183+
val elems1 = elems.mapConserve(instRoot)
184+
val refs1 = if elems1 eq elems then refs else CaptureSet(elems1*)
185+
t.derivedCapturingType(apply(parent), refs1)
186+
case _ =>
187+
mapOver(tp)
188+
end instantiateOuterClassRoots
96189

97190
end CaptureRoot
98191

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ sealed abstract class CaptureSet extends Showable:
233233
else if that.subCaptures(this, frozen = true).isOK then this
234234
else if this.isConst && that.isConst then Const(this.elems ++ that.elems)
235235
else Var(
236-
this.levelLimit.maxNested(that.levelLimit, pickFirstOnConflict = true),
236+
this.levelLimit.maxNested(that.levelLimit, onConflict = (sym1, sym2) => sym1),
237237
this.elems ++ that.elems)
238238
.addAsDependentTo(this).addAsDependentTo(that)
239239

@@ -315,6 +315,22 @@ sealed abstract class CaptureSet extends Showable:
315315
def substParams(tl: BindingType, to: List[Type])(using Context) =
316316
map(Substituters.SubstParamsMap(tl, to))
317317

318+
/** The capture root that corresponds to this capture set. This is:
319+
* - if the capture set is a Var with a defined level limit, the
320+
* associated capture root,
321+
* - otherwise, if the set is nonempty, the innermost root such
322+
* that some element of the set subcaptures this root,
323+
* - otherwise, if the set is empty, `default`.
324+
*/
325+
def impliedRoot(default: CaptureRoot)(using Context): CaptureRoot =
326+
if levelLimit.exists then levelLimit.localRoot.termRef
327+
else if elems.isEmpty then default
328+
else elems.toList
329+
.map:
330+
case elem: CaptureRoot if elem.isLocalRootCapability => elem
331+
case elem => elem.captureSetOfInfo.impliedRoot(default)
332+
.reduce((x: CaptureRoot, y: CaptureRoot) => CaptureRoot.lub(x, y))
333+
318334
/** Invoke handler if this set has (or later aquires) the root capability `cap` */
319335
def disallowRootCapability(handler: () => Context ?=> Unit)(using Context): this.type =
320336
if isUniversal then handler()

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ class CheckCaptures extends Recheck, SymTransformer:
319319
includeCallCaptures(tree.symbol, tree.srcPos)
320320
else
321321
markFree(tree.symbol, tree.srcPos)
322-
instantiateLocalRoots(tree.symbol, pt):
322+
instantiateLocalRoots(tree.symbol, NoPrefix, pt):
323323
super.recheckIdent(tree, pt)
324324

325325
/** A specialized implementation of the selection rule.
@@ -349,7 +349,7 @@ class CheckCaptures extends Recheck, SymTransformer:
349349

350350
val selType = recheckSelection(tree, qualType, name, disambiguate)
351351
val selCs = selType.widen.captureSet
352-
instantiateLocalRoots(tree.symbol, pt):
352+
instantiateLocalRoots(tree.symbol, qualType, pt):
353353
if selCs.isAlwaysEmpty || selType.widen.isBoxedCapturing || qualType.isBoxedCapturing then
354354
selType
355355
else
@@ -370,14 +370,23 @@ class CheckCaptures extends Recheck, SymTransformer:
370370
* - `tp` is the type of a function that gets applied, either as a method
371371
* or as a function value that gets applied.
372372
*/
373-
def instantiateLocalRoots(sym: Symbol, pt: Type)(tp: Type)(using Context): Type =
373+
def instantiateLocalRoots(sym: Symbol, pre: Type, pt: Type)(tp: Type)(using Context): Type =
374374
def canInstantiate =
375375
sym.is(Method, butNot = Accessor)
376376
|| sym.isTerm && defn.isFunctionType(sym.info) && pt == AnySelectionProto
377-
if sym.skipConstructor.isLevelOwner && canInstantiate then
377+
if canInstantiate then
378378
val tpw = tp.widen
379-
val tp1 = mapRoots(sym.localRoot.termRef, CaptureRoot.Var(ctx.owner, sym))(tpw)
380-
.showing(i"INST $sym: $tp, ${sym.localRoot} = $result", ccSetup)
379+
var tp1 = tpw
380+
val rootVar = CaptureRoot.Var(ctx.owner, sym)
381+
if sym.skipConstructor.isLevelOwner then
382+
tp1 = mapRoots(sym.localRoot.termRef, rootVar)(tp1)
383+
if tp1 ne tpw then
384+
ccSetup.println(i"INST local $sym: $tp, ${sym.localRoot} = $tp1")
385+
if sym.owner.isClass then
386+
val tp2 = CaptureRoot.instantiateOuterClassRoots(sym, pre, rootVar)(tp1)
387+
if tp2 ne tp1 then
388+
ccSetup.println(i"INST class $sym: $tp, ${sym.localRoot} in $pre = $tp2")
389+
tp1 = tp2
381390
if tpw eq tp1 then tp else tp1
382391
else
383392
tp
@@ -695,6 +704,9 @@ class CheckCaptures extends Recheck, SymTransformer:
695704
case _ =>
696705
val res =
697706
try super.recheck(tree, pt)
707+
catch case ex: CaptureRoot.NoCommonRoot =>
708+
report.error(ex.getMessage.nn)
709+
tree.tpe
698710
finally curEnv = saved
699711
if tree.isTerm && !pt.isBoxedCapturing then
700712
markFree(res.boxedCaptureSet, tree.srcPos)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/pairs.scala:15:31 ----------------------------------------
2+
15 | val x1c: Cap^ ->{c} Unit = x1 // error
3+
| ^^
4+
| Found: (x$0: Cap^{cap[test]}) ->{x1} Unit
5+
| Required: Cap^{cap[x1c]} ->{c} Unit
6+
|
7+
| longer explanation available when compiling with `-explain`
8+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/pairs.scala:17:30 ----------------------------------------
9+
17 | val y1c: Cap ->{d} Unit = y1 // error
10+
| ^^
11+
| Found: (x$0: Cap^{cap[test]}) ->{y1} Unit
12+
| Required: Cap^{cap[y1c]} ->{d} Unit
13+
|
14+
| longer explanation available when compiling with `-explain`
15+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/pairs.scala:34:30 ----------------------------------------
16+
34 | val x1c: Cap ->{c} Unit = x1 // error
17+
| ^^
18+
| Found: (x$0: Cap^{cap[test]}) ->{x1} Unit
19+
| Required: Cap^{cap[x1c]} ->{c} Unit
20+
|
21+
| longer explanation available when compiling with `-explain`
22+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/pairs.scala:36:30 ----------------------------------------
23+
36 | val y1c: Cap ->{d} Unit = y1 // error
24+
| ^^
25+
| Found: (x$0: Cap^{cap[test]}) ->{y1} Unit
26+
| Required: Cap^{cap[y1c]} ->{d} Unit
27+
|
28+
| longer explanation available when compiling with `-explain`
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
@annotation.capability class Cap
2+
3+
object Monomorphic:
4+
5+
class Pair(x: Cap => Unit, y: Cap => Unit):
6+
type PCap = Cap
7+
def fst: PCap ->{x} Unit = x
8+
def snd: PCap ->{y} Unit = y
9+
10+
def test(c: Cap, d: Cap) =
11+
def f(x: Cap): Unit = if c == x then ()
12+
def g(x: Cap): Unit = if d == x then ()
13+
val p = Pair(f, g)
14+
val x1 = p.fst
15+
val x1c: Cap^ ->{c} Unit = x1 // error
16+
val y1 = p.snd
17+
val y1c: Cap ->{d} Unit = y1 // error
18+
19+
object Monomorphic2:
20+
21+
class Pair(x: Cap => Unit, y: Cap => Unit):
22+
def fst: Cap^{cap[Pair]} ->{x} Unit = x
23+
def snd: Cap^{cap[Pair]} ->{y} Unit = y
24+
25+
class Pair2(x: Cap => Unit, y: Cap => Unit):
26+
def fst: Cap^{cap[Pair2]} => Unit = x
27+
def snd: Cap^{cap[Pair2]} => Unit = y
28+
29+
def test(c: Cap, d: Cap) =
30+
def f(x: Cap): Unit = if c == x then ()
31+
def g(x: Cap): Unit = if d == x then ()
32+
val p = Pair(f, g)
33+
val x1 = p.fst
34+
val x1c: Cap ->{c} Unit = x1 // error
35+
val y1 = p.snd
36+
val y1c: Cap ->{d} Unit = y1 // error
37+
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
class C
2+
type Cap = C^
3+
4+
class A
5+
class B
6+
7+
class Foo(x: Cap):
8+
9+
def foo: A ->{cap[Foo]} Unit = ???
10+
11+
class Bar(y: Cap):
12+
13+
def bar: B ->{cap[Bar]} Unit = ???
14+
15+
def f(a: A ->{cap[Foo]} Unit, b: B ->{cap[Bar]} Unit)
16+
: (A ->{a} Unit, B ->{b} Unit)
17+
= (a, b)
18+
19+
def test(c1: Cap, c2: Cap) =
20+
val x = Foo(c1)
21+
val y = x.Bar(c2)
22+
val xfoo = x.foo
23+
val ybar = y.bar
24+
val z1 = y.f(xfoo, ybar)
25+
val z2 = y.f(x.foo, y.bar)
26+
()

tests/pos-custom-args/captures/pairs.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ object Monomorphic:
2929
def g(x: Cap): Unit = if d == x then ()
3030
val p = Pair(f, g)
3131
val x1 = p.fst
32-
val x1c: Cap ->{c} Unit = x1
32+
val x1c: Cap^{cap[test]} ->{c} Unit = x1
3333
val y1 = p.snd
34-
val y1c: Cap ->{d} Unit = y1
34+
val y1c: Cap^{cap[test]} ->{d} Unit = y1
3535

3636
object Monomorphic2:
3737

@@ -48,7 +48,6 @@ object Monomorphic2:
4848
def g(x: Cap): Unit = if d == x then ()
4949
val p = Pair(f, g)
5050
val x1 = p.fst
51-
val x1c: Cap ->{c} Unit = x1
51+
val x1c: Cap^{cap[test]} ->{c} Unit = x1
5252
val y1 = p.snd
53-
val y1c: Cap ->{d} Unit = y1
54-
53+
val y1c: Cap^{cap[test]} ->{d} Unit = y1
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
type Proc = () => Unit
2+
3+
class MonoRef(init: Proc):
4+
type MonoProc = Proc
5+
var x: MonoProc = init
6+
def getX: MonoProc = x
7+
def setX(x: MonoProc): Unit = this.x = x
8+
9+
def test(p: Proc) =
10+
val x = MonoRef(p)
11+
x.setX(p)
12+
val y = x.getX
13+
val yc1: Proc = y
14+
val yc2: () ->{x} Unit = y
15+
val yc3: () ->{cap[test]} Unit = y
16+
17+
class MonoRef2(init: () => Unit):
18+
var x: () ->{cap[MonoRef2]} Unit = init
19+
def getX: () ->{cap[MonoRef2]} Unit = x
20+
def setX(x: () ->{cap[MonoRef2]} Unit): Unit = this.x = x
21+
22+
def test2(p: Proc) =
23+
val x = MonoRef2(p)
24+
x.setX(p)
25+
val y = x.getX
26+
val yc1: Proc = y
27+
val yc2: () ->{x} Unit = y
28+
val yc3: () ->{cap[test2]} Unit = y
29+

0 commit comments

Comments
 (0)