Skip to content

Commit e434e5f

Browse files
committed
Allow new capture set syntax
- `^{xs}` in postfix - `->{xs}` after arrow - `any` instead of `*`
1 parent d02ecd0 commit e434e5f

File tree

12 files changed

+174
-134
lines changed

12 files changed

+174
-134
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,13 +1827,11 @@ object desugar {
18271827
case CapturingTypeTree(refs, parent) =>
18281828
// convert `{refs} T` to `T @retains refs`
18291829
// `{refs}-> T` to `-> (T @retainsByName refs)`
1830-
def annotate(annotName: TypeName, tp: Tree) =
1831-
Annotated(tp, New(scalaAnnotationDot(annotName), List(refs)))
18321830
parent match
18331831
case ByNameTypeTree(restpt) =>
1834-
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
1832+
cpy.ByNameTypeTree(parent)(makeRetaining(restpt, refs, tpnme.retainsByName))
18351833
case _ =>
1836-
annotate(tpnme.retains, parent)
1834+
makeRetaining(parent, refs, tpnme.retains)
18371835
case f: FunctionWithMods if f.hasErasedParams => makeFunctionWithValDefs(f, pt)
18381836
}
18391837
desugared.withSpan(tree.span)
@@ -1927,7 +1925,7 @@ object desugar {
19271925
}
19281926
tree match
19291927
case tree: FunctionWithMods =>
1930-
untpd.FunctionWithMods(applyVParams, tree.body, tree.mods, tree.erasedParams)
1928+
untpd.FunctionWithMods(applyVParams, result, tree.mods, tree.erasedParams)
19311929
case _ => untpd.Function(applyVParams, result)
19321930
}
19331931
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
150150
/** {x1, ..., xN} T (only relevant under captureChecking) */
151151
case class CapturingTypeTree(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
152152

153+
/** {x1, ..., xN} T (only relevant under captureChecking) */
154+
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
155+
153156
/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
154157
case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree
155158

@@ -501,6 +504,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
501504
def captureRoot(using Context): Select =
502505
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)
503506

507+
def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
508+
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))
509+
504510
def makeConstructor(tparams: List[TypeDef], vparamss: List[List[ValDef]], rhs: Tree = EmptyTree)(using Context): DefDef =
505511
DefDef(nme.CONSTRUCTOR, joinParams(tparams, vparamss), TypeTree(), rhs)
506512

@@ -658,6 +664,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
658664
case tree: Number if (digits == tree.digits) && (kind == tree.kind) => tree
659665
case _ => finalize(tree, untpd.Number(digits, kind))
660666
}
667+
def CapturesAndResult(tree: Tree)(refs: List[Tree], parent: Tree)(using Context): Tree = tree match
668+
case tree: CapturesAndResult if (refs eq tree.refs) && (parent eq tree.parent) => tree
669+
case _ => finalize(tree, untpd.CapturesAndResult(refs, parent))
670+
661671
def CapturingTypeTree(tree: Tree)(refs: List[Tree], parent: Tree)(using Context): Tree = tree match
662672
case tree: CapturingTypeTree if (refs eq tree.refs) && (parent eq tree.parent) => tree
663673
case _ => finalize(tree, untpd.CapturingTypeTree(refs, parent))
@@ -723,6 +733,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
723733
tree
724734
case MacroTree(expr) =>
725735
cpy.MacroTree(tree)(transform(expr))
736+
case CapturesAndResult(refs, parent) =>
737+
cpy.CapturesAndResult(tree)(transform(refs), transform(parent))
726738
case CapturingTypeTree(refs, parent) =>
727739
cpy.CapturingTypeTree(tree)(transform(refs), transform(parent))
728740
case _ =>
@@ -782,6 +794,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
782794
this(x, splice)
783795
case MacroTree(expr) =>
784796
this(x, expr)
797+
case CapturesAndResult(refs, parent) =>
798+
this(this(x, refs), parent)
785799
case CapturingTypeTree(refs, parent) =>
786800
this(this(x, refs), parent)
787801
case _ =>

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ object StdNames {
287287

288288
// Compiler-internal
289289
val CAPTURE_ROOT: N = "*"
290+
val CAPTURE_ROOT_ALT: N = "any"
290291
val CONSTRUCTOR: N = "<init>"
291292
val STATIC_CONSTRUCTOR: N = "<clinit>"
292293
val EVT2U: N = "evt2u$"
@@ -301,6 +302,7 @@ object StdNames {
301302
val THROWS: N = "$throws"
302303
val U2EVT: N = "u2evt$"
303304
val ALLARGS: N = "$allArgs"
305+
val UPARROW: N = "^"
304306

305307
final val Nil: N = "Nil"
306308
final val Predef: N = "Predef"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,7 @@ object Parsers {
14671467
if in.token == THIS then simpleRef()
14681468
else termIdent() match
14691469
case Ident(nme.CAPTURE_ROOT) => captureRoot
1470+
case Ident(nme.CAPTURE_ROOT_ALT) => captureRoot
14701471
case id => id
14711472

14721473
/** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking
@@ -1475,6 +1476,11 @@ object Parsers {
14751476
if in.token == RBRACE then Nil else commaSeparated(captureRef)
14761477
}
14771478

1479+
def capturesAndResult(core: () => Tree): Tree =
1480+
if in.token == LBRACE && in.offset == in.lastOffset
1481+
then CapturesAndResult(captureSet(), core())
1482+
else core()
1483+
14781484
/** Type ::= FunType
14791485
* | HkTypeParamClause ‘=>>’ Type
14801486
* | FunParamClause ‘=>>’ Type
@@ -1519,7 +1525,7 @@ object Parsers {
15191525
else
15201526
accept(ARROW)
15211527

1522-
val resultType = typ()
1528+
val resultType = capturesAndResult(typ)
15231529
if token == TLARROW then
15241530
for case ValDef(_, tpt, _) <- params do
15251531
if isByNameType(tpt) then
@@ -1690,19 +1696,27 @@ object Parsers {
16901696
infixOps(t, canStartInfixTypeTokens, refinedTypeFn, Location.ElseWhere, ParseKind.Type,
16911697
isOperator = !followingIsVararg() && !isPureArrow)
16921698

1693-
/** RefinedType ::= WithType {[nl] Refinement}
1699+
/** RefinedType ::= WithType {[nl] (Refinement} [`^` CaptureSet]
16941700
*/
16951701
val refinedTypeFn: Location => Tree = _ => refinedType()
16961702

16971703
def refinedType() = refinedTypeRest(withType())
16981704

16991705
def refinedTypeRest(t: Tree): Tree = {
17001706
argumentStart()
1701-
if (in.isNestedStart)
1707+
if in.isNestedStart then
17021708
refinedTypeRest(atSpan(startOffset(t)) {
17031709
RefinedTypeTree(rejectWildcardType(t), refinement(indentOK = true))
17041710
})
1705-
else t
1711+
else if in.isIdent(nme.UPARROW) then
1712+
val upArrowStart = in.offset
1713+
in.nextToken()
1714+
def cs =
1715+
if in.token == LBRACE then captureSet()
1716+
else atSpan(upArrowStart)(captureRoot) :: Nil
1717+
makeRetaining(t, cs, tpnme.retains)
1718+
else
1719+
t
17061720
}
17071721

17081722
/** WithType ::= AnnotType {`with' AnnotType} (deprecated)
@@ -1929,7 +1943,7 @@ object Parsers {
19291943
def paramTypeOf(core: () => Tree): Tree =
19301944
if in.token == ARROW || isPureArrow(nme.PUREARROW) then
19311945
val isImpure = in.token == ARROW
1932-
val tp = atSpan(in.skipToken()) { ByNameTypeTree(core()) }
1946+
val tp = atSpan(in.skipToken()) { ByNameTypeTree(capturesAndResult(core)) }
19331947
if isImpure && Feature.pureFunsEnabled then ImpureByNameTypeTree(tp) else tp
19341948
else if in.token == LBRACE && followingIsCaptureSet() then
19351949
val start = in.offset

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
730730
val contentText = toTextGlobal(content)
731731
val tptText = toTextGlobal(tpt)
732732
prefix ~~ idx.toString ~~ "|" ~~ tptText ~~ "|" ~~ argsText ~~ "|" ~~ contentText ~~ postfix
733+
case CapturesAndResult(refs, parent) =>
734+
changePrec(GlobalPrec)("^{" ~ Text(refs.map(toText), ", ") ~ "} " ~ toText(parent))
733735
case CapturingTypeTree(refs, parent) =>
734736
parent match
735737
case ImpureByNameTypeTree(bntpt) =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13891389

13901390
def typedFunctionType(tree: untpd.Function, pt: Type)(using Context): Tree = {
13911391
val untpd.Function(args, body) = tree
1392+
body match
1393+
case untpd.CapturesAndResult(refs, result) =>
1394+
return typedUnadapted(untpd.makeRetaining(
1395+
cpy.Function(tree)(args, result), refs, tpnme.retains), pt)
1396+
case _ =>
13921397
var (funFlags, erasedParams) = tree match {
13931398
case tree: untpd.FunctionWithMods => (tree.mods.flags, tree.erasedParams)
13941399
case _ => (EmptyFlags, args.map(_ => false))
@@ -2274,10 +2279,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22742279
assignType(cpy.MatchTypeTree(tree)(bound1, sel1, cases1), bound1, sel1, cases1)
22752280
}
22762281

2277-
def typedByNameTypeTree(tree: untpd.ByNameTypeTree)(using Context): ByNameTypeTree = {
2278-
val result1 = typed(tree.result)
2279-
assignType(cpy.ByNameTypeTree(tree)(result1), result1)
2280-
}
2282+
def typedByNameTypeTree(tree: untpd.ByNameTypeTree)(using Context): ByNameTypeTree = tree.result match
2283+
case untpd.CapturesAndResult(refs, tpe) =>
2284+
typedByNameTypeTree(
2285+
cpy.ByNameTypeTree(tree)(untpd.makeRetaining(tpe, refs, tpnme.retainsByName)))
2286+
case _ =>
2287+
val result1 = typed(tree.result)
2288+
assignType(cpy.ByNameTypeTree(tree)(result1), result1)
22812289

22822290
def typedTypeBoundsTree(tree: untpd.TypeBoundsTree, pt: Type)(using Context): Tree =
22832291
val TypeBoundsTree(lo, hi, alias) = tree

tests/pos-custom-args/captures/byname.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ type Cap = {*} CC
44

55
class I
66

7-
def test(cap1: Cap, cap2: Cap): {cap1} I =
7+
def test(cap1: Cap, cap2: Cap): I{ref cap1} =
88
def f() = if cap1 == cap1 then I() else I()
9-
def h(x: {cap1}-> I) = x
9+
def h(x: ->{ref any} I) = x
1010
h(f()) // OK
1111
def hh(x: -> I @retainsByName(cap1)) = x
1212
h(f())

tests/pos-custom-args/captures/capt-test.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def map[A, B](f: A => B)(xs: LIST[A]): LIST[B] =
1919
xs.map(f)
2020

2121
class C
22-
type Cap = {*} C
22+
type Cap = C{ref any}
2323

2424
class Foo(x: Cap):
25-
this: {x} Foo =>
25+
this: Foo{ref x} =>
2626

2727
def test(c: Cap, d: Cap) =
2828
def f(x: Cap): Unit = if c == x then ()
@@ -32,7 +32,7 @@ def test(c: Cap, d: Cap) =
3232
val zs =
3333
val z = g
3434
CONS(z, ys)
35-
val zsc: LIST[{d, y} Cap -> Unit] = zs
35+
val zsc: LIST[Cap ->{ref d, y} Unit] = zs
3636

3737
val a4 = zs.map(identity)
38-
val a4c: LIST[{d, y} Cap -> Unit] = a4
38+
val a4c: LIST[Cap ->{ref d, y} Unit] = a4

tests/pos-custom-args/captures/capt1.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
class C
2-
type Cap = {*} C
3-
def f1(c: Cap): {c} () -> c.type = () => c // ok
2+
type Cap = C^
3+
def f1(c: Cap): () ->{c} c.type = () => c // ok
44

55
def f2: Int =
6-
val g: {*} Boolean -> Int = ???
6+
val g: Boolean ->{any} Int = ???
77
val x = g(true)
88
x
99

@@ -13,11 +13,11 @@ def f3: Int =
1313
val x = g.apply(true)
1414
x
1515

16-
def foo(): {*} C =
17-
val x: {*} C = ???
18-
val y: {x} C = x
19-
val x2: {x} () -> C = ???
20-
val y2: {x} () -> {x} C = x2
16+
def foo(): C^ =
17+
val x: C^ = ???
18+
val y: C^{x} = x
19+
val x2: () ->{x} C = ???
20+
val y2: () ->{x} C^{x} = x2
2121

2222
val z1: () => Cap = f1(x)
2323
def h[X](a: X)(b: X) = a
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
import annotation.retains
22
class B
3-
type Cap = {*} B
3+
type Cap = B{ref any}
44
class C(val n: Cap):
5-
this: {n} C =>
6-
def foo(): {n} B = n
5+
this: C{ref n} =>
6+
def foo(): B{ref n} = n
77

88

99
def test(x: Cap, y: Cap, z: Cap) =
1010
val c0 = C(x)
11-
val c1: {x} C {val n: {x} B} = c0
11+
val c1: C{ref x}{val n: B{ref x}} = c0
1212
val d = c1.foo()
13-
d: {x} B
13+
d: B{ref x}
1414

1515
val c2 = if ??? then C(x) else C(y)
1616
val c2a = identity(c2)
17-
val c3: {x, y} C { val n: {x, y} B } = c2
17+
val c3: C{ref x, y}{ val n: B{ref x, y} } = c2
1818
val d1 = c3.foo()
19-
d1: B @retains(x, y)
19+
d1: B{ref x, y}
2020

2121
class Local:
2222

@@ -29,7 +29,7 @@ def test(x: Cap, y: Cap, z: Cap) =
2929
end Local
3030

3131
val l = Local()
32-
val l1: {x, y} Local = l
32+
val l1: Local{ref x, y} = l
3333
val l2 = Local(x)
34-
val l3: {x, y, z} Local = l2
34+
val l3: Local{ref x, y, z} = l2
3535

0 commit comments

Comments
 (0)