Skip to content

Commit 42c9eed

Browse files
committed
Handle captures in by-name parameters
1. Infrastructure to deal with capturesets in byname parameters 2. Handle retainsByName annotations in ElimByName Convert them to regular annotations on the generated function types. This enables capture checking on by-name parameters. 3. Add a style warning for misleading by-name parameter type formatting. By-name types should be formatted `{...}-> T`. `{...} -> T` looks too much like a function type.
1 parent 7ba7c89 commit 42c9eed

24 files changed

+233
-84
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ object desugar {
454454

455455
if mods.is(Trait) then
456456
for vparams <- originalVparamss; vparam <- vparams do
457-
if vparam.tpt.isInstanceOf[ByNameTypeTree] then
457+
if isByNameType(vparam.tpt) then
458458
report.error(em"implementation restriction: traits cannot have by name parameters", vparam.srcPos)
459459

460460
// Annotations on class _type_ parameters are set on the derived parameters
@@ -558,9 +558,8 @@ object desugar {
558558
appliedTypeTree(tycon, targs)
559559
}
560560

561-
def isRepeated(tree: Tree): Boolean = tree match {
561+
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
562562
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
563-
case ByNameTypeTree(tree1) => isRepeated(tree1)
564563
case _ => false
565564
}
566565

@@ -1734,8 +1733,13 @@ object desugar {
17341733
case ext: ExtMethods =>
17351734
Block(List(ext), Literal(Constant(())).withSpan(ext.span))
17361735
case CapturingTypeTree(refs, parent) =>
1737-
val annot = New(scalaDot(tpnme.retains), List(refs))
1738-
Annotated(parent, annot)
1736+
def annotate(annotName: TypeName, tp: Tree) =
1737+
Annotated(tp, New(scalaDot(annotName), List(refs)))
1738+
parent match
1739+
case ByNameTypeTree(restpt) =>
1740+
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
1741+
case _ =>
1742+
annotate(tpnme.retains, parent)
17391743
}
17401744
desugared.withSpan(tree.span)
17411745
}

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
172172
}
173173

174174
/** Is tpt a vararg type of the form T* or => T*? */
175-
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = tpt match {
176-
case ByNameTypeTree(tpt1) => isRepeatedParamType(tpt1)
175+
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = stripByNameType(tpt) match {
177176
case tpt: TypeTree => tpt.typeOpt.isRepeatedParam
178177
case AppliedTypeTree(Select(_, tpnme.REPEATED_PARAM_CLASS), _) => true
179178
case _ => false
@@ -190,6 +189,16 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
190189
case arg => arg.typeOpt.widen.isRepeatedParam
191190
}
192191

192+
def isByNameType(tree: Tree)(using Context): Boolean =
193+
stripByNameType(tree) ne tree
194+
195+
def stripByNameType(tree: Tree)(using Context): Tree = unsplice(tree) match
196+
case ByNameTypeTree(t1) => t1
197+
case untpd.CapturingTypeTree(_, parent) =>
198+
val parent1 = stripByNameType(parent)
199+
if parent1 eq parent then tree else parent1
200+
case _ => tree
201+
193202
/** All type and value parameter symbols of this DefDef */
194203
def allParamSyms(ddef: DefDef)(using Context): List[Symbol] =
195204
ddef.paramss.flatten.map(_.symbol)
@@ -389,6 +398,16 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
389398
case _ => None
390399
}
391400
}
401+
402+
object ImpureByNameTypeTree:
403+
def apply(tp: ByNameTypeTree)(using Context): untpd.CapturingTypeTree =
404+
untpd.CapturingTypeTree(
405+
Ident(nme.CAPTURE_ROOT).withSpan(tp.span.startPos) :: Nil, tp)
406+
def unapply(tp: Tree)(using Context): Option[ByNameTypeTree] = tp match
407+
case untpd.CapturingTypeTree(id @ Ident(nme.CAPTURE_ROOT) :: Nil, bntp: ByNameTypeTree)
408+
if id.span == bntp.span.startPos => Some(bntp)
409+
case _ => None
410+
end ImpureByNameTypeTree
392411
}
393412

394413
trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import printing.Printer
1212
import printing.Texts.Text
1313

1414

15-
case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation:
15+
case class CaptureAnnotation(refs: CaptureSet, kind: CapturingKind) extends Annotation:
1616
import CaptureAnnotation.*
1717
import tpd.*
1818

@@ -25,25 +25,26 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
2525
val arg = repeated(elems, TypeTree(defn.AnyType))
2626
New(symbol.typeRef, arg :: Nil)
2727

28-
override def symbol(using Context) = defn.RetainsAnnot
28+
override def symbol(using Context) =
29+
if kind == CapturingKind.ByName then defn.RetainsByNameAnnot else defn.RetainsAnnot
2930

3031
override def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
unsupported("derivedAnnotation(Tree)")
3233

33-
def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation =
34-
if (this.refs eq refs) && (this.boxed == boxed) then this
35-
else CaptureAnnotation(refs, boxed)
34+
def derivedAnnotation(refs: CaptureSet, kind: CapturingKind)(using Context): Annotation =
35+
if (this.refs eq refs) && (this.kind == kind) then this
36+
else CaptureAnnotation(refs, kind)
3637

3738
override def sameAnnotation(that: Annotation)(using Context): Boolean = that match
38-
case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2
39+
case CaptureAnnotation(refs2, kind2) => refs == refs2 && kind == kind2
3940
case _ => false
4041

4142
override def mapWith(tp: TypeMap)(using Context) =
4243
val elems = refs.elems.toList
4344
val elems1 = elems.mapConserve(tp)
4445
if elems1 eq elems then this
4546
else if elems1.forall(_.isInstanceOf[CaptureRef])
46-
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
47+
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), kind)
4748
else EmptyAnnotation
4849

4950
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
@@ -54,10 +55,11 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
5455

5556
override def toText(printer: Printer): Text = refs.toText(printer)
5657

57-
override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0)
58+
override def hash: Int =
59+
(refs.hashCode << 1) | (if kind == CapturingKind.Regular then 0 else 1)
5860

5961
override def eql(that: Annotation) = that match
60-
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed)
62+
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.kind == kind)
6163
case _ => false
6264

6365
end CaptureAnnotation

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ extension (tree: Tree)
4343
extension (tp: Type)
4444

4545
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match
46-
case CapturingType(p, r, b) =>
46+
case CapturingType(p, r, k) =>
4747
if (parent eq p) && (refs eq r) then tp
48-
else CapturingType(parent, refs, b)
48+
else CapturingType(parent, refs, k)
4949

5050
/** If this is type variable instantiated or upper bounded with a capturing type,
5151
* the capture set associated with that type. Extended to and-or types and
@@ -54,7 +54,8 @@ extension (tp: Type)
5454
*/
5555
def boxedCaptured(using Context): CaptureSet =
5656
def getBoxed(tp: Type): CaptureSet = tp match
57-
case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty
57+
case CapturingType(_, refs, CapturingKind.Boxed) => refs
58+
case CapturingType(_, _, _) => CaptureSet.empty
5859
case tp: TypeProxy => getBoxed(tp.superType)
5960
case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2)
6061
case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2)

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ sealed abstract class CaptureSet extends Showable:
209209
((NoType: Type) /: elems) ((tp, ref) =>
210210
if tp.exists then OrType(tp, ref, soft = false) else ref)
211211

212-
def toRegularAnnotation(using Context): Annotation =
213-
Annotation(CaptureAnnotation(this, boxed = false).tree)
212+
def toRegularAnnotation(byName: Boolean)(using Context): Annotation =
213+
val kind = if byName then CapturingKind.ByName else CapturingKind.Regular
214+
Annotation(CaptureAnnotation(this, kind).tree)
214215

215216
override def toText(printer: Printer): Text =
216217
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
5+
/** Possible kinds of captures */
6+
enum CapturingKind:
7+
case Regular // normal capture
8+
case Boxed // capture under box
9+
case ByName // capture applies to enclosing by-name type (only possible before ElimByName)

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,46 @@ package cc
55
import core.*
66
import Types.*, Symbols.*, Contexts.*
77

8+
/** A capturing type. This is internally represented as an annotated type with a `retains`
9+
* annotation, but the extractor will succeed only at phase CheckCaptures.
10+
* Annotated types with `@retainsByName` annotation can also be created that way, by
11+
* giving a `CapturingKind.ByName` as `kind` argument, but they are never extracted,
12+
* since they have already been converted to regular capturing types before CheckCaptures.
13+
*/
814
object CapturingType:
915

10-
def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type =
16+
def apply(parent: Type, refs: CaptureSet, kind: CapturingKind)(using Context): Type =
1117
if refs.isAlwaysEmpty then parent
12-
else AnnotatedType(parent, CaptureAnnotation(refs, boxed))
13-
14-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
15-
if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp)
18+
else AnnotatedType(parent, CaptureAnnotation(refs, kind))
19+
20+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
21+
if ctx.phase == Phases.checkCapturesPhase then
22+
val r = EventuallyCapturingType.unapply(tp)
23+
r match
24+
case Some((_, _, CapturingKind.ByName)) => None
25+
case _ => r
1626
else None
1727

1828
end CapturingType
1929

30+
/** An extractor for types that will be capturing types at phase CheckCaptures. Also
31+
* included are types that indicate captures on enclosing call-by-name parameters
32+
* before phase ElimByName
33+
*/
2034
object EventuallyCapturingType:
2135

22-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
23-
if tp.annot.symbol == defn.RetainsAnnot then
36+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
37+
val sym = tp.annot.symbol
38+
if sym == defn.RetainsAnnot || sym == defn.RetainsByNameAnnot then
2439
tp.annot match
25-
case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed))
40+
case ann: CaptureAnnotation =>
41+
Some((tp.parent, ann.refs, ann.kind))
2642
case ann =>
27-
try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
43+
val kind =
44+
if ann.tree.isBoxedCapturing then CapturingKind.Boxed
45+
else if sym == defn.RetainsByNameAnnot then CapturingKind.ByName
46+
else CapturingKind.Regular
47+
try Some((tp.parent, ann.tree.toCaptureSet, kind))
2848
catch case ex: IllegalCaptureRef => None
2949
else None
3050

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ extends tpd.TreeTraverser:
2525
.toFunctionType(isJava = false, alwaysDependent = true)
2626

2727
private def box(tp: Type)(using Context): Type = tp match
28-
case CapturingType(parent, refs, false) => CapturingType(parent, refs, true)
28+
case CapturingType(parent, refs, CapturingKind.Regular) =>
29+
CapturingType(parent, refs, CapturingKind.Boxed)
2930
case _ => tp
3031

3132
private def setBoxed(tp: Type)(using Context) = tp match
@@ -77,7 +78,7 @@ extends tpd.TreeTraverser:
7778
cls.paramGetters.foldLeft(tp) { (core, getter) =>
7879
if getter.termRef.isTracked then
7980
val getterType = tp.memberInfo(getter).strippedDealias
80-
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
81+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), CapturingKind.Regular))
8182
.showing(i"add capture refinement $tp --> $result", capt)
8283
else
8384
core
@@ -130,7 +131,7 @@ extends tpd.TreeTraverser:
130131
case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) =>
131132
CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2)
132133
case _ if canHaveInferredCapture(tp) =>
133-
CapturingType(tp, CaptureSet.Var(), boxed = false)
134+
CapturingType(tp, CaptureSet.Var(), CapturingKind.Regular)
134135
case _ =>
135136
tp
136137

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import typer.ImportInfo.RootRef
1414
import Comments.CommentsContext
1515
import Comments.Comment
1616
import util.Spans.NoSpan
17-
import cc.{CapturingType, CaptureSet}
17+
import cc.{CapturingType, CaptureSet, CapturingKind, EventuallyCapturingType}
1818

1919
import scala.annotation.tailrec
2020

@@ -117,9 +117,9 @@ class Definitions {
117117
*
118118
* ErasedFunctionN and ErasedContextFunctionN erase to Function0.
119119
*
120-
* EffXYZFunctionN afollow this template:
120+
* ImpureXYZFunctionN follow this template:
121121
*
122-
* type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
122+
* type ImpureXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
123123
*/
124124
private def newFunctionNType(name: TypeName): Symbol = {
125125
val impure = name.startsWith("Impure")
@@ -135,7 +135,7 @@ class Definitions {
135135
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
136136
tl => List.fill(arity + 1)(TypeBounds.empty),
137137
tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
138-
CaptureSet.universal, boxed = false)
138+
CaptureSet.universal, CapturingKind.Regular)
139139
))
140140
else
141141
val cls = denot.asClass.classSymbol
@@ -968,6 +968,7 @@ class Definitions {
968968
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
969969
@tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since")
970970
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains")
971+
@tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.retainsByName")
971972

972973
@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")
973974

@@ -1101,9 +1102,16 @@ class Definitions {
11011102
}
11021103
}
11031104

1105+
/** Extractor for function types representing by-name parameters, of the form
1106+
* `() ?=> T`.
1107+
* Under -Ycc, this becomes `() ?-> T` or `{r1, ..., rN} () ?-> T`.
1108+
*/
11041109
object ByNameFunction:
1105-
def apply(tp: Type)(using Context): Type =
1106-
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
1110+
def apply(tp: Type)(using Context): Type = tp match
1111+
case EventuallyCapturingType(tp1, refs, CapturingKind.ByName) =>
1112+
CapturingType(apply(tp1), refs, CapturingKind.Regular)
1113+
case _ =>
1114+
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
11071115
def unapply(tp: Type)(using Context): Option[Type] = tp match
11081116
case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) =>
11091117
Some(arg)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ object StdNames {
580580
val reify : N = "reify"
581581
val releaseFence : N = "releaseFence"
582582
val retains: N = "retains"
583+
val retainsByName: N = "retainsByName"
583584
val rootMirror : N = "rootMirror"
584585
val run: N = "run"
585586
val runOrElse: N = "runOrElse"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import typer.Applications.productSelectorTypes
2424
import reporting.trace
2525
import NullOpsDecorator._
2626
import annotation.constructorOnly
27-
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing}
27+
import cc.{CapturingType, derivedCapturingType, CaptureSet, CapturingKind, stripCapturing}
2828

2929
/** Provides methods to compare types.
3030
*/
@@ -832,7 +832,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
832832
tp1 match
833833
case tp1: CaptureRef if tp1.isTracked =>
834834
val stripped = tp1w.stripCapturing
835-
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false)
835+
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, CapturingKind.Regular)
836836
case _ =>
837837
isSubType(tp1w, tp2, approx.addLow)
838838
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import scala.util.hashing.{ MurmurHash3 => hashing }
3838
import config.Printers.{core, typr, matchTypes}
3939
import reporting.{trace, Message}
4040
import java.lang.ref.WeakReference
41-
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing}
41+
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing, CapturingKind}
4242
import CaptureSet.CompareResult
4343

4444
import scala.annotation.internal.sharable
@@ -1869,13 +1869,15 @@ object Types {
18691869

18701870
def capturing(ref: CaptureRef)(using Context): Type =
18711871
if captureSet.accountsFor(ref) then this
1872-
else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing)
1872+
else CapturingType(this, ref.singletonCaptureSet,
1873+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18731874

18741875
def capturing(cs: CaptureSet)(using Context): Type =
18751876
if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this
18761877
else this match
18771878
case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs)
1878-
case _ => CapturingType(this, cs, this.isBoxedCapturing)
1879+
case _ => CapturingType(this, cs,
1880+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18791881

18801882
/** The set of distinct symbols referred to by this type, after all aliases are expanded */
18811883
def coveringSet(using Context): Set[Symbol] =
@@ -3796,10 +3798,11 @@ object Types {
37963798
CapturingType(parent1, CaptureSet.universal, boxed))
37973799
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
37983800
val parent1 = mapOver(parent)
3799-
if ann.symbol == defn.RetainsAnnot then
3801+
if ann.symbol == defn.RetainsAnnot || ann.symbol == defn.RetainsByNameAnnot then
3802+
val byName = ann.symbol == defn.RetainsByNameAnnot
38003803
range(
3801-
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation),
3802-
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation))
3804+
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation(byName)),
3805+
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation(byName)))
38033806
else
38043807
parent1
38053808
case _ => mapOver(tp)

0 commit comments

Comments
 (0)