Skip to content

Commit 8f5ef66

Browse files
committed
Add capture checks for mutable variables
- Mutable variables have boxed types, so that we do not need to track them when computing capture sets of classes. - Mutable variable types cannot capture `*` in order to prevent scope extrusion.
1 parent e12648b commit 8f5ef66

File tree

5 files changed

+102
-18
lines changed

5 files changed

+102
-18
lines changed

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
7676
val symd = sym.denot
7777
symd.validFor.firstPhaseId == thisPhase.id && (sym.originDenotation ne symd)
7878

79-
def transformType(tp: Type, inferred: Boolean)(using Context): Type = tp
79+
def transformType(tp: Type, inferred: Boolean, boxed: Boolean = false)(using Context): Type = tp
8080

8181
object transformTypes extends TreeTraverser:
8282

@@ -110,12 +110,19 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
110110
mapOver(t)
111111
end SubstParams
112112

113+
private def transformTT(tree: TypeTree, boxed: Boolean)(using Context) =
114+
transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree], boxed).rememberFor(tree)
115+
113116
def traverse(tree: Tree)(using Context) =
114-
traverseChildren(tree)
115117
tree match
116-
118+
case tree @ ValDef(_, tpt: TypeTree, _) if tree.symbol.is(Mutable) =>
119+
transformTT(tpt, boxed = true)
120+
traverse(tree.rhs)
121+
case _ =>
122+
traverseChildren(tree)
123+
tree match
117124
case tree: TypeTree =>
118-
transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree]).rememberFor(tree)
125+
transformTT(tree, boxed = false)
119126
case tree: ValOrDefDef =>
120127
val sym = tree.symbol
121128

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class CheckCaptures extends Recheck:
115115
class CaptureChecker(ictx: Context) extends Rechecker(ictx):
116116
import ast.tpd.*
117117

118-
override def transformType(tp: Type, inferred: Boolean)(using Context): Type =
118+
override def transformType(tp: Type, inferred: Boolean, boxed: Boolean)(using Context): Type =
119119

120120
def addInnerVars(tp: Type): Type = tp match
121121
case tp @ AppliedType(tycon, args) =>
@@ -191,15 +191,15 @@ class CheckCaptures extends Recheck:
191191
apply(parent)
192192
case _ =>
193193
mapOver(t)
194-
addVars(addFunctionRefinements(cleanup(tp)))
194+
addVars(addFunctionRefinements(cleanup(tp)), boxed)
195195
.showing(i"reinfer $tp --> $result", capt)
196196
else
197-
val addBoxes = new TypeTraverser:
198-
def setBoxed(t: Type) = t match
199-
case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot =>
200-
annot.tree.setBoxedCapturing()
201-
case _ =>
197+
def setBoxed(t: Type) = t match
198+
case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot =>
199+
annot.tree.setBoxedCapturing()
200+
case _ =>
202201

202+
val addBoxes = new TypeTraverser:
203203
def traverse(t: Type) =
204204
t match
205205
case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) =>
@@ -208,8 +208,8 @@ class CheckCaptures extends Recheck:
208208
setBoxed(lo); setBoxed(hi)
209209
case _ =>
210210
traverseChildren(t)
211-
end addBoxes
212211

212+
if boxed then setBoxed(tp)
213213
addBoxes.traverse(tp)
214214
tp
215215
end transformType
@@ -417,12 +417,15 @@ class CheckCaptures extends Recheck:
417417
val what = if ref.isRootCapability then "universal" else "global"
418418
if isGlobal then
419419
val notAllowed = i" is not allowed to capture the $what capability $ref"
420-
def msg = tree match
421-
case tree: InferredTypeTree =>
422-
i"""inferred type argument ${knownType(tree)}$notAllowed
423-
|
424-
|The inferred arguments are: [${allArgs.map(knownType)}%, %]"""
425-
case _ => s"type argument$notAllowed"
420+
def msg =
421+
if allArgs.isEmpty then
422+
i"type of mutable variable ${knownType(tree)}$notAllowed"
423+
else tree match
424+
case tree: InferredTypeTree =>
425+
i"""inferred type argument ${knownType(tree)}$notAllowed
426+
|
427+
|The inferred arguments are: [${allArgs.map(knownType)}%, %]"""
428+
case _ => s"type argument$notAllowed"
426429
report.error(msg, tree.srcPos)
427430

428431
object PostRefinerCheck extends TreeTraverser:
@@ -463,6 +466,8 @@ class CheckCaptures extends Recheck:
463466
|The type needs to be declared explicitly.""", t.srcPos)
464467
case _ =>
465468
inferred.foreachPart(checkPure, StopAt.Static)
469+
case t: ValDef if t.symbol.is(Mutable) =>
470+
checkNotGlobal(t.tpt)
466471
case _ =>
467472
traverseChildren(tree)
468473

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/vars.scala:11:24 -----------------------------------------
2+
11 | val z2c: () => Unit = z2 // error
3+
| ^^
4+
| Found: (z2 : {x, cap1} () => Unit)
5+
| Required: () => Unit
6+
7+
longer explanation available when compiling with `-explain`
8+
-- Error: tests/neg-custom-args/captures/vars.scala:13:10 --------------------------------------------------------------
9+
13 | var a: {*} String => String = f // error
10+
| ^^^^^^^^^^^^^^^^^^^
11+
| type of mutable variable box {*} String => String is not allowed to capture the universal capability (* : Any)
12+
-- Error: tests/neg-custom-args/captures/vars.scala:27:2 ---------------------------------------------------------------
13+
27 | local { cap3 => // error
14+
| ^^^^^
15+
|inferred type argument {*} (x$0: ? String) => ? String is not allowed to capture the universal capability (* : Any)
16+
|
17+
|The inferred arguments are: [{*} (x$0: ? String) => ? String]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
class CC
2+
type Cap = {*} CC
3+
4+
def test(cap1: Cap, cap2: Cap) =
5+
def f(x: String): String = if cap1 == cap1 then "" else "a"
6+
var x = f
7+
val y = x
8+
val z = () => if x("") == "" then "a" else "b"
9+
val zc: {cap1} () => String = z
10+
val z2 = () => { x = identity }
11+
val z2c: () => Unit = z2 // error
12+
13+
var a: {*} String => String = f // error
14+
15+
def scope =
16+
val cap3: Cap = CC()
17+
def g(x: String): String = if cap3 == cap3 then "" else "a"
18+
a = g
19+
val gc = g
20+
g
21+
22+
val s = scope
23+
val sc: {*} String => String = scope
24+
25+
def local[T](op: Cap => T): T = op(CC())
26+
27+
local { cap3 => // error
28+
def g(x: String): String = if cap3 == cap3 then "" else "a"
29+
g
30+
}
31+
32+
class Ref:
33+
var elem: {cap1} String => String = null
34+
35+
val r = Ref()
36+
r.elem = f
37+
val fc = r.elem
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class CC
2+
type Cap = {*} CC
3+
4+
def test(cap1: Cap, cap2: Cap) =
5+
def f(x: String): String = if cap1 == cap1 then "" else "a"
6+
var x = f
7+
val y = x
8+
val z = () => if x("") == "" then "a" else "b"
9+
val zc: {cap1} () => String = z
10+
val z2 = () => { x = identity }
11+
val z2c: {cap1} () => Unit = z2
12+
13+
class Ref:
14+
var elem: {cap1} String => String = null
15+
16+
val r = Ref()
17+
r.elem = f
18+
val fc: {cap1} String => String = r.elem

0 commit comments

Comments
 (0)