Skip to content

Commit afddb45

Browse files
committed
Attempt to pass and check capability from parents correctly
1 parent 69ff121 commit afddb45

File tree

6 files changed

+55
-19
lines changed

6 files changed

+55
-19
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ extension (tp: Type)
207207
case _: TypeRef | _: AppliedType => tp.typeSymbol.hasAnnotation(defn.CapabilityAnnot)
208208
case _ => false
209209

210+
/** Check if the class has universal capability, which means:
211+
* 1. the class has a capability annotation,
212+
* 2. the class is an impure function type,
213+
* 3. or one of its base classes has universal capability.
214+
*/
215+
def hasUniversalCapability(using Context): Boolean = tp match
216+
case CapturingType(parent, ref) =>
217+
ref.isUniversal || parent.hasUniversalCapability
218+
case tp =>
219+
tp.isCapabilityClassRef || tp.parents.exists(_.hasUniversalCapability)
220+
210221
/** Drop @retains annotations everywhere */
211222
def dropAllRetains(using Context): Type = // TODO we should drop retains from inferred types before unpickling
212223
val tm = new TypeMap:

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -519,16 +519,6 @@ class CheckCaptures extends Recheck, SymTransformer:
519519
if sym.isConstructor then
520520
val cls = sym.owner.asClass
521521

522-
/** Check if the class or one of its parents has a root capability,
523-
* which means that the class has a capability annotation or an impure
524-
* function type.
525-
*/
526-
def hasUniversalCapability(tp: Type): Boolean = tp match
527-
case CapturingType(parent, ref) =>
528-
ref.isUniversal || hasUniversalCapability(parent)
529-
case tp =>
530-
tp.isCapabilityClassRef || tp.parents.exists(hasUniversalCapability)
531-
532522
/** First half of result pair:
533523
* Refine the type of a constructor call `new C(t_1, ..., t_n)`
534524
* to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked
@@ -538,7 +528,7 @@ class CheckCaptures extends Recheck, SymTransformer:
538528
*/
539529
def addParamArgRefinements(core: Type, initCs: CaptureSet): (Type, CaptureSet) =
540530
var refined: Type = core
541-
var allCaptures: CaptureSet = if hasUniversalCapability(core)
531+
var allCaptures: CaptureSet = if core.hasUniversalCapability
542532
then CaptureSet.universal else initCs
543533
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
544534
val getter = cls.info.member(getterName).suchThat(_.is(ParamAccessor)).symbol

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,6 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
269269
CapturingType(fntpe, cs, boxed = false)
270270
else fntpe
271271

272-
/** Map references to capability classes C to C^ */
273-
private def expandCapabilityClass(tp: Type): Type =
274-
if tp.isCapabilityClassRef
275-
then CapturingType(tp, defn.expandedUniversalSet, boxed = false)
276-
else tp
277-
278272
private def recur(t: Type): Type = normalizeCaptures(mapOver(t))
279273

280274
def apply(t: Type) =
@@ -297,7 +291,8 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
297291
case t: TypeVar =>
298292
this(t.underlying)
299293
case t =>
300-
if t.isCapabilityClassRef
294+
// Map references to capability classes C to C^
295+
if t.hasUniversalCapability
301296
then CapturingType(t, defn.expandedUniversalSet, boxed = false)
302297
else recur(t)
303298
end expandAliases

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,13 +893,20 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
893893
canWidenAbstract && acc(true, tp)
894894

895895
def tryBaseType(cls2: Symbol) =
896-
val base = nonExprBaseType(tp1, cls2)
896+
var base = nonExprBaseType(tp1, cls2)
897897
if base.exists && (base ne tp1)
898898
&& (!caseLambda.exists
899899
|| widenAbstractOKFor(tp2)
900900
|| tp1.widen.underlyingClassRef(refinementOK = true).exists)
901901
then
902902
def checkBase =
903+
// Strip existing capturing set from base type
904+
base = base.stripCapturing
905+
// Pass capture set of tp1 to base type
906+
tp1 match
907+
case tp1 @ CapturingType(_, refs1) =>
908+
base = CapturingType(base, refs1, tp1.isBoxed)
909+
case _ =>
903910
isSubType(base, tp2, if tp1.isRef(cls2) then approx else approx.addLow)
904911
&& recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp1) }
905912
if tp1.widenDealias.isInstanceOf[AndType] || base.isInstanceOf[OrType] then
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import annotation.capability
2+
3+
class C1
4+
@capability class C2 extends C1
5+
class C3 extends C2
6+
7+
def test =
8+
val x1: C1 = new C1
9+
val x2: C1 = new C2 // error
10+
val x3: C1 = new C3 // error
11+
12+
val y1: C2 = new C2
13+
val y2: C2 = new C3
14+
15+
val z1: C3 = new C3
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class F extends (Int => Unit) {
2+
def apply(x: Int): Unit = ()
3+
}
4+
5+
def test =
6+
val x1 = new (Int => Unit) {
7+
def apply(x: Int): Unit = ()
8+
}
9+
10+
val x2: Int -> Unit = new (Int => Unit) { // error
11+
def apply(x: Int): Unit = ()
12+
}
13+
14+
val y1: Int => Unit = new F
15+
val y2: Int -> Unit = new F // error
16+
17+
val z1 = () => ()
18+
val z2: () -> Unit = () => ()

0 commit comments

Comments
 (0)