Skip to content

Commit 4f806e2

Browse files
committed
Handle captures in by-name parameters
1. Infrastructure to deal with capturesets in byname parameters 2. Handle retainsByName annotations in ElimByName Convert them to regular annotations on the generated function types. This enables capture checking on by-name parameters. 3. Add a style warning for misleading by-name parameter type formatting. By-name types should be formatted `{...}-> T`. `{...} -> T` looks too much like a function type.
1 parent 9030732 commit 4f806e2

24 files changed

+233
-84
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ object desugar {
454454

455455
if mods.is(Trait) then
456456
for vparams <- originalVparamss; vparam <- vparams do
457-
if vparam.tpt.isInstanceOf[ByNameTypeTree] then
457+
if isByNameType(vparam.tpt) then
458458
report.error(em"implementation restriction: traits cannot have by name parameters", vparam.srcPos)
459459

460460
// Annotations on class _type_ parameters are set on the derived parameters
@@ -562,9 +562,8 @@ object desugar {
562562
appliedTypeTree(tycon, targs)
563563
}
564564

565-
def isRepeated(tree: Tree): Boolean = tree match {
565+
def isRepeated(tree: Tree): Boolean = stripByNameType(tree) match {
566566
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
567-
case ByNameTypeTree(tree1) => isRepeated(tree1)
568567
case _ => false
569568
}
570569

@@ -1735,8 +1734,13 @@ object desugar {
17351734
case ext: ExtMethods =>
17361735
Block(List(ext), Literal(Constant(())).withSpan(ext.span))
17371736
case CapturingTypeTree(refs, parent) =>
1738-
val annot = New(scalaDot(tpnme.retains), List(refs))
1739-
Annotated(parent, annot)
1737+
def annotate(annotName: TypeName, tp: Tree) =
1738+
Annotated(tp, New(scalaDot(annotName), List(refs)))
1739+
parent match
1740+
case ByNameTypeTree(restpt) =>
1741+
cpy.ByNameTypeTree(parent)(annotate(tpnme.retainsByName, restpt))
1742+
case _ =>
1743+
annotate(tpnme.retains, parent)
17401744
}
17411745
desugared.withSpan(tree.span)
17421746
}

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
172172
}
173173

174174
/** Is tpt a vararg type of the form T* or => T*? */
175-
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = tpt match {
176-
case ByNameTypeTree(tpt1) => isRepeatedParamType(tpt1)
175+
def isRepeatedParamType(tpt: Tree)(using Context): Boolean = stripByNameType(tpt) match {
177176
case tpt: TypeTree => tpt.typeOpt.isRepeatedParam
178177
case AppliedTypeTree(Select(_, tpnme.REPEATED_PARAM_CLASS), _) => true
179178
case _ => false
@@ -190,6 +189,16 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
190189
case arg => arg.typeOpt.widen.isRepeatedParam
191190
}
192191

192+
def isByNameType(tree: Tree)(using Context): Boolean =
193+
stripByNameType(tree) ne tree
194+
195+
def stripByNameType(tree: Tree)(using Context): Tree = unsplice(tree) match
196+
case ByNameTypeTree(t1) => t1
197+
case untpd.CapturingTypeTree(_, parent) =>
198+
val parent1 = stripByNameType(parent)
199+
if parent1 eq parent then tree else parent1
200+
case _ => tree
201+
193202
/** All type and value parameter symbols of this DefDef */
194203
def allParamSyms(ddef: DefDef)(using Context): List[Symbol] =
195204
ddef.paramss.flatten.map(_.symbol)
@@ -389,6 +398,16 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
389398
case _ => None
390399
}
391400
}
401+
402+
object ImpureByNameTypeTree:
403+
def apply(tp: ByNameTypeTree)(using Context): untpd.CapturingTypeTree =
404+
untpd.CapturingTypeTree(
405+
Ident(nme.CAPTURE_ROOT).withSpan(tp.span.startPos) :: Nil, tp)
406+
def unapply(tp: Tree)(using Context): Option[ByNameTypeTree] = tp match
407+
case untpd.CapturingTypeTree(id @ Ident(nme.CAPTURE_ROOT) :: Nil, bntp: ByNameTypeTree)
408+
if id.span == bntp.span.startPos => Some(bntp)
409+
case _ => None
410+
end ImpureByNameTypeTree
392411
}
393412

394413
trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import printing.Printer
1212
import printing.Texts.Text
1313

1414

15-
case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation:
15+
case class CaptureAnnotation(refs: CaptureSet, kind: CapturingKind) extends Annotation:
1616
import CaptureAnnotation.*
1717
import tpd.*
1818

@@ -25,25 +25,26 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
2525
val arg = repeated(elems, TypeTree(defn.AnyType))
2626
New(symbol.typeRef, arg :: Nil)
2727

28-
override def symbol(using Context) = defn.RetainsAnnot
28+
override def symbol(using Context) =
29+
if kind == CapturingKind.ByName then defn.RetainsByNameAnnot else defn.RetainsAnnot
2930

3031
override def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
unsupported("derivedAnnotation(Tree)")
3233

33-
def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation =
34-
if (this.refs eq refs) && (this.boxed == boxed) then this
35-
else CaptureAnnotation(refs, boxed)
34+
def derivedAnnotation(refs: CaptureSet, kind: CapturingKind)(using Context): Annotation =
35+
if (this.refs eq refs) && (this.kind == kind) then this
36+
else CaptureAnnotation(refs, kind)
3637

3738
override def sameAnnotation(that: Annotation)(using Context): Boolean = that match
38-
case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2
39+
case CaptureAnnotation(refs2, kind2) => refs == refs2 && kind == kind2
3940
case _ => false
4041

4142
override def mapWith(tp: TypeMap)(using Context) =
4243
val elems = refs.elems.toList
4344
val elems1 = elems.mapConserve(tp)
4445
if elems1 eq elems then this
4546
else if elems1.forall(_.isInstanceOf[CaptureRef])
46-
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
47+
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), kind)
4748
else EmptyAnnotation
4849

4950
override def refersToParamOf(tl: TermLambda)(using Context): Boolean =
@@ -54,10 +55,11 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotatio
5455

5556
override def toText(printer: Printer): Text = refs.toText(printer)
5657

57-
override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0)
58+
override def hash: Int =
59+
(refs.hashCode << 1) | (if kind == CapturingKind.Regular then 0 else 1)
5860

5961
override def eql(that: Annotation) = that match
60-
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed)
62+
case that: CaptureAnnotation => (this.refs eq that.refs) && (this.kind == kind)
6163
case _ => false
6264

6365
end CaptureAnnotation

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ extension (tree: Tree)
4343
extension (tp: Type)
4444

4545
def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match
46-
case CapturingType(p, r, b) =>
46+
case CapturingType(p, r, k) =>
4747
if (parent eq p) && (refs eq r) then tp
48-
else CapturingType(parent, refs, b)
48+
else CapturingType(parent, refs, k)
4949

5050
/** If this is type variable instantiated or upper bounded with a capturing type,
5151
* the capture set associated with that type. Extended to and-or types and
@@ -54,7 +54,8 @@ extension (tp: Type)
5454
*/
5555
def boxedCaptured(using Context): CaptureSet =
5656
def getBoxed(tp: Type): CaptureSet = tp match
57-
case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty
57+
case CapturingType(_, refs, CapturingKind.Boxed) => refs
58+
case CapturingType(_, _, _) => CaptureSet.empty
5859
case tp: TypeProxy => getBoxed(tp.superType)
5960
case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2)
6061
case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2)

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ sealed abstract class CaptureSet extends Showable:
209209
((NoType: Type) /: elems) ((tp, ref) =>
210210
if tp.exists then OrType(tp, ref, soft = false) else ref)
211211

212-
def toRegularAnnotation(using Context): Annotation =
213-
Annotation(CaptureAnnotation(this, boxed = false).tree)
212+
def toRegularAnnotation(byName: Boolean)(using Context): Annotation =
213+
val kind = if byName then CapturingKind.ByName else CapturingKind.Regular
214+
Annotation(CaptureAnnotation(this, kind).tree)
214215

215216
override def toText(printer: Printer): Text =
216217
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
5+
/** Possible kinds of captures */
6+
enum CapturingKind:
7+
case Regular // normal capture
8+
case Boxed // capture under box
9+
case ByName // capture applies to enclosing by-name type (only possible before ElimByName)

compiler/src/dotty/tools/dotc/cc/CapturingType.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,46 @@ package cc
55
import core.*
66
import Types.*, Symbols.*, Contexts.*
77

8+
/** A capturing type. This is internally represented as an annotated type with a `retains`
9+
* annotation, but the extractor will succeed only at phase CheckCaptures.
10+
* Annotated types with `@retainsByName` annotation can also be created that way, by
11+
* giving a `CapturingKind.ByName` as `kind` argument, but they are never extracted,
12+
* since they have already been converted to regular capturing types before CheckCaptures.
13+
*/
814
object CapturingType:
915

10-
def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type =
16+
def apply(parent: Type, refs: CaptureSet, kind: CapturingKind)(using Context): Type =
1117
if refs.isAlwaysEmpty then parent
12-
else AnnotatedType(parent, CaptureAnnotation(refs, boxed))
13-
14-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
15-
if ctx.phase == Phases.checkCapturesPhase then EventuallyCapturingType.unapply(tp)
18+
else AnnotatedType(parent, CaptureAnnotation(refs, kind))
19+
20+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
21+
if ctx.phase == Phases.checkCapturesPhase then
22+
val r = EventuallyCapturingType.unapply(tp)
23+
r match
24+
case Some((_, _, CapturingKind.ByName)) => None
25+
case _ => r
1626
else None
1727

1828
end CapturingType
1929

30+
/** An extractor for types that will be capturing types at phase CheckCaptures. Also
31+
* included are types that indicate captures on enclosing call-by-name parameters
32+
* before phase ElimByName
33+
*/
2034
object EventuallyCapturingType:
2135

22-
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] =
23-
if tp.annot.symbol == defn.RetainsAnnot then
36+
def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, CapturingKind)] =
37+
val sym = tp.annot.symbol
38+
if sym == defn.RetainsAnnot || sym == defn.RetainsByNameAnnot then
2439
tp.annot match
25-
case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed))
40+
case ann: CaptureAnnotation =>
41+
Some((tp.parent, ann.refs, ann.kind))
2642
case ann =>
27-
try Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing))
43+
val kind =
44+
if ann.tree.isBoxedCapturing then CapturingKind.Boxed
45+
else if sym == defn.RetainsByNameAnnot then CapturingKind.ByName
46+
else CapturingKind.Regular
47+
try Some((tp.parent, ann.tree.toCaptureSet, kind))
2848
catch case ex: IllegalCaptureRef => None
2949
else None
3050

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ extends tpd.TreeTraverser:
2525
.toFunctionType(isJava = false, alwaysDependent = true)
2626

2727
private def box(tp: Type)(using Context): Type = tp match
28-
case CapturingType(parent, refs, false) => CapturingType(parent, refs, true)
28+
case CapturingType(parent, refs, CapturingKind.Regular) =>
29+
CapturingType(parent, refs, CapturingKind.Boxed)
2930
case _ => tp
3031

3132
private def setBoxed(tp: Type)(using Context) = tp match
@@ -77,7 +78,7 @@ extends tpd.TreeTraverser:
7778
cls.paramGetters.foldLeft(tp) { (core, getter) =>
7879
if getter.termRef.isTracked then
7980
val getterType = tp.memberInfo(getter).strippedDealias
80-
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
81+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), CapturingKind.Regular))
8182
.showing(i"add capture refinement $tp --> $result", capt)
8283
else
8384
core
@@ -130,7 +131,7 @@ extends tpd.TreeTraverser:
130131
case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) =>
131132
CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2)
132133
case _ if canHaveInferredCapture(tp) =>
133-
CapturingType(tp, CaptureSet.Var(), boxed = false)
134+
CapturingType(tp, CaptureSet.Var(), CapturingKind.Regular)
134135
case _ =>
135136
tp
136137

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import typer.ImportInfo.RootRef
1414
import Comments.CommentsContext
1515
import Comments.Comment
1616
import util.Spans.NoSpan
17-
import cc.{CapturingType, CaptureSet}
17+
import cc.{CapturingType, CaptureSet, CapturingKind, EventuallyCapturingType}
1818

1919
import scala.annotation.tailrec
2020

@@ -117,9 +117,9 @@ class Definitions {
117117
*
118118
* ErasedFunctionN and ErasedContextFunctionN erase to Function0.
119119
*
120-
* EffXYZFunctionN afollow this template:
120+
* ImpureXYZFunctionN follow this template:
121121
*
122-
* type EffXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
122+
* type ImpureXYZFunctionN[-T0,...,-T{N-1}, +R] = {*} XYZFunctionN[T0,...,T{N-1}, R]
123123
*/
124124
private def newFunctionNType(name: TypeName): Symbol = {
125125
val impure = name.startsWith("Impure")
@@ -135,7 +135,7 @@ class Definitions {
135135
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
136136
tl => List.fill(arity + 1)(TypeBounds.empty),
137137
tl => CapturingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
138-
CaptureSet.universal, boxed = false)
138+
CaptureSet.universal, CapturingKind.Regular)
139139
))
140140
else
141141
val cls = denot.asClass.classSymbol
@@ -965,6 +965,7 @@ class Definitions {
965965
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
966966
@tu lazy val SinceAnnot: ClassSymbol = requiredClass("scala.annotation.since")
967967
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains")
968+
@tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.retainsByName")
968969

969970
@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")
970971

@@ -1098,9 +1099,16 @@ class Definitions {
10981099
}
10991100
}
11001101

1102+
/** Extractor for function types representing by-name parameters, of the form
1103+
* `() ?=> T`.
1104+
* Under -Ycc, this becomes `() ?-> T` or `{r1, ..., rN} () ?-> T`.
1105+
*/
11011106
object ByNameFunction:
1102-
def apply(tp: Type)(using Context): Type =
1103-
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
1107+
def apply(tp: Type)(using Context): Type = tp match
1108+
case EventuallyCapturingType(tp1, refs, CapturingKind.ByName) =>
1109+
CapturingType(apply(tp1), refs, CapturingKind.Regular)
1110+
case _ =>
1111+
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
11041112
def unapply(tp: Type)(using Context): Option[Type] = tp match
11051113
case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) =>
11061114
Some(arg)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ object StdNames {
579579
val reify : N = "reify"
580580
val releaseFence : N = "releaseFence"
581581
val retains: N = "retains"
582+
val retainsByName: N = "retainsByName"
582583
val rootMirror : N = "rootMirror"
583584
val run: N = "run"
584585
val runOrElse: N = "runOrElse"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import typer.ProtoTypes.constrained
2323
import typer.Applications.productSelectorTypes
2424
import reporting.trace
2525
import annotation.constructorOnly
26-
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing}
26+
import cc.{CapturingType, derivedCapturingType, CaptureSet, CapturingKind, stripCapturing}
2727

2828
/** Provides methods to compare types.
2929
*/
@@ -841,7 +841,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
841841
tp1 match
842842
case tp1: CaptureRef if tp1.isTracked =>
843843
val stripped = tp1w.stripCapturing
844-
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false)
844+
tp1w = CapturingType(stripped, tp1.singletonCaptureSet, CapturingKind.Regular)
845845
case _ =>
846846
isSubType(tp1w, tp2, approx.addLow)
847847
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import scala.util.hashing.{ MurmurHash3 => hashing }
3535
import config.Printers.{core, typr, matchTypes}
3636
import reporting.{trace, Message}
3737
import java.lang.ref.WeakReference
38-
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing}
38+
import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing, CapturingKind}
3939
import CaptureSet.CompareResult
4040

4141
import scala.annotation.internal.sharable
@@ -1866,13 +1866,15 @@ object Types {
18661866

18671867
def capturing(ref: CaptureRef)(using Context): Type =
18681868
if captureSet.accountsFor(ref) then this
1869-
else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing)
1869+
else CapturingType(this, ref.singletonCaptureSet,
1870+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18701871

18711872
def capturing(cs: CaptureSet)(using Context): Type =
18721873
if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this
18731874
else this match
18741875
case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs)
1875-
case _ => CapturingType(this, cs, this.isBoxedCapturing)
1876+
case _ => CapturingType(this, cs,
1877+
if this.isBoxedCapturing then CapturingKind.Boxed else CapturingKind.Regular)
18761878

18771879
/** The set of distinct symbols referred to by this type, after all aliases are expanded */
18781880
def coveringSet(using Context): Set[Symbol] =
@@ -3793,10 +3795,11 @@ object Types {
37933795
CapturingType(parent1, CaptureSet.universal, boxed))
37943796
case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) =>
37953797
val parent1 = mapOver(parent)
3796-
if ann.symbol == defn.RetainsAnnot then
3798+
if ann.symbol == defn.RetainsAnnot || ann.symbol == defn.RetainsByNameAnnot then
3799+
val byName = ann.symbol == defn.RetainsByNameAnnot
37973800
range(
3798-
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation),
3799-
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation))
3801+
AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation(byName)),
3802+
AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation(byName)))
38003803
else
38013804
parent1
38023805
case _ => mapOver(tp)

0 commit comments

Comments
 (0)