Skip to content

Commit 5af9f6d

Browse files
committed
Boxed CapturingTypes
1 parent 934b6e0 commit 5af9f6d

23 files changed

+190
-128
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import printing.Printer
1212
import printing.Texts.Text
1313

1414

15-
case class CaptureAnnotation(refs: CaptureSet) extends Annotation:
15+
case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation:
1616
import CaptureAnnotation.*
1717
import tpd.*
1818

@@ -30,19 +30,20 @@ case class CaptureAnnotation(refs: CaptureSet) extends Annotation:
3030
override def derivedAnnotation(tree: Tree)(using Context): Annotation =
3131
unsupported("derivedAnnotation(Tree)")
3232

33-
def derivedAnnotation(refs: CaptureSet)(using Context): Annotation =
34-
if this.refs eq refs then this else CaptureAnnotation(refs)
33+
def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation =
34+
if (this.refs eq refs) && (this.boxed == boxed) then this
35+
else CaptureAnnotation(refs, boxed)
3536

3637
override def sameAnnotation(that: Annotation)(using Context): Boolean = that match
37-
case CaptureAnnotation(refs2) => refs == refs2
38+
case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2
3839
case _ => false
3940

4041
override def mapWith(tp: TypeMap)(using Context) =
4142
val elems = refs.elems.toList
4243
val elems1 = elems.mapConserve(tp)
4344
if elems1 eq elems then this
4445
else if elems1.forall(_.isInstanceOf[CaptureRef])
45-
then CaptureAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*))
46+
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
4647
else EmptyAnnotation
4748

4849
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
@@ -53,10 +54,10 @@ case class CaptureAnnotation(refs: CaptureSet) extends Annotation:
5354

5455
override def toText(printer: Printer): Text = refs.toText(printer)
5556

56-
override def hash: Int = refs.hashCode
57+
override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0)
5758

5859
override def eql(that: Annotation) = that match
59-
case that: CaptureAnnotation => this.refs eq that.refs
60+
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed)
6061
case _ => false
6162

6263
end CaptureAnnotation

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import util.Property.Key
1111
import tpd.*
1212

1313
private val Captures: Key[CaptureSet] = Key()
14+
private val IsBoxed: Key[Unit] = Key()
1415

1516
def retainedElems(tree: Tree)(using Context): List[Tree] = tree match
1617
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
@@ -29,47 +30,31 @@ extension (tree: Tree)
2930
tree.putAttachment(Captures, refs)
3031
refs
3132

33+
def isBoxedCapturing(using Context) =
34+
tree.hasAttachment(IsBoxed)
35+
3236
extension (tp: Type)
3337

3438
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match
35-
case CapturingType(p, r) =>
39+
case CapturingType(p, r, b) =>
3640
if (parent eq p) && (refs eq r) then tp
37-
else CapturingType(parent, refs)
41+
else CapturingType(parent, refs, b)
3842

3943
/** If this is type variable instantiated or upper bounded with a capturing type,
4044
* the capture set associated with that type. Extended to and-or types and
4145
* type proxies in the obvious way. If a term has a type with a boxed captureset,
4246
* that captureset counts towards the capture variables of the envirionment.
4347
*/
4448
def boxedCaptured(using Context): CaptureSet =
45-
def getBoxed(tp: Type, enabled: Boolean): CaptureSet = tp match
46-
case CapturingType(_, refs) if enabled => refs
47-
case tp: TypeVar => getBoxed(tp.underlying, enabled = true)
48-
case tp: TypeRef if tp.symbol == defn.AnyClass && enabled => CaptureSet.universal
49-
case tp: TypeProxy => getBoxed(tp.superType, enabled)
50-
case tp: AndType => getBoxed(tp.tp1, enabled) ++ getBoxed(tp.tp2, enabled)
51-
case tp: OrType => getBoxed(tp.tp1, enabled) ** getBoxed(tp.tp2, enabled)
49+
def getBoxed(tp: Type): CaptureSet = tp match
50+
case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty
51+
case tp: TypeProxy => getBoxed(tp.superType)
52+
case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2)
53+
case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2)
5254
case _ => CaptureSet.empty
53-
getBoxed(tp, enabled = false)
55+
getBoxed(tp)
5456

55-
/** If this type appears as an expected type of a term, does it imply
56-
* that the term should be boxed?
57-
* ^^^ Special treat Any? - but the current status is more conservative in that
58-
* it counts free variables in expressions that have Any as expected type.
59-
*/
60-
def needsBox(using Context): Boolean = tp match
61-
case _: TypeVar => true
62-
case tp: TypeRef =>
63-
tp.info match
64-
case TypeBounds(lo, _) => lo.needsBox
65-
case _ => false
66-
case tp: RefinedOrRecType => tp.parent.needsBox
67-
case CapturingType(_, _) => false
68-
case tp: AnnotatedType => tp.parent.needsBox
69-
case tp: LazyRef => tp.ref.needsBox
70-
case tp: AndType => tp.tp1.needsBox || tp.tp2.needsBox
71-
case tp: OrType => tp.tp1.needsBox && tp.tp2.needsBox
72-
case _ => false
57+
def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty
7358

7459
def canHaveInferredCapture(using Context): Boolean = tp match
7560
case tp: TypeRef if tp.symbol.isClass =>
@@ -84,7 +69,7 @@ extension (tp: Type)
8469
false
8570

8671
def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match
87-
case CapturingType(parent, _) =>
72+
case CapturingType(parent, _, _) =>
8873
parent.stripCapturing
8974
case atd @ AnnotatedType(parent, annot) =>
9075
atd.derivedAnnotatedType(parent.stripCapturing, annot)

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ sealed abstract class CaptureSet extends Showable:
161161
if tp.exists then OrType(tp, ref, soft = false) else ref)
162162

163163
def toRegularAnnotation(using Context): Annotation =
164-
Annotation(CaptureAnnotation(this).tree)
164+
Annotation(CaptureAnnotation(this, boxed = false).tree)
165165

166166
override def toText(printer: Printer): Text =
167167
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}")
@@ -400,7 +400,7 @@ object CaptureSet:
400400
tp.captureSet
401401
case _: TypeRef | _: TypeParamRef =>
402402
empty
403-
case CapturingType(parent, refs) =>
403+
case CapturingType(parent, refs, _) =>
404404
recur(parent) ++ refs
405405
case AppliedType(tycon, args) =>
406406
val cs = recur(tycon)

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ import Types.*, Symbols.*, Contexts.*
77

88
object CapturingType:
99

10-
def apply(parent: Type, refs: CaptureSet)(using Context): Type =
10+
def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type =
1111
if refs.isAlwaysEmpty then parent
12-
else AnnotatedType(parent, CaptureAnnotation(refs))
12+
else AnnotatedType(parent, CaptureAnnotation(refs, boxed))
1313

14-
def unapply(tp: AnnotatedType)(using Context) =
14+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
1515
if ctx.phase == Phases.checkCapturesPhase && tp.annot.symbol == defn.RetainsAnnot then
1616
tp.annot match
17-
case ann: CaptureAnnotation => Some((tp.parent, ann.refs))
18-
case ann => Some((tp.parent, ann.tree.toCaptureSet))
17+
case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed))
18+
case ann => Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
1919
else None
2020

2121
end CapturingType

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class Definitions {
265265
*/
266266
@tu lazy val AnyClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Any, Abstract, Nil), ensureCtor = false)
267267
def AnyType: TypeRef = AnyClass.typeRef
268-
@tu lazy val TopType: Type = CapturingType(AnyType, CaptureSet.universal)
268+
@tu lazy val TopType: Type = CapturingType(AnyType, CaptureSet.universal, boxed = false)
269269
@tu lazy val MatchableClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Matchable, Trait, AnyType :: Nil), ensureCtor = false)
270270
def MatchableType: TypeRef = MatchableClass.typeRef
271271
@tu lazy val AnyValClass: ClassSymbol =

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
329329
case tp: TypeVar =>
330330
val underlying1 = recur(tp.underlying, fromBelow)
331331
if underlying1 ne tp.underlying then underlying1 else tp
332-
case CapturingType(parent, refs) =>
332+
case CapturingType(parent, refs, _) =>
333333
val parent1 = recur(parent, fromBelow)
334334
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
335335
case tp: AnnotatedType =>

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2166,7 +2166,7 @@ object SymDenotations {
21662166
case tp: TypeParamRef => // uncachable, since baseType depends on context bounds
21672167
recur(TypeComparer.bounds(tp).hi)
21682168

2169-
case CapturingType(parent, refs) =>
2169+
case CapturingType(parent, refs, _) =>
21702170
tp.derivedCapturingType(recur(parent), refs)
21712171

21722172
case tp: TypeProxy =>

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
326326
compareWild
327327
case tp2: LazyRef =>
328328
isBottom(tp1) || !tp2.evaluating && recur(tp1, tp2.ref)
329-
case CapturingType(_, _) =>
329+
case CapturingType(_, _, _) =>
330330
secondTry
331331
case tp2: AnnotatedType if !tp2.isRefining =>
332332
recur(tp1, tp2.parent)
@@ -490,7 +490,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
490490
// and then need to check that they are indeed supertypes of the original types
491491
// under -Ycheck. Test case is i7965.scala.
492492

493-
case CapturingType(parent1, refs1) =>
493+
case CapturingType(parent1, refs1, _) =>
494494
if refs1.subCaptures(tp2.captureSet, frozenConstraint) == CaptureSet.CompareResult.OK then
495495
recur(parent1, tp2)
496496
else
@@ -749,7 +749,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
749749
false
750750
}
751751
compareTypeBounds
752-
case CapturingType(parent2, _) =>
752+
case CapturingType(parent2, _, _) =>
753753
recur(tp1, parent2) || fourthTry
754754
case tp2: AnnotatedType if tp2.isRefining =>
755755
(tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) &&
@@ -797,7 +797,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
797797
case tp: AppliedType => isNullable(tp.tycon)
798798
case AndType(tp1, tp2) => isNullable(tp1) && isNullable(tp2)
799799
case OrType(tp1, tp2) => isNullable(tp1) || isNullable(tp2)
800-
case CapturingType(tp1, _) => isNullable(tp1)
800+
case CapturingType(tp1, _, _) => isNullable(tp1)
801801
case _ => false
802802
}
803803
val sym1 = tp1.symbol
@@ -821,7 +821,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
821821
tp1 match
822822
case tp1: CaptureRef if tp1.isTracked =>
823823
val stripped = tp1w.stripCapturing
824-
tp1w = CapturingType(stripped, tp1.singletonCaptureSet)
824+
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false)
825825
case _ =>
826826
isSubType(tp1w, tp2, approx.addLow)
827827
}
@@ -2395,7 +2395,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
23952395
}
23962396
case tp1: TypeVar if tp1.isInstantiated =>
23972397
tp1.underlying & tp2
2398-
case CapturingType(parent1, refs1) =>
2398+
case CapturingType(parent1, refs1, _) =>
23992399
if tp2.captureSet.subCaptures(refs1, frozenConstraint) == CaptureSet.CompareResult.OK then
24002400
parent1 & tp2
24012401
else

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ object TypeOps:
164164
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
165165
// corrective steps, so no widening is wanted.
166166
simplify(l, theMap) | simplify(r, theMap)
167-
case CapturingType(parent, refs) =>
167+
case CapturingType(parent, refs, _) =>
168168
if !ctx.mode.is(Mode.Type)
169169
&& refs.subCaptures(parent.captureSet, frozen = true) == CompareResult.OK then
170170
simplify(parent, theMap)
@@ -283,7 +283,7 @@ object TypeOps:
283283
tp1 match {
284284
case tp1: RecType =>
285285
return tp1.rebind(approximateOr(tp1.parent, tp2))
286-
case CapturingType(parent1, refs1) =>
286+
case CapturingType(parent1, refs1, _) =>
287287
return tp1.derivedCapturingType(approximateOr(parent1, tp2), refs1)
288288
case err: ErrorType =>
289289
return err
@@ -292,7 +292,7 @@ object TypeOps:
292292
tp2 match {
293293
case tp2: RecType =>
294294
return tp2.rebind(approximateOr(tp1, tp2.parent))
295-
case CapturingType(parent2, refs2) =>
295+
case CapturingType(parent2, refs2, _) =>
296296
return tp2.derivedCapturingType(approximateOr(tp1, parent2), refs2)
297297
case err: ErrorType =>
298298
return err

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import scala.util.hashing.{ MurmurHash3 => hashing }
3838
import config.Printers.{core, typr, matchTypes}
3939
import reporting.{trace, Message}
4040
import java.lang.ref.WeakReference
41-
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems}
41+
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing}
4242
import CaptureSet.CompareResult
4343

4444
import scala.annotation.internal.sharable
@@ -203,7 +203,7 @@ object Types {
203203
else this1.underlying.isRef(sym, skipRefined)
204204
case this1: TypeVar =>
205205
this1.instanceOpt.isRef(sym, skipRefined)
206-
case CapturingType(_, _) =>
206+
case CapturingType(_, _, _) =>
207207
false
208208
case this1: AnnotatedType =>
209209
this1.parent.isRef(sym, skipRefined)
@@ -373,7 +373,7 @@ object Types {
373373
case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference
374374
case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference)
375375
case WildcardType(optBounds) => optBounds.unusableForInference
376-
case CapturingType(parent, refs) => parent.unusableForInference || refs.elems.exists(_.unusableForInference)
376+
case CapturingType(parent, refs, _) => parent.unusableForInference || refs.elems.exists(_.unusableForInference)
377377
case _: ErrorType => true
378378
case _ => false
379379

@@ -1382,7 +1382,7 @@ object Types {
13821382
case tp: TypeVar =>
13831383
val tp1 = tp.instanceOpt
13841384
if (tp1.exists) tp1.dealias1(keep) else tp
1385-
case tp @ CapturingType(parent, refs) => // ^^^ merge with below for efficiency?
1385+
case tp @ CapturingType(parent, refs, _) => // ^^^ merge with below for efficiency?
13861386
tp.derivedCapturingType(parent.dealias1(keep), refs)
13871387
case tp: AnnotatedType =>
13881388
val tp1 = tp.parent.dealias1(keep)
@@ -1845,13 +1845,14 @@ object Types {
18451845
}
18461846

18471847
def capturing(ref: CaptureRef)(using Context): Type =
1848-
if captureSet.accountsFor(ref) then this else CapturingType(this, ref.singletonCaptureSet)
1848+
if captureSet.accountsFor(ref) then this
1849+
else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing)
18491850

18501851
def capturing(cs: CaptureSet)(using Context): Type =
18511852
if cs.isConst && cs.subCaptures(captureSet, frozen = true) == CompareResult.OK then this
18521853
else this match
1853-
case CapturingType(parent, cs1) => parent.capturing(cs1 ++ cs)
1854-
case _ => CapturingType(this, cs)
1854+
case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs)
1855+
case _ => CapturingType(this, cs, this.isBoxedCapturing)
18551856

18561857
/** The set of distinct symbols referred to by this type, after all aliases are expanded */
18571858
def coveringSet(using Context): Set[Symbol] =
@@ -3694,7 +3695,7 @@ object Types {
36943695
case tp: AppliedType => tp.fold(status, compute(_, _, theAcc))
36953696
case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional)
36963697
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
3697-
case CapturingType(parent, refs) =>
3698+
case CapturingType(parent, refs, _) =>
36983699
(compute(status, parent, theAcc) /: refs.elems) {
36993700
(s, ref) => ref match
37003701
case tp: TermParamRef if tp.binder eq thisLambdaType => combine(s, CaptureDeps)
@@ -3763,7 +3764,7 @@ object Types {
37633764
def apply(tp: Type) = tp match {
37643765
case tp @ TermParamRef(`thisLambdaType`, _) =>
37653766
range(defn.NothingType, atVariance(1)(apply(tp.underlying)))
3766-
case CapturingType(parent, refs) =>
3767+
case CapturingType(parent, refs, boxed) =>
37673768
val parent1 = this(parent)
37683769
val elems1 = refs.elems.filter {
37693770
case tp @ TermParamRef(`thisLambdaType`, _) => false
@@ -3773,8 +3774,8 @@ object Types {
37733774
derivedCapturingType(tp, parent1, refs)
37743775
else
37753776
range(
3776-
CapturingType(parent1, CaptureSet(elems1)),
3777-
CapturingType(parent1, CaptureSet.universal))
3777+
CapturingType(parent1, CaptureSet(elems1), boxed),
3778+
CapturingType(parent1, CaptureSet.universal, boxed))
37783779
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
37793780
val parent1 = mapOver(parent)
37803781
if ann.symbol == defn.RetainsAnnot then
@@ -5579,7 +5580,7 @@ object Types {
55795580
case tp: ExprType =>
55805581
derivedExprType(tp, this(tp.resultType))
55815582

5582-
case CapturingType(parent, refs) =>
5583+
case CapturingType(parent, refs, _) =>
55835584
mapCapturingType(tp, parent, refs, variance)
55845585

55855586
case tp @ AnnotatedType(underlying, annot) =>
@@ -6051,7 +6052,7 @@ object Types {
60516052
val x2 = atVariance(0)(this(x1, tp.scrutinee))
60526053
foldOver(x2, tp.cases)
60536054

6054-
case CapturingType(parent, refs) =>
6055+
case CapturingType(parent, refs, _) =>
60556056
(this(x, parent) /: refs.elems)(this)
60566057

60576058
case AnnotatedType(underlying, annot) =>

compiler/src/dotty/tools/dotc/core/Variances.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ object Variances {
100100
v
101101
}
102102
varianceInArgs(varianceInType(tycon)(tparam), args, tycon.typeParams)
103-
case CapturingType(tp, _) =>
103+
case CapturingType(tp, _, _) =>
104104
varianceInType(tp)(tparam)
105105
case AnnotatedType(tp, annot) =>
106106
varianceInType(tp)(tparam) & varianceInAnnot(annot)(tparam)

0 commit comments

Comments
 (0)