Skip to content

Commit f6d7a91

Browse files
committed
Streamline treatment of CaptureRefs
- use isTrackableRef everywhere for discrimination (instead of just checking the CaptureRef type) - streamline treatment of reach refs through `stripReach`
1 parent b176eca commit f6d7a91

File tree

6 files changed

+13
-14
lines changed

6 files changed

+13
-14
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte
6363
val elems = refs.elems.toList
6464
val elems1 = elems.mapConserve(tm)
6565
if elems1 eq elems then this
66-
else if elems1.forall(_.isInstanceOf[CaptureRef])
66+
else if elems1.forall(_.isTrackableRef)
6767
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
6868
else EmptyAnnotation
6969

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ extension (tree: Tree)
7878

7979
/** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */
8080
def toCaptureRef(using Context): CaptureRef = tree.tpe match
81-
case ref: CaptureRef => ref
81+
case ref: CaptureRef if ref.isTrackableRef => ref
8282
case tpe => throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer
8383

8484
/** Convert a @retains or @retainsByName annotation tree to the capture set it represents.

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,8 @@ sealed abstract class CaptureSet extends Showable:
156156
case y: TermRef => !y.isReach && (y.prefix eq x)
157157
case _ => false
158158
|| x.match
159-
case x: TermRef if x.isReach =>
160-
y.match
161-
case y: TermRef if y.isReach => x.reachPrefix.subsumes(y.reachPrefix)
162-
case _ => x.reachPrefix.subsumes(y)
163-
case _ =>
164-
false
159+
case x: TermRef if x.isReach => x.stripReach.subsumes(y.stripReach)
160+
case _ => false
165161

166162
/** {x} <:< this where <:< is subcapturing, but treating all variables
167163
* as frozen.
@@ -505,7 +501,7 @@ object CaptureSet:
505501
if elem.isRootCapability then !noUniversal
506502
else elem match
507503
case elem: TermRef =>
508-
if elem.isReach then levelOK(elem.reachPrefix)
504+
if elem.isReach then levelOK(elem.stripReach)
509505
else if levelLimit.exists then
510506
var sym = elem.symbol
511507
if sym.isLevelOwner then sym = sym.owner

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,15 +1249,15 @@ class CheckCaptures extends Recheck, SymTransformer:
12491249
val checker = new TypeTraverser:
12501250
private var allowed: SimpleIdentitySet[TermParamRef] = SimpleIdentitySet.empty
12511251

1252-
private def isAllowed(ref: CaptureRef): Boolean = ref match
1252+
private def isAllowed(ref: CaptureRef): Boolean = ref.stripReach match
12531253
case ref: TermParamRef => allowed.contains(ref)
12541254
case _ => true
12551255

12561256
private def healCaptureSet(cs: CaptureSet): Unit =
12571257
cs.ensureWellformed: elem =>
12581258
ctx ?=>
12591259
var seen = new util.HashSet[CaptureRef]
1260-
def recur(ref: CaptureRef): Unit = ref match
1260+
def recur(ref: CaptureRef): Unit = ref.stripReach match
12611261
case ref: TermParamRef
12621262
if !allowed.contains(ref) && !seen.contains(ref) =>
12631263
seen += ref

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,6 +2177,8 @@ object Types {
21772177
/** Is this a reach reference of the form `x*`? */
21782178
def isReach(using Context): Boolean = false // overridden in TermRef
21792179

2180+
def stripReach(using Context): CaptureRef = this // overridden in TermRef
2181+
21802182
/** Is this reference the generic root capability `cap` ? */
21812183
def isRootCapability(using Context): Boolean = false
21822184

@@ -2916,13 +2918,14 @@ object Types {
29162918
override def isReach(using Context): Boolean =
29172919
name == nme.CC_REACH && symbol == defn.Any_ccReach
29182920

2919-
def reachPrefix: CaptureRef = prefix.asInstanceOf[CaptureRef]
2921+
override def stripReach(using Context): CaptureRef =
2922+
if isReach then prefix.asInstanceOf[CaptureRef] else this
29202923

29212924
override def isRootCapability(using Context): Boolean =
29222925
name == nme.CAPTURE_ROOT && symbol == defn.captureRoot
29232926

29242927
override def normalizedRef(using Context): CaptureRef =
2925-
if isReach then TermRef(reachPrefix.normalizedRef, name, denot)
2928+
if isReach then TermRef(stripReach.normalizedRef, name, denot)
29262929
else if isTrackableRef then symbol.termRef
29272930
else this
29282931
}

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
377377
def toTextRef(tp: SingletonType): Text = controlled {
378378
tp match {
379379
case tp: TermRef =>
380-
if tp.isReach then toTextRef(tp.reachPrefix) ~ "*"
380+
if tp.isReach then toTextRef(tp.stripReach) ~ "*"
381381
else toTextPrefixOf(tp) ~ selectionString(tp)
382382
case tp: ThisType =>
383383
nameString(tp.cls) + ".this"

0 commit comments

Comments
 (0)