From b00a8b547a99b0e7593ee4f51c2aacf81ac213c3 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 29 Sep 2021 17:46:49 +0200 Subject: [PATCH] First version of capture checker. A squashed version of the following commits: Handle byname parameters Don't force symbol completion when printing flags or annotations Check overrides and disallow non-local inferred capture types Handle `this` in capture sets Print capture variable dependencies under -Ydebug-cc Avoid spurious error message Avoid spurious error message "cannot be tracked since its capture set is empty". This arose in lazyref.scala for a DependentTypeTree in an anaonymois function. Dependent type trees map to normal TypeTrees, not InferredTypeTrees (and things go wrong if we try to change that). Drop TopType Consider bounds of type variables to be boxed More tests Avoid multiple maps when creating symbol infos Use a single BiTypeMap to map from inferred result and parameters to method info. This improves efficiency and debuggability by reducing the frequence of multiple stacked maps capture sets. Refactor with CompareResult#andAlso Refactoring: use isOK on CompareResult Reflect inferred parameter types in enclosing method type The variables in the inferred parameter type of an anonymous function need to also show up in the closure type itself, so that they can be constrained. Don't interpolate parameters of anonymous functions Here, we should wait until we get the info from the outside, which can be arbitrarily much later. Compute upper approximation of bimapped sets from both sides Fail when trying to add new elements to mapped sets It's the safe option. Print full origin trail of derived capture sets under -Ycc-debug Fix isEmpty condition in well-formedness check Make printing capture sets dependent on -Ycc-debug Recursion brake for upperApprox Fixes to upperApprox Make instantiteRT a BiTypeMap Otherwise we will not be able to do upper approximations of parameters. Interpolate only variables at negative polarity Interpolating covariant variables risks restricting capture sets to early. For instance, when a variable has the capture set of a called function in its capture set. When we have indirectly recursive calls it could be that the capture set of a called function is not yet fully formed. Interpolate type variables when symbols are completed Allow for possibility that variables are constant Only recomplete symbols if their info changes Add completions to Rechecker Complete val and def definitions lazily on first access. Now, recheckDefDef and recheckValDef are called the first time the new info of the defined symbol is needed, or, if the info is never needed, when the typer gets to the definitions. This only applied to definitions with inferred types. The others are handled in typer sequence, as before. The motivation of the change is that some modifications to inferred types of symbols can be made in subclasses without running into ordering problems. More fixes for subCapture New setting -Ycc-debug for more info on capture variables Fix subCapture in frozen state Previously, we still OKed two empty variables to be compared with subcapture in the frozen state. This should give an error. Direct comparisons of dependent function types Revert: Special treatment of dependent functions in TypeComparer change test Also treat explicit capturing type arguments as boxed Print subcapturing steps in -explain traces Don't decorate type variables with additional capture sets Boxed CapturingTypes Drop unsound capture suppression if expected type is boxed If expected type is boxed, the expression still contributes to the captured variables of its environment. Re-infer result types of anonymous functions Keep erased implicit args Special treatment of dependent functions in TypeComparer Fix addFunctionRefinements Always print refined function types as dependent functions. Makes it easier to see what goes on. Make CaptureSet ++ and ** simplify more Refine function types when reinferring so that they can be dependent Fix avoidance problem when typing blocks We should not pass en expected type when rechecking the expression of a block since that can add local references to global capture set variables. Also: tests for lists and pairs Print empty variables with "?" Fix printing untyped annotations Fix printing annotations in trees Drop redundant code Refactor map operations on capture sets Intoduce Bi-Mapped CaptureSets Report an error is a simply mapped capture set gets new elements that do not come from the original souurce. Introduce a new abstraction of bi-mapped sets that accept new elements and propagate them to the original source. Add map operation to SimpleIdentitySet Restrict tracked class parameters to vals Handle local classes and secondary constructors Fix CapturingType precedence when printing First stab at handling classes Bug fixes 1. Fix canBeTracked for TermRefs only TermRefs where prefix is NoPrefix or `this` can be tracked. The others have to be widened. 2. Fix rule for comparing capture refs on the left 3. Be more careful where comparisons are frozen Capture checker for functions --- compiler/src/dotty/tools/dotc/Compiler.scala | 5 +- compiler/src/dotty/tools/dotc/Run.scala | 4 +- .../src/dotty/tools/dotc/ast/Desugar.scala | 5 + compiler/src/dotty/tools/dotc/ast/Trees.scala | 8 +- compiler/src/dotty/tools/dotc/ast/untpd.scala | 17 +- .../tools/dotc/cc/CaptureAnnotation.scala | 63 ++ .../src/dotty/tools/dotc/cc/CaptureOps.scala | 82 +++ .../src/dotty/tools/dotc/cc/CaptureSet.scala | 577 ++++++++++++++++++ .../dotty/tools/dotc/cc/CapturingType.scala | 21 + .../src/dotty/tools/dotc/config/Config.scala | 5 + .../dotty/tools/dotc/config/Printers.scala | 1 + .../tools/dotc/config/ScalaSettings.scala | 2 + .../dotty/tools/dotc/core/Annotations.scala | 6 +- .../dotty/tools/dotc/core/Definitions.scala | 17 +- .../tools/dotc/core/OrderingConstraint.scala | 4 + .../src/dotty/tools/dotc/core/Phases.scala | 16 +- .../src/dotty/tools/dotc/core/StdNames.scala | 4 +- .../dotty/tools/dotc/core/Substituters.scala | 6 +- .../tools/dotc/core/SymDenotations.scala | 9 +- .../dotty/tools/dotc/core/TypeComparer.scala | 194 ++++-- .../dotty/tools/dotc/core/TypeErrors.scala | 1 + .../src/dotty/tools/dotc/core/TypeOps.scala | 26 +- .../src/dotty/tools/dotc/core/Types.scala | 223 ++++++- .../src/dotty/tools/dotc/core/Variances.scala | 3 + .../tools/dotc/core/tasty/TreeUnpickler.scala | 10 +- .../dotty/tools/dotc/parsing/Parsers.scala | 34 +- .../src/dotty/tools/dotc/parsing/Tokens.scala | 4 +- .../tools/dotc/printing/PlainPrinter.scala | 32 +- .../dotty/tools/dotc/printing/Printer.scala | 8 +- .../tools/dotc/printing/RefinedPrinter.scala | 23 +- .../dotty/tools/dotc/reporting/messages.scala | 1 - .../src/dotty/tools/dotc/sbt/ExtractAPI.scala | 1 + .../tools/dotc/transform/EmptyPhase.scala | 19 + .../tools/dotc/transform/PostTyper.scala | 6 +- .../dotty/tools/dotc/transform/Recheck.scala | 238 ++++++-- .../tools/dotc/transform/TreeChecker.scala | 4 +- .../dotc/transform/TryCatchPatterns.scala | 2 +- .../tools/dotc/transform/TypeTestsCasts.scala | 4 +- .../tools/dotc/typer/CheckCaptures.scala | 468 ++++++++++++++ .../src/dotty/tools/dotc/typer/Checking.scala | 6 +- .../dotty/tools/dotc/typer/Inferencing.scala | 8 +- .../dotty/tools/dotc/typer/RefChecks.scala | 2 +- .../dotty/tools/dotc/typer/TypeAssigner.scala | 26 +- .../src/dotty/tools/dotc/typer/Typer.scala | 9 +- .../tools/dotc/util/SimpleIdentitySet.scala | 23 + .../dotty/tools/dotc/CompilationTests.scala | 3 + library/src-bootstrapped/scala/Retains.scala | 6 + .../scala/annotation/ability.scala | 9 + .../scala/runtime/stdLibPatches/Predef.scala | 1 + .../neg-custom-args/captures/capt-wf.scala | 19 + .../neg-custom-args/captures/try2.check | 38 ++ .../neg-custom-args/captures/try2.scala | 55 ++ tests/disabled/pos/lazylist.scala | 51 ++ .../allow-deep-subtypes}/i9325.scala | 0 tests/neg-custom-args/capt-wf.scala | 35 ++ tests/neg-custom-args/captures/bounded.scala | 14 + tests/neg-custom-args/captures/boxmap.check | 7 + tests/neg-custom-args/captures/boxmap.scala | 14 + tests/neg-custom-args/captures/byname.scala | 10 + .../captures/capt-box-env.scala | 12 + tests/neg-custom-args/captures/capt-box.scala | 13 + .../captures/capt-depfun.scala | 7 + .../captures/capt-depfun2.scala | 10 + tests/neg-custom-args/captures/capt-env.scala | 13 + .../neg-custom-args/captures/capt-test.scala | 26 + .../captures/capt-wf-typer.scala | 10 + tests/neg-custom-args/captures/capt1.check | 46 ++ tests/neg-custom-args/captures/capt1.scala | 34 ++ tests/neg-custom-args/captures/capt2.scala | 9 + tests/neg-custom-args/captures/capt3.scala | 26 + tests/neg-custom-args/captures/cc1.scala | 4 + tests/neg-custom-args/captures/classes.scala | 12 + tests/neg-custom-args/captures/io.scala | 22 + tests/neg-custom-args/captures/lazylist.check | 42 ++ tests/neg-custom-args/captures/lazylist.scala | 41 ++ tests/neg-custom-args/captures/lazyref.check | 28 + tests/neg-custom-args/captures/lazyref.scala | 25 + tests/neg-custom-args/captures/try.check | 25 + tests/neg-custom-args/captures/try.scala | 53 ++ tests/neg-custom-args/captures/try3.scala | 27 + tests/neg/multiLineOps.scala | 2 +- tests/neg/polymorphic-functions1.check | 7 + tests/neg/polymorphic-functions1.scala | 1 + tests/pos-custom-args/captures/bounded.scala | 14 + .../captures/boxmap-paper.scala | 38 ++ tests/pos-custom-args/captures/boxmap.scala | 20 + tests/pos-custom-args/captures/byname.scala | 10 + .../captures/capt-depfun.scala | 18 + .../captures/capt-depfun2.scala | 8 + .../pos-custom-args/captures/capt-test.scala | 35 ++ tests/pos-custom-args/captures/capt0.scala | 7 + tests/pos-custom-args/captures/capt1.scala | 27 + tests/pos-custom-args/captures/capt2.scala | 20 + .../pos-custom-args/captures/cc-expand.scala | 21 + tests/pos-custom-args/captures/classes.scala | 34 ++ .../pos-custom-args/captures/iterators.scala | 23 + tests/pos-custom-args/captures/lazyref.scala | 25 + .../captures/list-encoding.scala | 23 + tests/pos-custom-args/captures/lists.scala | 91 +++ tests/pos-custom-args/captures/pairs.scala | 33 + tests/pos-custom-args/captures/try.scala | 26 + tests/pos-custom-args/captures/try3.scala | 51 ++ 102 files changed, 3274 insertions(+), 234 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala create mode 100644 compiler/src/dotty/tools/dotc/cc/CaptureOps.scala create mode 100644 compiler/src/dotty/tools/dotc/cc/CaptureSet.scala create mode 100644 compiler/src/dotty/tools/dotc/cc/CapturingType.scala create mode 100644 compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala create mode 100644 compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala create mode 100644 library/src-bootstrapped/scala/Retains.scala create mode 100644 library/src-bootstrapped/scala/annotation/ability.scala create mode 100644 tests/disabled/neg-custom-args/captures/capt-wf.scala create mode 100644 tests/disabled/neg-custom-args/captures/try2.check create mode 100644 tests/disabled/neg-custom-args/captures/try2.scala create mode 100644 tests/disabled/pos/lazylist.scala rename tests/{neg => neg-custom-args/allow-deep-subtypes}/i9325.scala (100%) create mode 100644 tests/neg-custom-args/capt-wf.scala create mode 100644 tests/neg-custom-args/captures/bounded.scala create mode 100644 tests/neg-custom-args/captures/boxmap.check create mode 100644 tests/neg-custom-args/captures/boxmap.scala create mode 100644 tests/neg-custom-args/captures/byname.scala create mode 100644 tests/neg-custom-args/captures/capt-box-env.scala create mode 100644 tests/neg-custom-args/captures/capt-box.scala create mode 100644 tests/neg-custom-args/captures/capt-depfun.scala create mode 100644 tests/neg-custom-args/captures/capt-depfun2.scala create mode 100644 tests/neg-custom-args/captures/capt-env.scala create mode 100644 tests/neg-custom-args/captures/capt-test.scala create mode 100644 tests/neg-custom-args/captures/capt-wf-typer.scala create mode 100644 tests/neg-custom-args/captures/capt1.check create mode 100644 tests/neg-custom-args/captures/capt1.scala create mode 100644 tests/neg-custom-args/captures/capt2.scala create mode 100644 tests/neg-custom-args/captures/capt3.scala create mode 100644 tests/neg-custom-args/captures/cc1.scala create mode 100644 tests/neg-custom-args/captures/classes.scala create mode 100644 tests/neg-custom-args/captures/io.scala create mode 100644 tests/neg-custom-args/captures/lazylist.check create mode 100644 tests/neg-custom-args/captures/lazylist.scala create mode 100644 tests/neg-custom-args/captures/lazyref.check create mode 100644 tests/neg-custom-args/captures/lazyref.scala create mode 100644 tests/neg-custom-args/captures/try.check create mode 100644 tests/neg-custom-args/captures/try.scala create mode 100644 tests/neg-custom-args/captures/try3.scala create mode 100644 tests/neg/polymorphic-functions1.check create mode 100644 tests/neg/polymorphic-functions1.scala create mode 100644 tests/pos-custom-args/captures/bounded.scala create mode 100644 tests/pos-custom-args/captures/boxmap-paper.scala create mode 100644 tests/pos-custom-args/captures/boxmap.scala create mode 100644 tests/pos-custom-args/captures/byname.scala create mode 100644 tests/pos-custom-args/captures/capt-depfun.scala create mode 100644 tests/pos-custom-args/captures/capt-depfun2.scala create mode 100644 tests/pos-custom-args/captures/capt-test.scala create mode 100644 tests/pos-custom-args/captures/capt0.scala create mode 100644 tests/pos-custom-args/captures/capt1.scala create mode 100644 tests/pos-custom-args/captures/capt2.scala create mode 100644 tests/pos-custom-args/captures/cc-expand.scala create mode 100644 tests/pos-custom-args/captures/classes.scala create mode 100644 tests/pos-custom-args/captures/iterators.scala create mode 100644 tests/pos-custom-args/captures/lazyref.scala create mode 100644 tests/pos-custom-args/captures/list-encoding.scala create mode 100644 tests/pos-custom-args/captures/lists.scala create mode 100644 tests/pos-custom-args/captures/pairs.scala create mode 100644 tests/pos-custom-args/captures/try.scala create mode 100644 tests/pos-custom-args/captures/try3.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 33b78e9fe945..45b1c56ab4be 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -4,6 +4,7 @@ package dotc import core._ import Contexts._ import typer.{TyperPhase, RefChecks} +import cc.CheckCaptures import parsing.Parser import Phases.Phase import transform._ @@ -78,6 +79,8 @@ class Compiler { new RefChecks, // Various checks mostly related to abstract members and overriding new TryCatchPatterns, // Compile cases in try/catch new PatternMatcher) :: // Compile pattern matches + List(new PreRecheck) :: // Preparations for check captures phase, enabled under -Ycc + List(new CheckCaptures) :: // Check captures, enabled under -Ycc List(new ElimOpaque, // Turn opaque into normal aliases new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only) new ExplicitOuter, // Add accessors to outer classes from nested ones. @@ -101,8 +104,6 @@ class Compiler { new TupleOptimizations, // Optimize generic operations on tuples new LetOverApply, // Lift blocks from receivers of applications new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify. - List(new PreRecheck) :: // Preparations for recheck phase, enabled under -Yrecheck - List(new TestRecheck) :: // Test rechecking, enabled under -Yrecheck List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks diff --git a/compiler/src/dotty/tools/dotc/Run.scala b/compiler/src/dotty/tools/dotc/Run.scala index b9552d97fca7..32b7b2feaeb3 100644 --- a/compiler/src/dotty/tools/dotc/Run.scala +++ b/compiler/src/dotty/tools/dotc/Run.scala @@ -20,9 +20,7 @@ import reporting.{Reporter, Suppression, Action} import reporting.Diagnostic import reporting.Diagnostic.Warning import rewrites.Rewrites - import profile.Profiler -import printing.XprintMode import parsing.Parsers.Parser import parsing.JavaParsers.JavaParser import typer.ImplicitRunInfo @@ -328,7 +326,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint val fusedPhase = ctx.base.fusedContaining(prevPhase) val echoHeader = f"[[syntax trees at end of $fusedPhase%25s]] // ${unit.source}" val tree = if ctx.isAfterTyper then unit.tpdTree else unit.untpdTree - val treeString = tree.show(using ctx.withProperty(XprintMode, Some(()))) + val treeString = fusedPhase.show(tree) last match { case SomePrintedTree(phase, lastTreeString) if lastTreeString == treeString => diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index a71fc3d40e92..5e7e6bd57c29 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1752,6 +1752,9 @@ object desugar { flatTree(pats1 map (makePatDef(tree, mods, _, rhs))) case ext: ExtMethods => Block(List(ext), Literal(Constant(())).withSpan(ext.span)) + case CapturingTypeTree(refs, parent) => + val annot = New(scalaDot(tpnme.retains), List(refs)) + Annotated(parent, annot) } desugared.withSpan(tree.span) } @@ -1890,6 +1893,8 @@ object desugar { case _ => traverseChildren(tree) } }.traverse(expr) + case CapturingTypeTree(refs, parent) => + collect(parent) case _ => } collect(tree) diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 5bdd18705051..267c6b114a8c 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -260,16 +260,10 @@ object Trees { /** Tree's denotation can be derived from its type */ abstract class DenotingTree[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends Tree[T] { type ThisTree[-T >: Untyped] <: DenotingTree[T] - override def denot(using Context): Denotation = typeOpt match { + override def denot(using Context): Denotation = typeOpt.stripped match case tpe: NamedType => tpe.denot case tpe: ThisType => tpe.cls.denot - case tpe: AnnotatedType => tpe.stripAnnots match { - case tpe: NamedType => tpe.denot - case tpe: ThisType => tpe.cls.denot - case _ => NoDenotation - } case _ => NoDenotation - } } /** Tree's denot/isType/isTerm properties come from a subtree diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 40467dc5be3f..b9960cbb4652 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -147,6 +147,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case Floating } + /** {x1, ..., xN} T (only relevant under -Ycc) */ + case class CapturingTypeTree(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree + /** Short-lived usage in typer, does not need copy/transform/fold infrastructure */ case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree @@ -458,7 +461,11 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def AppliedTypeTree(tpt: Tree, arg: Tree)(implicit src: SourceFile): AppliedTypeTree = AppliedTypeTree(tpt, arg :: Nil) - def TypeTree(tpe: Type)(using Context): TypedSplice = TypedSplice(TypeTree().withTypeUnchecked(tpe)) + def TypeTree(tpe: Type)(using Context): TypedSplice = + TypedSplice(TypeTree().withTypeUnchecked(tpe)) + + def InferredTypeTree(tpe: Type)(using Context): TypedSplice = + TypedSplice(new InferredTypeTree().withTypeUnchecked(tpe)) def unitLiteral(implicit src: SourceFile): Literal = Literal(Constant(())) @@ -646,6 +653,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case tree: Number if (digits == tree.digits) && (kind == tree.kind) => tree case _ => finalize(tree, untpd.Number(digits, kind)) } + def CapturingTypeTree(tree: Tree)(refs: List[Tree], parent: Tree)(using Context): Tree = tree match + case tree: CapturingTypeTree if (refs eq tree.refs) && (parent eq tree.parent) => tree + case _ => finalize(tree, untpd.CapturingTypeTree(refs, parent)) + def TypedSplice(tree: Tree)(splice: tpd.Tree)(using Context): ProxyTree = tree match { case tree: TypedSplice if splice `eq` tree.splice => tree case _ => finalize(tree, untpd.TypedSplice(splice)(using ctx)) @@ -711,6 +722,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { tree case MacroTree(expr) => cpy.MacroTree(tree)(transform(expr)) + case CapturingTypeTree(refs, parent) => + cpy.CapturingTypeTree(tree)(transform(refs), transform(parent)) case _ => super.transformMoreCases(tree) } @@ -772,6 +785,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { this(x, splice) case MacroTree(expr) => this(x, expr) + case CapturingTypeTree(refs, parent) => + this(this(x, refs), parent) case _ => super.foldMoreCases(x, tree) } diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala new file mode 100644 index 000000000000..5f73b50a6bbe --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala @@ -0,0 +1,63 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.*, Annotations.* +import ast.Trees.* +import ast.{tpd, untpd} +import Decorators.* +import config.Printers.capt +import printing.Printer +import printing.Texts.Text + + +case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean) extends Annotation: + import CaptureAnnotation.* + import tpd.* + + override def tree(using Context) = + val elems = refs.elems.toList.map { + case cr: TermRef => ref(cr) + case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr) + case cr: ThisType => This(cr.cls) + } + val arg = repeated(elems, TypeTree(defn.AnyType)) + New(symbol.typeRef, arg :: Nil) + + override def symbol(using Context) = defn.RetainsAnnot + + override def derivedAnnotation(tree: Tree)(using Context): Annotation = + unsupported("derivedAnnotation(Tree)") + + def derivedAnnotation(refs: CaptureSet, boxed: Boolean)(using Context): Annotation = + if (this.refs eq refs) && (this.boxed == boxed) then this + else CaptureAnnotation(refs, boxed) + + override def sameAnnotation(that: Annotation)(using Context): Boolean = that match + case CaptureAnnotation(refs2, boxed2) => refs == refs2 && boxed == boxed2 + case _ => false + + override def mapWith(tp: TypeMap)(using Context) = + val elems = refs.elems.toList + val elems1 = elems.mapConserve(tp) + if elems1 eq elems then this + else if elems1.forall(_.isInstanceOf[CaptureRef]) + then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed) + else EmptyAnnotation + + override def refersToParamOf(tl: TermLambda)(using Context): Boolean = + refs.elems.exists { + case TermParamRef(tl1, _) => tl eq tl1 + case _ => false + } + + override def toText(printer: Printer): Text = refs.toText(printer) + + override def hash: Int = (refs.hashCode << 1) | (if boxed then 1 else 0) + + override def eql(that: Annotation) = that match + case that: CaptureAnnotation => (this.refs eq that.refs) && (this.boxed == boxed) + case _ => false + +end CaptureAnnotation diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala new file mode 100644 index 000000000000..09064314b1bf --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -0,0 +1,82 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.*, Annotations.* +import ast.{tpd, untpd} +import Decorators.* +import config.Printers.capt +import util.Property.Key +import tpd.* + +private val Captures: Key[CaptureSet] = Key() +private val IsBoxed: Key[Unit] = Key() + +def retainedElems(tree: Tree)(using Context): List[Tree] = tree match + case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems + case _ => Nil + +extension (tree: Tree) + + def toCaptureRef(using Context): CaptureRef = tree.tpe.asInstanceOf[CaptureRef] + + def toCaptureSet(using Context): CaptureSet = + tree.getAttachment(Captures) match + case Some(refs) => refs + case None => + val refs = CaptureSet(retainedElems(tree).map(_.toCaptureRef)*) + .showing(i"toCaptureSet $tree --> $result", capt) + tree.putAttachment(Captures, refs) + refs + + def isBoxedCapturing(using Context): Boolean = + tree.hasAttachment(IsBoxed) + + def setBoxedCapturing()(using Context): Unit = + tree.putAttachment(IsBoxed, ()) + +extension (tp: Type) + + def derivedCapturingType(parent: Type, refs: CaptureSet)(using Context): Type = tp match + case CapturingType(p, r, b) => + if (parent eq p) && (refs eq r) then tp + else CapturingType(parent, refs, b) + + /** If this is type variable instantiated or upper bounded with a capturing type, + * the capture set associated with that type. Extended to and-or types and + * type proxies in the obvious way. If a term has a type with a boxed captureset, + * that captureset counts towards the capture variables of the envirionment. + */ + def boxedCaptured(using Context): CaptureSet = + def getBoxed(tp: Type): CaptureSet = tp match + case CapturingType(_, refs, boxed) => if boxed then refs else CaptureSet.empty + case tp: TypeProxy => getBoxed(tp.superType) + case tp: AndType => getBoxed(tp.tp1) ++ getBoxed(tp.tp2) + case tp: OrType => getBoxed(tp.tp1) ** getBoxed(tp.tp2) + case _ => CaptureSet.empty + getBoxed(tp) + + def isBoxedCapturing(using Context) = !tp.boxedCaptured.isAlwaysEmpty + + def canHaveInferredCapture(using Context): Boolean = tp match + case tp: TypeRef if tp.symbol.isClass => + !tp.symbol.isValueClass && tp.symbol != defn.AnyClass + case _: TypeVar | _: TypeParamRef => + false + case tp: TypeProxy => + tp.superType.canHaveInferredCapture + case tp: AndType => + tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture + case tp: OrType => + tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture + case _ => + false + + def stripCapturing(using Context): Type = tp.dealiasKeepAnnots match + case CapturingType(parent, _, _) => + parent.stripCapturing + case atd @ AnnotatedType(parent, annot) => + atd.derivedAnnotatedType(parent.stripCapturing, annot) + case _ => + tp diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala new file mode 100644 index 000000000000..f8ca2f87e3c5 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -0,0 +1,577 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Flags.*, Contexts.*, Decorators.* +import config.Printers.capt +import Annotations.Annotation +import annotation.threadUnsafe +import annotation.constructorOnly +import annotation.internal.sharable +import reporting.trace +import printing.{Showable, Printer} +import printing.Texts.* +import util.{SimpleIdentitySet, Property} +import util.common.alwaysTrue +import scala.collection.mutable + +/** A class for capture sets. Capture sets can be constants or variables. + * Capture sets support inclusion constraints <:< where <:< is subcapturing. + * They also allow mapping with arbitrary functions from elements to capture sets, + * by supporting a monadic flatMap operation. That is, constraints can be + * of one of the following forms + * + * cs1 <:< cs2 + * cs1 = ∪ {f(x) | x ∈ cs2} + * + * where the `f`s are arbitrary functions from capture references to capture sets. + * We call the resulting constraint system "monadic set constraints". + */ +sealed abstract class CaptureSet extends Showable: + import CaptureSet.* + + /** The elements of this capture set. For capture variables, + * the elements known so far. + */ + def elems: Refs + + /** Is this capture set constant (i.e. not an unsolved capture variable)? + * Solved capture variables count as constant. + */ + def isConst: Boolean + + /** Is this capture set always empty? For capture veraiables, returns + * always false + */ + def isAlwaysEmpty: Boolean + + /** Is this capture set definitely non-empty? */ + final def isNotEmpty: Boolean = !elems.isEmpty + + /** Cast to variable. @pre: @isConst */ + def asVar: Var = + assert(!isConst) + asInstanceOf[Var] + + /** Add new elements to this capture set if allowed. + * @pre `newElems` is not empty and does not overlap with `this.elems`. + * Constant capture sets never allow to add new elements. + * Variables allow it if and only if the new elements can be included + * in all their supersets. + * @param origin The set where the elements come from, or `empty` if not known. + * @return CompareResult.OK if elements were added, or a conflicting + * capture set that prevents addition otherwise. + */ + protected def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult + + /** If this is a variable, add `cs` as a super set */ + protected def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult + + /** If `cs` is a variable, add this capture set as one of its super sets */ + protected def addSub(cs: CaptureSet)(using Context): this.type = + cs.addSuper(this)(using ctx, UnrecordedState) + this + + /** Try to include all references of `elems` that are not yet accounted by this + * capture set. Inclusion is via `addNewElems`. + * @param origin The set where the elements come from, or `empty` if not known. + * @return CompareResult.OK if all unaccounted elements could be added, + * capture set that prevents addition otherwise. + */ + protected final def tryInclude(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + val unaccounted = elems.filter(!accountsFor(_)) + if unaccounted.isEmpty then CompareResult.OK + else addNewElems(unaccounted, origin) + + protected final def tryInclude(elem: CaptureRef, origin: CaptureSet)(using Context, VarState): CompareResult = + if accountsFor(elem) then CompareResult.OK + else addNewElems(elem.singletonCaptureSet.elems, origin) + + extension (x: CaptureRef) private def subsumes(y: CaptureRef) = + (x eq y) + || y.match + case y: TermRef => y.prefix eq x // ^^^ y.prefix.subsumes(x) ? + case _ => false + + /** {x} <:< this where <:< is subcapturing, but treating all variables + * as frozen. + */ + def accountsFor(x: CaptureRef)(using ctx: Context): Boolean = + reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) { + elems.exists(_.subsumes(x)) + || !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK + } + + /** The subcapturing test */ + final def subCaptures(that: CaptureSet, frozen: Boolean)(using Context): CompareResult = + subCaptures(that)(using ctx, if frozen then FrozenState else VarState()) + + private def subCaptures(that: CaptureSet)(using Context, VarState): CompareResult = + def recur(elems: List[CaptureRef]): CompareResult = elems match + case elem :: elems1 => + var result = that.tryInclude(elem, this) + if !result.isOK && !elem.isRootCapability && summon[VarState] != FrozenState then + result = elem.captureSetOfInfo.subCaptures(that) + if result.isOK then + recur(elems1) + else + varState.abort() + result + case Nil => + addSuper(that) + recur(elems.toList) + .showing(i"subcaptures $this <:< $that = ${result.show}", capt) + + def =:= (that: CaptureSet)(using Context): Boolean = + this.subCaptures(that, frozen = true).isOK + && that.subCaptures(this, frozen = true).isOK + + /** The smallest capture set (via <:<) that is a superset of both + * `this` and `that` + */ + def ++ (that: CaptureSet)(using Context): CaptureSet = + if this.subCaptures(that, frozen = true).isOK then that + else if that.subCaptures(this, frozen = true).isOK then this + else if this.isConst && that.isConst then Const(this.elems ++ that.elems) + else Var(this.elems ++ that.elems).addSub(this).addSub(that) + + /** The smallest superset (via <:<) of this capture set that also contains `ref`. + */ + def + (ref: CaptureRef)(using Context): CaptureSet = + this ++ ref.singletonCaptureSet + + /** The largest capture set (via <:<) that is a subset of both `this` and `that` + */ + def **(that: CaptureSet)(using Context): CaptureSet = + if this.subCaptures(that, frozen = true).isOK then this + else if that.subCaptures(this, frozen = true).isOK then that + else if this.isConst && that.isConst then Const(elems.intersect(that.elems)) + else if that.isConst then Intersected(this.asVar, that) + else Intersected(that.asVar, this) + + def -- (that: CaptureSet.Const)(using Context): CaptureSet = + val elems1 = elems.filter(!that.accountsFor(_)) + if elems1.size == elems.size then this + else if this.isConst then Const(elems1) + else Diff(asVar, that) + + def - (ref: CaptureRef)(using Context): CaptureSet = + this -- ref.singletonCaptureSet + + def filter(p: CaptureRef => Boolean)(using Context): CaptureSet = + if this.isConst then Const(elems.filter(p)) + else Filtered(asVar, p) + + /** capture set obtained by applying `f` to all elements of the current capture set + * and joining the results. If the current capture set is a variable, the same + * transformation is applied to all future additions of new elements. + */ + def map(tm: TypeMap)(using Context): CaptureSet = tm match + case tm: BiTypeMap => + val mappedElems = elems.map(tm.forward) + if isConst then Const(mappedElems) + else BiMapped(asVar, tm, mappedElems) + case _ => + val mapped = mapRefs(elems, tm, tm.variance) + if isConst then mapped + else Mapped(asVar, tm, tm.variance, mapped) + + def substParams(tl: BindingType, to: List[Type])(using Context) = + map(Substituters.SubstParamsMap(tl, to)) + + /** An upper approximation of this capture set. This is the set itself + * except for real (non-mapped, non-filtered) capture set variables, where + * it is the intersection of all upper approximations of known supersets + * of the variable. + * The upper approximation is meaningful only if it is constant. If not, + * `upperApprox` can return an arbitrary capture set variable. + */ + protected def upperApprox(origin: CaptureSet)(using Context): CaptureSet + + protected def propagateSolved()(using Context): Unit = () + + def toRetainsTypeArg(using Context): Type = + assert(isConst) + ((NoType: Type) /: elems) ((tp, ref) => + if tp.exists then OrType(tp, ref, soft = false) else ref) + + def toRegularAnnotation(using Context): Annotation = + Annotation(CaptureAnnotation(this, boxed = false).tree) + + override def toText(printer: Printer): Text = + Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}") + +object CaptureSet: + type Refs = SimpleIdentitySet[CaptureRef] + type Vars = SimpleIdentitySet[Var] + type Deps = SimpleIdentitySet[CaptureSet] + + /** If set to `true`, capture stack traces that tell us where sets are created */ + private final val debugSets = false + + private val emptySet = SimpleIdentitySet.empty + @sharable private var varId = 0 + + val empty: CaptureSet.Const = Const(emptySet) + + /** The universal capture set `{*}` */ + def universal(using Context): CaptureSet = + defn.captureRoot.termRef.singletonCaptureSet + + /** Used as a recursion brake */ + @sharable private[dotc] val Pending = Const(SimpleIdentitySet.empty) + + def apply(elems: CaptureRef*)(using Context): CaptureSet.Const = + if elems.isEmpty then empty + else Const(SimpleIdentitySet(elems.map(_.normalizedRef)*)) + + def apply(elems: Refs)(using Context): CaptureSet.Const = + if elems.isEmpty then empty else Const(elems) + + class Const private[CaptureSet] (val elems: Refs) extends CaptureSet: + assert(elems != null) + def isConst = true + def isAlwaysEmpty = elems.isEmpty + + def addNewElems(elems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + CompareResult.fail(this) + + def addSuper(cs: CaptureSet)(using Context, VarState) = CompareResult.OK + + def upperApprox(origin: CaptureSet)(using Context): CaptureSet = this + + override def toString = elems.toString + end Const + + class Var(initialElems: Refs = emptySet) extends CaptureSet: + val id = + varId += 1 + varId + + private var isSolved: Boolean = false + + var elems: Refs = initialElems + var deps: Deps = emptySet + def isConst = isSolved + def isAlwaysEmpty = false + + private def recordElemsState()(using VarState): Boolean = + varState.getElems(this) match + case None => varState.putElems(this, elems) + case _ => true + + private[CaptureSet] def recordDepsState()(using VarState): Boolean = + varState.getDeps(this) match + case None => varState.putDeps(this, deps) + case _ => true + + def resetElems()(using state: VarState): Unit = + elems = state.elems(this) + + def resetDeps()(using state: VarState): Unit = + deps = state.deps(this) + + def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + if !isConst && recordElemsState() then + elems ++= newElems + // assert(id != 2 || elems.size != 2, this) + (CompareResult.OK /: deps) { (r, dep) => + r.andAlso(dep.tryInclude(newElems, this)) + } + else + CompareResult.fail(this) + + def addSuper(cs: CaptureSet)(using Context, VarState): CompareResult = + if (cs eq this) || cs.elems.contains(defn.captureRoot.termRef) || isConst then + CompareResult.OK + else if recordDepsState() then + deps += cs + CompareResult.OK + else + CompareResult.fail(this) + + private var computingApprox = false + + final def upperApprox(origin: CaptureSet)(using Context): CaptureSet = + if computingApprox then universal + else if isConst then this + else + computingApprox = true + try computeApprox(origin).ensuring(_.isConst) + finally computingApprox = false + + protected def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + (universal /: deps) { (acc, sup) => acc ** sup.upperApprox(this) } + + def solve(variance: Int)(using Context): Unit = + if variance < 0 && !isConst then + val approx = upperApprox(empty) + //println(i"solving var $this $approx ${approx.isConst} deps = ${deps.toList}") + if approx.isConst then + val newElems = approx.elems -- elems + if newElems.isEmpty || addNewElems(newElems, empty)(using ctx, VarState()).isOK then + markSolved() + + def markSolved()(using Context): Unit = + isSolved = true + deps.foreach(_.propagateSolved()) + + protected def ids(using Context): String = + val trail = this.match + case dv: DerivedVar => dv.source.ids + case _ => "" + s"$id${getClass.getSimpleName.take(1)}$trail" + + override def toText(printer: Printer): Text = inContext(printer.printerContext) { + for vars <- ctx.property(ShownVars) do vars += this + super.toText(printer) ~ (Str(ids) provided !isConst && ctx.settings.YccDebug.value) + } + + override def toString = s"Var$id$elems" + end Var + + abstract class DerivedVar(initialElems: Refs)(using @constructorOnly ctx: Context) + extends Var(initialElems): + def source: Var + + addSub(source) + + override def propagateSolved()(using Context) = + if source.isConst && !isConst then markSolved() + end DerivedVar + + /** A variable that changes when `source` changes, where all additional new elements are mapped + * using ∪ { f(x) | x <- elems } + */ + class Mapped private[CaptureSet] + (val source: Var, tm: TypeMap, variance: Int, initial: CaptureSet)(using @constructorOnly ctx: Context) + extends DerivedVar(initial.elems): + addSub(initial) + val stack = if debugSets then (new Throwable).getStackTrace().take(20) else null + + private def whereCreated(using Context): String = + if stack == null then "" + else i""" + |Stack trace of variable creation:" + |${stack.mkString("\n")}""" + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + val added = + if origin eq source then + mapRefs(newElems, tm, variance) + else + if variance <= 0 && !origin.isConst && (origin ne initial) then + report.warning(i"trying to add elems $newElems from unrecognized source $origin of mapped set $this$whereCreated") + return CompareResult.fail(this) + Const(newElems) + super.addNewElems(added.elems, origin) + .andAlso { + if added.isConst then CompareResult.OK + else if added.asVar.recordDepsState() then { addSub(added); CompareResult.OK } + else CompareResult.fail(this) + } + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + if source eq origin then universal + else source.upperApprox(this).map(tm) + + override def propagateSolved()(using Context) = + if initial.isConst then super.propagateSolved() + + override def toString = s"Mapped$id($source, elems = $elems)" + end Mapped + + class BiMapped private[CaptureSet] + (val source: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context) + extends DerivedVar(initialElems): + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + if origin eq source then + super.addNewElems(newElems.map(bimap.forward), origin) + else + super.addNewElems(newElems, origin) + .andAlso { + source.tryInclude(newElems.map(bimap.backward), this) + .showing(i"propagating new elems $newElems backward from $this to $source", capt) + } + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + val supApprox = super.computeApprox(this) + if source eq origin then supApprox.map(bimap.inverseTypeMap) + else source.upperApprox(this).map(bimap) ** supApprox + + override def toString = s"BiMapped$id($source, elems = $elems)" + end BiMapped + + /** A variable with elements given at any time as { x <- source.elems | p(x) } */ + class Filtered private[CaptureSet] + (val source: Var, p: CaptureRef => Boolean)(using @constructorOnly ctx: Context) + extends DerivedVar(source.elems.filter(p)): + + override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = + super.addNewElems(newElems.filter(p), origin) + + override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = + if source eq origin then universal + else source.upperApprox(this).filter(p) + + override def toString = s"${getClass.getSimpleName}$id($source, elems = $elems)" + end Filtered + + /** A variable with elements given at any time as { x <- source.elems | !other.accountsFor(x) } */ + class Diff(source: Var, other: Const)(using Context) + extends Filtered(source, !other.accountsFor(_)) + + /** A variable with elements given at any time as { x <- source.elems | other.accountsFor(x) } */ + class Intersected(source: Var, other: CaptureSet)(using Context) + extends Filtered(source, other.accountsFor(_)): + addSub(other) + + def extrapolateCaptureRef(r: CaptureRef, tm: TypeMap, variance: Int)(using Context): CaptureSet = + val r1 = tm(r) + val upper = r1.captureSet + def isExact = + upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1) + if variance > 0 || isExact then upper + else if variance < 0 then CaptureSet.empty + else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting") + + def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet = + ((empty: CaptureSet) /: xs)((cs, x) => cs ++ f(x)) + + def mapRefs(xs: Refs, tm: TypeMap, variance: Int)(using Context): CaptureSet = + mapRefs(xs, extrapolateCaptureRef(_, tm, variance)) + + type CompareResult = CompareResult.Type + + /** None = ok, Some(cs) = failure since not a subset of cs */ + object CompareResult: + opaque type Type = CaptureSet + val OK: Type = Const(emptySet) + def fail(cs: CaptureSet): Type = cs + extension (result: Type) + def isOK: Boolean = result eq OK + def blocking: CaptureSet = result + def show: String = if result.isOK then "OK" else result.toString + def andAlso(op: Context ?=> Type)(using Context): Type = if result.isOK then op else result + + class VarState: + private val elemsMap: util.EqHashMap[Var, Refs] = new util.EqHashMap + private val depsMap: util.EqHashMap[Var, Deps] = new util.EqHashMap + + def elems(v: Var): Refs = elemsMap(v) + def getElems(v: Var): Option[Refs] = elemsMap.get(v) + def putElems(v: Var, elems: Refs): Boolean = { elemsMap(v) = elems; true } + + def deps(v: Var): Deps = depsMap(v) + def getDeps(v: Var): Option[Deps] = depsMap.get(v) + def putDeps(v: Var, deps: Deps): Boolean = { depsMap(v) = deps; true } + + def abort(): Unit = + elemsMap.keysIterator.foreach(_.resetElems()(using this)) + depsMap.keysIterator.foreach(_.resetDeps()(using this)) + end VarState + + @sharable + object FrozenState extends VarState: + override def putElems(v: Var, refs: Refs) = false + override def putDeps(v: Var, deps: Deps) = false + override def abort(): Unit = () + + @sharable + object UnrecordedState extends VarState: + override def putElems(v: Var, refs: Refs) = true + override def putDeps(v: Var, deps: Deps) = true + override def abort(): Unit = () + + def varState(using state: VarState): VarState = state + + def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet = + CaptureSet.empty + /* + def captureSetOf(tp: Type): CaptureSet = tp match + case tp: TypeRef if tp.symbol.is(ParamAccessor) => + def mapArg(accs: List[Symbol], tps: List[Type]): CaptureSet = accs match + case acc :: accs1 if tps.nonEmpty => + if acc == tp.symbol then tps.head.captureSet + else mapArg(accs1, tps.tail) + case _ => + empty + mapArg(cinfo.cls.paramAccessors, argTypes) + case _ => + tp.captureSet + val css = + for + parent <- cinfo.parents if parent.classSymbol == defn.RetainingClass + arg <- parent.argInfos + yield captureSetOf(arg) + css.foldLeft(empty)(_ ++ _) + */ + def ofInfo(ref: CaptureRef)(using Context): CaptureSet = ref match + case ref: ThisType => + val declaredCaptures = ref.cls.givenSelfType.captureSet + ref.cls.paramAccessors.foldLeft(declaredCaptures) ((cs, acc) => + cs ++ acc.termRef.captureSetOfInfo) // ^^^ need to also include outer references of inner classes + .showing(i"cc info $ref with ${ref.cls.paramAccessors.map(_.termRef)}%, % = $result", capt) + case ref: TermRef if ref.isRootCapability => ref.singletonCaptureSet + case _ => ofType(ref.underlying) + + def ofType(tp: Type)(using Context): CaptureSet = + def recur(tp: Type): CaptureSet = tp.dealias match + case tp: TermRef => + tp.captureSet + case tp: TermParamRef => + tp.captureSet + case _: TypeRef | _: TypeParamRef => + empty + case CapturingType(parent, refs, _) => + recur(parent) ++ refs + case AppliedType(tycon, args) => + val cs = recur(tycon) + tycon.typeParams match + case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args) + case _ => cs + case tp: TypeProxy => + recur(tp.underlying) + case AndType(tp1, tp2) => + recur(tp1) ** recur(tp2) + case OrType(tp1, tp2) => + recur(tp1) ++ recur(tp2) + case tp: ClassInfo => + ofClass(tp, Nil) + case _ => + empty + recur(tp) + .showing(i"capture set of $tp = $result", capt) + + private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key() + + def withCaptureSetsExplained[T](op: Context ?=> T)(using ctx: Context): T = + if ctx.settings.YccDebug.value then + val shownVars = mutable.Set[Var]() + inContext(ctx.withProperty(ShownVars, Some(shownVars))) { + try op + finally + val reachable = mutable.Set[Var]() + val todo = mutable.Queue[Var]() ++= shownVars + def incl(cv: Var): Unit = + if !reachable.contains(cv) then todo += cv + while todo.nonEmpty do + val cv = todo.dequeue() + if !reachable.contains(cv) then + reachable += cv + cv.deps.foreach { + case cv: Var => incl(cv) + case _ => + } + cv match + case cv: DerivedVar => incl(cv.source) + case _ => + val allVars = reachable.toArray.sortBy(_.id) + println(i"Capture set dependencies:") + for cv <- allVars do + println(i" ${cv.show.padTo(20, ' ')} :: ${cv.deps.toList}%, %") + } + else op +end CaptureSet diff --git a/compiler/src/dotty/tools/dotc/cc/CapturingType.scala b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala new file mode 100644 index 000000000000..2eeb1ff41b72 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/cc/CapturingType.scala @@ -0,0 +1,21 @@ +package dotty.tools +package dotc +package cc + +import core.* +import Types.*, Symbols.*, Contexts.* + +object CapturingType: + + def apply(parent: Type, refs: CaptureSet, boxed: Boolean)(using Context): Type = + if refs.isAlwaysEmpty then parent + else AnnotatedType(parent, CaptureAnnotation(refs, boxed)) + + def unapply(tp: AnnotatedType)(using Context): Option[(Type, CaptureSet, Boolean)] = + if ctx.phase == Phases.checkCapturesPhase && tp.annot.symbol == defn.RetainsAnnot then + tp.annot match + case ann: CaptureAnnotation => Some((tp.parent, ann.refs, ann.boxed)) + case ann => Some((tp.parent, ann.tree.toCaptureSet, ann.tree.isBoxedCapturing)) + else None + +end CapturingType diff --git a/compiler/src/dotty/tools/dotc/config/Config.scala b/compiler/src/dotty/tools/dotc/config/Config.scala index ac1708378e73..a54987b23ecc 100644 --- a/compiler/src/dotty/tools/dotc/config/Config.scala +++ b/compiler/src/dotty/tools/dotc/config/Config.scala @@ -227,4 +227,9 @@ object Config { * reduces the number of allocated denotations by ~50%. */ inline val reuseSymDenotations = true + + /** If true, print capturing types in the form `{c} T`. + * If false, print them in the form `T @retains(c)`. + */ + inline val printCaptureSetsAsPrefix = true } diff --git a/compiler/src/dotty/tools/dotc/config/Printers.scala b/compiler/src/dotty/tools/dotc/config/Printers.scala index b71e1e7f188a..d20d482b062e 100644 --- a/compiler/src/dotty/tools/dotc/config/Printers.scala +++ b/compiler/src/dotty/tools/dotc/config/Printers.scala @@ -12,6 +12,7 @@ object Printers { val default = new Printer + val capt = noPrinter val constr = noPrinter val core = noPrinter val checks = noPrinter diff --git a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala index 56e6ab14fae5..b0e161399e75 100644 --- a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -300,6 +300,8 @@ private sealed trait YSettings: val YcheckInit: Setting[Boolean] = BooleanSetting("-Ysafe-init", "Ensure safe initialization of objects") val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation") val Yrecheck: Setting[Boolean] = BooleanSetting("-Yrecheck", "Run type rechecks (test only)") + val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references") + val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Debug info for captured references") /** Area-specific debug output */ val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.") diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index b8d62210ce26..d0172c82972c 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -48,7 +48,7 @@ object Annotations { /** The tree evaluation has finished. */ def isEvaluated: Boolean = true - /** Normally, type map over all tree nodes of this annotation, but can + /** Normally, applies a type map to all tree nodes of this annotation, but can * be overridden. Returns EmptyAnnotation if type type map produces a range * type, since ranges cannot be types of trees. */ @@ -86,6 +86,10 @@ object Annotations { def sameAnnotation(that: Annotation)(using Context): Boolean = symbol == that.symbol && tree.sameTree(that.tree) + + /** Operations for hash-consing, can be overridden */ + def hash: Int = System.identityHashCode(this) + def eql(that: Annotation) = this eq that } case class ConcreteAnnotation(t: Tree) extends Annotation: diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index c3b462f3b179..ee296c47a305 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -14,6 +14,7 @@ import typer.ImportInfo.RootRef import Comments.CommentsContext import Comments.Comment import util.Spans.NoSpan +import cc.{CapturingType, CaptureSet} import scala.annotation.tailrec @@ -143,11 +144,13 @@ class Definitions { private def enterMethod(cls: ClassSymbol, name: TermName, info: Type, flags: FlagSet = EmptyFlags): TermSymbol = newMethod(cls, name, info, flags).entered - private def enterAliasType(name: TypeName, tpe: Type, flags: FlagSet = EmptyFlags): TypeSymbol = { - val sym = newPermanentSymbol(ScalaPackageClass, name, flags, TypeAlias(tpe)) + private def enterPermanentSymbol(name: Name, info: Type, flags: FlagSet = EmptyFlags): Symbol = + val sym = newPermanentSymbol(ScalaPackageClass, name, flags, info) ScalaPackageClass.currentPackageDecls.enter(sym) sym - } + + private def enterAliasType(name: TypeName, tpe: Type, flags: FlagSet = EmptyFlags): TypeSymbol = + enterPermanentSymbol(name, TypeAlias(tpe), flags).asType private def enterBinaryAlias(name: TypeName, op: (Type, Type) => Type): TypeSymbol = enterAliasType(name, @@ -440,6 +443,7 @@ class Definitions { @tu lazy val andType: TypeSymbol = enterBinaryAlias(tpnme.AND, AndType(_, _)) @tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _, soft = false)) + @tu lazy val captureRoot: TermSymbol = enterPermanentSymbol(nme.CAPTURE_ROOT, AnyType).asTerm /** Marker method to indicate an argument to a call-by-name parameter. * Created by byNameClosures and elimByName, eliminated by Erasure, @@ -941,6 +945,8 @@ class Definitions { @tu lazy val FunctionalInterfaceAnnot: ClassSymbol = requiredClass("java.lang.FunctionalInterface") @tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName") @tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs") + @tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains") + @tu lazy val AbilityAnnot: ClassSymbol = requiredClass("scala.annotation.ability") @tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable") @@ -1514,6 +1520,9 @@ class Definitions { def isFunctionType(tp: Type)(using Context): Boolean = isNonRefinedFunction(tp.dropDependentRefinement) + def isFunctionOrPolyType(tp: RefinedType)(using Context): Boolean = + isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass) + // Specialized type parameters defined for scala.Function{0,1,2}. @tu lazy val Function1SpecializedParamTypes: collection.Set[TypeRef] = Set(IntType, LongType, FloatType, DoubleType) @@ -1812,7 +1821,7 @@ class Definitions { this.initCtx = ctx if (!isInitialized) { // force initialization of every symbol that is synthesized or hijacked by the compiler - val forced = syntheticCoreClasses ++ syntheticCoreMethods ++ ScalaValueClasses() :+ JavaEnumClass + val forced = syntheticCoreClasses ++ syntheticCoreMethods ++ ScalaValueClasses() ++ List(JavaEnumClass, captureRoot) isInitialized = true } diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 17df7149e9be..b6f8f00a91fb 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -12,6 +12,7 @@ import config.Printers.constr import reflect.ClassTag import annotation.tailrec import annotation.internal.sharable +import cc.{CapturingType, derivedCapturingType} object OrderingConstraint { @@ -328,6 +329,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds, case tp: TypeVar => val underlying1 = recur(tp.underlying, fromBelow) if underlying1 ne tp.underlying then underlying1 else tp + case CapturingType(parent, refs, _) => + val parent1 = recur(parent, fromBelow) + if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp case tp: AnnotatedType => val parent1 = recur(tp.parent, fromBelow) if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 0294de01f36e..a1faa1428188 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -13,10 +13,12 @@ import scala.collection.mutable.ListBuffer import dotty.tools.dotc.transform.MegaPhase._ import dotty.tools.dotc.transform._ import Periods._ -import parsing.{ Parser} +import parsing.Parser +import printing.XprintMode import typer.{TyperPhase, RefChecks} +import cc.CheckCaptures import typer.ImportInfo.withRootImports -import ast.tpd +import ast.{tpd, untpd} import scala.annotation.internal.sharable import scala.util.control.NonFatal @@ -216,6 +218,7 @@ object Phases { private var myCountOuterAccessesPhase: Phase = _ private var myFlattenPhase: Phase = _ private var myGenBCodePhase: Phase = _ + private var myCheckCapturesPhase: Phase = _ final def parserPhase: Phase = myParserPhase final def typerPhase: Phase = myTyperPhase @@ -238,6 +241,7 @@ object Phases { final def countOuterAccessesPhase = myCountOuterAccessesPhase final def flattenPhase: Phase = myFlattenPhase final def genBCodePhase: Phase = myGenBCodePhase + final def checkCapturesPhase: Phase = myCheckCapturesPhase private def setSpecificPhases() = { def phaseOfClass(pclass: Class[?]) = phases.find(pclass.isInstance).getOrElse(NoPhase) @@ -262,7 +266,8 @@ object Phases { myFlattenPhase = phaseOfClass(classOf[Flatten]) myExplicitOuterPhase = phaseOfClass(classOf[ExplicitOuter]) myGettersPhase = phaseOfClass(classOf[Getters]) - myGenBCodePhase = phaseOfClass(classOf[GenBCode]) + myGenBCodePhase = phaseOfClass(classOf[GenBCode]) + myCheckCapturesPhase = phaseOfClass(classOf[CheckCaptures]) } final def isAfterTyper(phase: Phase): Boolean = phase.id > typerPhase.id @@ -312,6 +317,10 @@ object Phases { unitCtx.compilationUnit } + /** Convert a compilation unit's tree to a string; can be overridden */ + def show(tree: untpd.Tree)(using Context): String = + tree.show(using ctx.withProperty(XprintMode, Some(()))) + def description: String = phaseName /** Output should be checkable by TreeChecker */ @@ -438,6 +447,7 @@ object Phases { def lambdaLiftPhase(using Context): Phase = ctx.base.lambdaLiftPhase def flattenPhase(using Context): Phase = ctx.base.flattenPhase def genBCodePhase(using Context): Phase = ctx.base.genBCodePhase + def checkCapturesPhase(using Context): Phase = ctx.base.checkCapturesPhase def unfusedPhases(using Context): Array[Phase] = ctx.base.phases diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 2e0b229ca42c..2a575cd0ad4d 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -275,6 +275,7 @@ object StdNames { // Compiler-internal val ANYname: N = "" + val CAPTURE_ROOT: N = "*" val COMPANION: N = "" val CONSTRUCTOR: N = "" val STATIC_CONSTRUCTOR: N = "" @@ -362,6 +363,7 @@ object StdNames { val AppliedTypeTree: N = "AppliedTypeTree" val ArrayAnnotArg: N = "ArrayAnnotArg" val CAP: N = "CAP" + val ClassManifestFactory: N = "ClassManifestFactory" val Constant: N = "Constant" val ConstantType: N = "ConstantType" val Eql: N = "Eql" @@ -439,7 +441,6 @@ object StdNames { val canEqualAny : N = "canEqualAny" val cbnArg: N = "" val checkInitialized: N = "checkInitialized" - val ClassManifestFactory: N = "ClassManifestFactory" val classOf: N = "classOf" val classType: N = "classType" val clone_ : N = "clone" @@ -571,6 +572,7 @@ object StdNames { val reflectiveSelectable: N = "reflectiveSelectable" val reify : N = "reify" val releaseFence : N = "releaseFence" + val retains: N = "retains" val rootMirror : N = "rootMirror" val run: N = "run" val runOrElse: N = "runOrElse" diff --git a/compiler/src/dotty/tools/dotc/core/Substituters.scala b/compiler/src/dotty/tools/dotc/core/Substituters.scala index f00edcb189c6..b277f2cd8619 100644 --- a/compiler/src/dotty/tools/dotc/core/Substituters.scala +++ b/compiler/src/dotty/tools/dotc/core/Substituters.scala @@ -161,8 +161,9 @@ object Substituters: .mapOver(tp) } - final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap { + final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) + def inverse(tp: Type): Type = tp.subst(to, from) } final class Subst1Map(from: Symbol, to: Type)(using Context) extends DeepTypeMap { @@ -177,8 +178,9 @@ object Substituters: def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) } - final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap { + final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = substSym(tp, from, to, this)(using mapCtx) + def inverse(tp: Type) = tp.substSym(to, from) } final class SubstThisMap(from: ClassSymbol, to: Type)(using Context) extends DeepTypeMap { diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index fcfbba208eb9..f5e5ea31d845 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -24,6 +24,7 @@ import config.Config import reporting._ import collection.mutable import transform.TypeUtils._ +import cc.{CapturingType, derivedCapturingType} import scala.annotation.internal.sharable @@ -224,6 +225,8 @@ object SymDenotations { ensureCompleted(); myAnnotations } + final def annotationsUNSAFE(using Context): List[Annotation] = myAnnotations + /** Update the annotations of this denotation */ final def annotations_=(annots: List[Annotation]): Unit = myAnnotations = annots @@ -1509,8 +1512,7 @@ object SymDenotations { case tp: ExprType => hasSkolems(tp.resType) case tp: AppliedType => hasSkolems(tp.tycon) || tp.args.exists(hasSkolems) case tp: LambdaType => tp.paramInfos.exists(hasSkolems) || hasSkolems(tp.resType) - case tp: AndType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) - case tp: OrType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) + case tp: AndOrType => hasSkolems(tp.tp1) || hasSkolems(tp.tp2) case tp: AnnotatedType => hasSkolems(tp.parent) case _ => false } @@ -2166,6 +2168,9 @@ object SymDenotations { case tp: TypeParamRef => // uncachable, since baseType depends on context bounds recur(TypeComparer.bounds(tp).hi) + case CapturingType(parent, refs, _) => + tp.derivedCapturingType(recur(parent), refs) + case tp: TypeProxy => def computeTypeProxy = { val superTp = tp.superType diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index add2030b6a82..b942526fa59e 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -24,6 +24,7 @@ import typer.Applications.productSelectorTypes import reporting.trace import NullOpsDecorator._ import annotation.constructorOnly +import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing} /** Provides methods to compare types. */ @@ -325,6 +326,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling compareWild case tp2: LazyRef => isBottom(tp1) || !tp2.evaluating && recur(tp1, tp2.ref) + case CapturingType(_, _, _) => + secondTry case tp2: AnnotatedType if !tp2.isRefining => recur(tp1, tp2.parent) case tp2: ThisType => @@ -444,8 +447,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // See i859.scala for an example where we hit this case. tp2.isRef(AnyClass, skipRefined = false) || !tp1.evaluating && recur(tp1.ref, tp2) - case tp1: AnnotatedType if !tp1.isRefining => - recur(tp1.parent, tp2) case AndType(tp11, tp12) => if (tp11.stripTypeVar eq tp12.stripTypeVar) recur(tp11, tp2) else thirdTry @@ -489,7 +490,14 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // and then need to check that they are indeed supertypes of the original types // under -Ycheck. Test case is i7965.scala. - case tp1: MatchType => + case CapturingType(parent1, refs1, _) => + if subCaptures(refs1, tp2.captureSet, frozenConstraint).isOK then + recur(parent1, tp2) + else + thirdTry + case tp1: AnnotatedType if !tp1.isRefining => + recur(tp1.parent, tp2) + case tp1: MatchType => val reduced = tp1.reduced if (reduced.exists) recur(reduced, tp2) else thirdTry case _: FlexType => @@ -527,8 +535,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // Note: We would like to replace this by `if (tp1.hasHigherKind)` // but right now we cannot since some parts of the standard library rely on the // idiom that e.g. `List <: Any`. We have to bootstrap without scalac first. - if (cls2 eq AnyClass) return true - if (cls2 == defn.SingletonClass && tp1.isStable) return true + if cls2 eq AnyClass then return true + if cls2 == defn.SingletonClass && tp1.isStable then return true return tryBaseType(cls2) } else if (cls2.is(JavaDefined)) { @@ -597,6 +605,28 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling isSubRefinements(tp1w.asInstanceOf[RefinedType], tp2, skipped2) && recur(tp1, skipped2) + def isSubInfo(info1: Type, info2: Type): Boolean = (info1, info2) match + case (info1: PolyType, info2: PolyType) => + sameLength(info1.paramNames, info2.paramNames) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1)) + case (info1: MethodType, info2: MethodType) => + matchingMethodParams(info1, info2, precise = false) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1)) + case _ => + isSubType(info1, info2) + + if ctx.phase == Phases.checkCapturesPhase then + if defn.isFunctionType(tp2) then + tp1.widenDealias match + case tp1: RefinedType => + return isSubInfo(tp1.refinedInfo, tp2.refinedInfo) + case _ => + else if tp2.parent.typeSymbol == defn.PolyFunctionClass then + tp1.member(nme.apply).info match + case info1: PolyType => + return isSubInfo(info1, tp2.refinedInfo) + case _ => + compareRefined case tp2: RecType => def compareRec = tp1.safeDealias match { @@ -727,13 +757,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def compareTypeBounds = tp1 match { case tp1 @ TypeBounds(lo1, hi1) => ((lo2 eq NothingType) || isSubType(lo2, lo1)) && - ((hi2 eq AnyType) && !hi1.isLambdaSub || (hi2 eq AnyKindType) || isSubType(hi1, hi2)) + ((hi2 eq AnyType) && !hi1.isLambdaSub + || (hi2 eq AnyKindType) + || isSubType(hi1, hi2)) case tp1: ClassInfo => tp2 contains tp1 case _ => false } compareTypeBounds + case CapturingType(parent2, _, _) => + recur(tp1, parent2) || fourthTry case tp2: AnnotatedType if tp2.isRefining => (tp1.derivesAnnotWith(tp2.annot.sameAnnotation) || tp1.isBottomType) && recur(tp1, tp2.parent) @@ -780,6 +814,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp: AppliedType => isNullable(tp.tycon) case AndType(tp1, tp2) => isNullable(tp1) && isNullable(tp2) case OrType(tp1, tp2) => isNullable(tp1) || isNullable(tp2) + case CapturingType(tp1, _, _) => isNullable(tp1) case _ => false } val sym1 = tp1.symbol @@ -798,7 +833,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => false } case _ => false - comparePaths || isSubType(tp1.underlying.widenExpr, tp2, approx.addLow) + comparePaths || { + var tp1w = tp1.underlying.widenExpr + tp1 match + case tp1: CaptureRef if tp1.isTracked => + val stripped = tp1w.stripCapturing + tp1w = CapturingType(stripped, tp1.singletonCaptureSet, boxed = false) + case _ => + isSubType(tp1w, tp2, approx.addLow) + } case tp1: RefinedType => isNewSubType(tp1.parent) case tp1: RecType => @@ -1769,69 +1812,68 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def hasMatchingMember(name: Name, tp1: Type, tp2: RefinedType): Boolean = trace(i"hasMatchingMember($tp1 . $name :? ${tp2.refinedInfo}), mbr: ${tp1.member(name).info}", subtyping) { - def qualifies(m: SingleDenotation): Boolean = - // If the member is an abstract type and the prefix is a path, compare the member itself - // instead of its bounds. This case is needed situations like: - // - // class C { type T } - // val foo: C - // foo.type <: C { type T {= , <: , >:} foo.T } - // - // or like: - // - // class C[T] - // C[?] <: C[TV] - // - // where TV is a type variable. See i2397.scala for an example of the latter. - def matchAbstractTypeMember(info1: Type): Boolean = info1 match { - case TypeBounds(lo, hi) if lo ne hi => - tp2.refinedInfo match { - case rinfo2: TypeBounds if tp1.isStable => - val ref1 = tp1.widenExpr.select(name) - isSubType(rinfo2.lo, ref1) && isSubType(ref1, rinfo2.hi) - case _ => - false - } - case _ => false - } + // If the member is an abstract type and the prefix is a path, compare the member itself + // instead of its bounds. This case is needed situations like: + // + // class C { type T } + // val foo: C + // foo.type <: C { type T {= , <: , >:} foo.T } + // + // or like: + // + // class C[T] + // C[?] <: C[TV] + // + // where TV is a type variable. See i2397.scala for an example of the latter. + def matchAbstractTypeMember(info1: Type): Boolean = info1 match { + case TypeBounds(lo, hi) if lo ne hi => + tp2.refinedInfo match { + case rinfo2: TypeBounds if tp1.isStable => + val ref1 = tp1.widenExpr.select(name) + isSubType(rinfo2.lo, ref1) && isSubType(ref1, rinfo2.hi) + case _ => + false + } + case _ => false + } - // An additional check for type member matching: If the refinement of the - // supertype `tp2` does not refer to a member symbol defined in the parent of `tp2`. - // then the symbol referred to in the subtype must have a signature that coincides - // in its parameters with the refinement's signature. The reason for the check - // is that if the refinement does not refer to a member symbol, we will have to - // resort to reflection to invoke the member. And Java reflection needs to know exact - // erased parameter types. See neg/i12211.scala. Other reflection algorithms could - // conceivably dispatch without knowning precise parameter signatures. One can signal - // this by inheriting from the `scala.reflect.SignatureCanBeImprecise` marker trait, - // in which case the signature test is elided. - def sigsOK(symInfo: Type, info2: Type) = - tp2.underlyingClassRef(refinementOK = true).member(name).exists - || tp2.derivesFrom(defn.WithoutPreciseParameterTypesClass) - || symInfo.isInstanceOf[MethodType] - && symInfo.signature.consistentParams(info2.signature) - - // A relaxed version of isSubType, which compares method types - // under the standard arrow rule which is contravarient in the parameter types, - // but under the condition that signatures might have to match (see sigsOK) - // This relaxed version is needed to correctly compare dependent function types. - // See pos/i12211.scala. - def isSubInfo(info1: Type, info2: Type, symInfo: Type): Boolean = - info2 match - case info2: MethodType => - info1 match - case info1: MethodType => - val symInfo1 = symInfo.stripPoly - matchingMethodParams(info1, info2, precise = false) - && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType) - && sigsOK(symInfo1, info2) - case _ => isSubType(info1, info2) - case _ => isSubType(info1, info2) + // An additional check for type member matching: If the refinement of the + // supertype `tp2` does not refer to a member symbol defined in the parent of `tp2`. + // then the symbol referred to in the subtype must have a signature that coincides + // in its parameters with the refinement's signature. The reason for the check + // is that if the refinement does not refer to a member symbol, we will have to + // resort to reflection to invoke the member. And Java reflection needs to know exact + // erased parameter types. See neg/i12211.scala. Other reflection algorithms could + // conceivably dispatch without knowning precise parameter signatures. One can signal + // this by inheriting from the `scala.reflect.SignatureCanBeImprecise` marker trait, + // in which case the signature test is elided. + def sigsOK(symInfo: Type, info2: Type) = + tp2.underlyingClassRef(refinementOK = true).member(name).exists + || tp2.derivesFrom(defn.WithoutPreciseParameterTypesClass) + || symInfo.isInstanceOf[MethodType] + && symInfo.signature.consistentParams(info2.signature) + + // A relaxed version of isSubType, which compares method types + // under the standard arrow rule which is contravarient in the parameter types, + // but under the condition that signatures might have to match (see sigsOK) + // This relaxed version is needed to correctly compare dependent function types. + // See pos/i12211.scala. + def isSubInfo(info1: Type, info2: Type, symInfo: Type): Boolean = + info2 match + case info2: MethodType => + info1 match + case info1: MethodType => + val symInfo1 = symInfo.stripPoly + matchingMethodParams(info1, info2, precise = false) + && isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType) + && sigsOK(symInfo1, info2) + case _ => isSubType(info1, info2) + case _ => isSubType(info1, info2) + def qualifies(m: SingleDenotation): Boolean = val info1 = m.info.widenExpr isSubInfo(info1, tp2.refinedInfo.widenExpr, m.symbol.info.orElse(info1)) || matchAbstractTypeMember(m.info) - end qualifies tp1.member(name) match // inlined hasAltWith for performance case mbr: SingleDenotation => qualifies(mbr) @@ -1956,8 +1998,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case formal2 :: rest2 => val formal2a = if (tp2.isParamDependent) formal2.subst(tp2, tp1) else formal2 val paramsMatch = - if precise then isSameTypeWhenFrozen(formal1, formal2a) - else isSubTypeWhenFrozen(formal2a, formal1) + if precise then + isSameTypeWhenFrozen(formal1, formal2a) + else if ctx.phase == Phases.checkCapturesPhase then + isSubType(formal2a, formal1) + else + isSubTypeWhenFrozen(formal2a, formal1) paramsMatch && loop(rest1, rest2) case nil => false @@ -2360,6 +2406,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } case tp1: TypeVar if tp1.isInstantiated => tp1.underlying & tp2 + case CapturingType(parent1, refs1, _) => + if subCaptures(tp2.captureSet, refs1, frozenConstraint).isOK then + parent1 & tp2 + else + tp1.derivedCapturingType(parent1 & tp2, refs1) case tp1: AnnotatedType if !tp1.isRefining => tp1.underlying & tp2 case _ => @@ -2422,6 +2473,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling false } + protected def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult.Type = + refs1.subCaptures(refs2, frozen) + // ----------- Diagnostics -------------------------------------------------- /** A hook for showing subtype traces. Overridden in ExplainingTypeComparer */ @@ -2687,6 +2741,7 @@ object TypeComparer { else res match case ClassInfo(_, cls, _, _, _) => cls.showLocated case bounds: TypeBounds => i"type bounds [$bounds]" + case CaptureSet.CompareResult.OK => "OK" case res: printing.Showable => res.show case _ => String.valueOf(res) @@ -3015,5 +3070,10 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.addConstraint(param, bound, fromBelow) } + override def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult.Type = + traceIndented(i"subcaptures $refs1 <:< $refs2 ${if frozen then "frozen" else ""}") { + super.subCaptures(refs1, refs2, frozen) + } + def lastTrace(header: String): String = header + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala index c9ca98f65f5e..9067d0c87142 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala @@ -73,6 +73,7 @@ class RecursionOverflow(val op: String, details: => String, val previous: Throwa s"""Recursion limit exceeded. |Maybe there is an illegal cyclic reference? |If that's not the case, you could also try to increase the stacksize using the -Xss JVM option. + |For the unprocessed stack trace, compile with -Yno-decode-stacktraces. |A recurring operation is (inner to outer): |${opsString(mostCommon)}""".stripMargin } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 6a5145ffd202..dfdcb5d38054 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -19,6 +19,8 @@ import typer.ForceDegree import typer.Inferencing._ import typer.IfBottom import reporting.TestingReporter +import cc.{CapturingType, derivedCapturingType, CaptureSet} +import CaptureSet.CompareResult import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -164,6 +166,12 @@ object TypeOps: // with Nulls (which have no base classes). Under -Yexplicit-nulls, we take // corrective steps, so no widening is wanted. simplify(l, theMap) | simplify(r, theMap) + case CapturingType(parent, refs, _) => + if !ctx.mode.is(Mode.Type) + && refs.subCaptures(parent.captureSet, frozen = true).isOK then + simplify(parent, theMap) + else + mapOver case tp @ AnnotatedType(parent, annot) => val parent1 = simplify(parent, theMap) if annot.symbol == defn.UncheckedVarianceAnnot @@ -273,15 +281,23 @@ object TypeOps: case _ => false } - // Step 1: Get RecTypes and ErrorTypes out of the way, + // Step 1: Get RecTypes and ErrorTypes and CapturingTypes out of the way, tp1 match { - case tp1: RecType => return tp1.rebind(approximateOr(tp1.parent, tp2)) - case err: ErrorType => return err + case tp1: RecType => + return tp1.rebind(approximateOr(tp1.parent, tp2)) + case CapturingType(parent1, refs1, _) => + return tp1.derivedCapturingType(approximateOr(parent1, tp2), refs1) + case err: ErrorType => + return err case _ => } tp2 match { - case tp2: RecType => return tp2.rebind(approximateOr(tp1, tp2.parent)) - case err: ErrorType => return err + case tp2: RecType => + return tp2.rebind(approximateOr(tp1, tp2.parent)) + case CapturingType(parent2, refs2, _) => + return tp2.derivedCapturingType(approximateOr(tp1, parent2), refs2) + case err: ErrorType => + return err case _ => } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index d3c0eeab73d9..84dec9e43784 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -38,6 +38,8 @@ import scala.util.hashing.{ MurmurHash3 => hashing } import config.Printers.{core, typr, matchTypes} import reporting.{trace, Message} import java.lang.ref.WeakReference +import cc.{CapturingType, CaptureSet, derivedCapturingType, retainedElems, isBoxedCapturing} +import CaptureSet.CompareResult import scala.annotation.internal.sharable import scala.annotation.threadUnsafe @@ -67,7 +69,7 @@ object Types { * | | +--- SkolemType * | +- TypeParamRef * | +- RefinedOrRecType -+-- RefinedType - * | | -+-- RecType + * | | +-- RecType * | +- AppliedType * | +- TypeBounds * | +- ExprType @@ -187,7 +189,7 @@ object Types { * It makes no sense for it to be an alias type because isRef would always * return false in that case. */ - def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = stripped match { + def isRef(sym: Symbol, skipRefined: Boolean = true)(using Context): Boolean = this match { case this1: TypeRef => this1.info match { // see comment in Namer#typeDefSig case TypeAlias(tp) => tp.isRef(sym, skipRefined) @@ -199,6 +201,12 @@ object Types { val this2 = this1.dealias if (this2 ne this1) this2.isRef(sym, skipRefined) else this1.underlying.isRef(sym, skipRefined) + case this1: TypeVar => + this1.instanceOpt.isRef(sym, skipRefined) + case this1: AnnotatedType => + this1 match + case CapturingType(_, _, _) => false + case _ => this1.parent.isRef(sym, skipRefined) case _ => false } @@ -365,6 +373,7 @@ object Types { case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference) case WildcardType(optBounds) => optBounds.unusableForInference + case CapturingType(parent, refs, _) => parent.unusableForInference || refs.elems.exists(_.unusableForInference) case _: ErrorType => true case _ => false @@ -1174,9 +1183,13 @@ object Types { */ def stripAnnots(using Context): Type = this - /** Strip TypeVars and Annotation wrappers */ + /** Strip TypeVars and Annotation and CapturingType wrappers */ def stripped(using Context): Type = this + def strippedDealias(using Context): Type = + val tp1 = stripped.dealias + if tp1 ne this then tp1.strippedDealias else this + def rewrapAnnots(tp: Type)(using Context): Type = tp.stripTypeVar match { case AnnotatedType(tp1, annot) => AnnotatedType(rewrapAnnots(tp1), annot) case _ => this @@ -1367,8 +1380,13 @@ object Types { val tp1 = tp.instanceOpt if (tp1.exists) tp1.dealias1(keep) else tp case tp: AnnotatedType => - val tp1 = tp.parent.dealias1(keep) - if keep(tp) then tp.derivedAnnotatedType(tp1, tp.annot) else tp1 + val parent1 = tp.parent.dealias1(keep) + tp match + case tp @ CapturingType(parent, refs, _) => + tp.derivedCapturingType(parent1, refs) + case _ => + if keep(tp) then tp.derivedAnnotatedType(parent1, tp.annot) + else parent1 case tp: LazyRef => tp.ref.dealias1(keep) case _ => this @@ -1461,7 +1479,7 @@ object Types { if (tp.tycon.isLambdaSub) NoType else tp.superType.underlyingClassRef(refinementOK) case tp: AnnotatedType => - tp.underlying.underlyingClassRef(refinementOK) + tp.parent.underlyingClassRef(refinementOK) case tp: RefinedType => if (refinementOK) tp.underlying.underlyingClassRef(refinementOK) else NoType case tp: RecType => @@ -1504,6 +1522,8 @@ object Types { case _ => if (isRepeatedParam) this.argTypesHi.head else this } + def captureSet(using Context): CaptureSet = CaptureSet.ofType(this) + // ----- Normalizing typerefs over refined types ---------------------------- /** If this normalizes* to a refinement type that has a refinement for `name` (which might be followed @@ -1791,7 +1811,7 @@ object Types { * @param dropLast The number of trailing parameters that should be dropped * when forming the function type. */ - def toFunctionType(isJava: Boolean, dropLast: Int = 0)(using Context): Type = this match { + def toFunctionType(isJava: Boolean, dropLast: Int = 0, alwaysDependent: Boolean = false)(using Context): Type = this match { case mt: MethodType if !mt.isParamDependent => val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast val isContextual = mt.isContextualMethod && !ctx.erasedTypes @@ -1803,7 +1823,7 @@ object Types { val funType = defn.FunctionOf( formals1 mapConserve (_.translateFromRepeated(toArray = isJava)), result1, isContextual, isErased) - if (mt.isResultDependent) RefinedType(funType, nme.apply, mt) + if alwaysDependent || mt.isResultDependent then RefinedType(funType, nme.apply, mt) else funType } @@ -1835,6 +1855,16 @@ object Types { case _ => this } + def capturing(ref: CaptureRef)(using Context): Type = + if captureSet.accountsFor(ref) then this + else CapturingType(this, ref.singletonCaptureSet, this.isBoxedCapturing) + + def capturing(cs: CaptureSet)(using Context): Type = + if cs.isConst && cs.subCaptures(captureSet, frozen = true).isOK then this + else this match + case CapturingType(parent, cs1, boxed) => parent.capturing(cs1 ++ cs) + case _ => CapturingType(this, cs, this.isBoxedCapturing) + /** The set of distinct symbols referred to by this type, after all aliases are expanded */ def coveringSet(using Context): Set[Symbol] = (new CoveringSetAccumulator).apply(Set.empty[Symbol], this) @@ -2015,6 +2045,40 @@ object Types { def isOverloaded(using Context): Boolean = false } + /** A trait for references in CaptureSets. These can be NamedTypes, ThisTypes or ParamRefs */ + trait CaptureRef extends SingletonType: + private var myCaptureSet: CaptureSet = _ + private var myCaptureSetRunId: Int = NoRunId + private var mySingletonCaptureSet: CaptureSet.Const = null + + def canBeTracked(using Context): Boolean + final def isTracked(using Context): Boolean = canBeTracked && !captureSetOfInfo.isAlwaysEmpty + def isRootCapability(using Context): Boolean = false + def normalizedRef(using Context): CaptureRef = this + + def singletonCaptureSet(using Context): CaptureSet.Const = + if mySingletonCaptureSet == null then + mySingletonCaptureSet = CaptureSet(this.normalizedRef) + mySingletonCaptureSet + + def captureSetOfInfo(using Context): CaptureSet = + if ctx.runId == myCaptureSetRunId then myCaptureSet + else if myCaptureSet eq CaptureSet.Pending then CaptureSet.empty + else + myCaptureSet = CaptureSet.Pending + val computed = CaptureSet.ofInfo(this) + if ctx.phase != Phases.checkCapturesPhase || underlying.isProvisional then + myCaptureSet = null + else + myCaptureSet = computed + myCaptureSetRunId = ctx.runId + computed + + override def captureSet(using Context): CaptureSet = + val cs = captureSetOfInfo + if canBeTracked && !cs.isAlwaysEmpty then singletonCaptureSet else cs + end CaptureRef + /** A trait for types that bind other types that refer to them. * Instances are: LambdaType, RecType. */ @@ -2062,7 +2126,7 @@ object Types { // --- NamedTypes ------------------------------------------------------------------ - abstract class NamedType extends CachedProxyType with ValueType { self => + abstract class NamedType extends CachedProxyType, ValueType { self => type ThisType >: this.type <: NamedType type ThisName <: Name @@ -2081,6 +2145,9 @@ object Types { private var mySignature: Signature = _ private var mySignatureRunId: Int = NoRunId + private var myCaptureSet: CaptureSet = _ + private var myCaptureSetRunId: Int = NoRunId + // Invariants: // (1) checkedPeriod != Nowhere => lastDenotation != null // (2) lastDenotation != null => lastSymbol != null @@ -2433,7 +2500,7 @@ object Types { val tparam = symbol val cls = tparam.owner val base = pre.baseType(cls) - base match { + base.stripped match { case AppliedType(_, allArgs) => var tparams = cls.typeParams var args = allArgs @@ -2613,7 +2680,7 @@ object Types { */ abstract case class TermRef(override val prefix: Type, private var myDesignator: Designator) - extends NamedType with SingletonType with ImplicitRef { + extends NamedType, ImplicitRef, CaptureRef { type ThisType = TermRef type ThisName = TermName @@ -2637,6 +2704,25 @@ object Types { def implicitName(using Context): TermName = name def underlyingRef: TermRef = this + + /** A term reference can be tracked if it is a local term ref to a value + * or a method term parameter. References to term parameters of classes + * cannot be tracked individually. + * They are subsumed in the capture sets of the enclosing class. + * TODO: ^^^ What avout call-by-name? + */ + def canBeTracked(using Context) = + ((prefix eq NoPrefix) + || symbol.is(ParamAccessor) && (prefix eq symbol.owner.thisType) + || symbol.hasAnnotation(defn.AbilityAnnot) + || isRootCapability + ) && !symbol.is(Method) + + override def isRootCapability(using Context): Boolean = + name == nme.CAPTURE_ROOT && symbol == defn.captureRoot + + override def normalizedRef(using Context): CaptureRef = + if canBeTracked then symbol.termRef else this } abstract case class TypeRef(override val prefix: Type, @@ -2772,7 +2858,7 @@ object Types { * Note: we do not pass a class symbol directly, because symbols * do not survive runs whereas typerefs do. */ - abstract case class ThisType(tref: TypeRef) extends CachedProxyType with SingletonType { + abstract case class ThisType(tref: TypeRef) extends CachedProxyType, CaptureRef { def cls(using Context): ClassSymbol = tref.stableInRunSymbol match { case cls: ClassSymbol => cls case _ if ctx.mode.is(Mode.Interactive) => defn.AnyClass // was observed to happen in IDE mode @@ -2786,6 +2872,8 @@ object Types { // can happen in IDE if `cls` is stale } + def canBeTracked(using Context) = true + override def computeHash(bs: Binders): Int = doHash(bs, tref) override def eql(that: Type): Boolean = that match { @@ -3612,9 +3700,17 @@ object Types { case tp: AppliedType => tp.fold(status, compute(_, _, theAcc)) case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional) case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps - case AnnotatedType(parent, ann) => - if ann.refersToParamOf(thisLambdaType) then TrueDeps - else compute(status, parent, theAcc) + case tp: AnnotatedType => + tp match + case CapturingType(parent, refs, _) => + (compute(status, parent, theAcc) /: refs.elems) { + (s, ref) => ref match + case tp: TermParamRef if tp.binder eq thisLambdaType => combine(s, CaptureDeps) + case _ => s + } + case _ => + if tp.annot.refersToParamOf(thisLambdaType) then TrueDeps + else compute(status, tp.parent, theAcc) case _: ThisType | _: BoundType | NoPrefix => status case _ => (if theAcc != null then theAcc else DepAcc()).foldOver(status, tp) @@ -3653,29 +3749,52 @@ object Types { /** Does result type contain references to parameters of this method type, * which cannot be eliminated by de-aliasing? */ - def isResultDependent(using Context): Boolean = dependencyStatus == TrueDeps + def isResultDependent(using Context): Boolean = + dependencyStatus == TrueDeps || dependencyStatus == CaptureDeps /** Does one of the parameter types contain references to earlier parameters * of this method type which cannot be eliminated by de-aliasing? */ def isParamDependent(using Context): Boolean = paramDependencyStatus == TrueDeps + /** Is there either a true or false type dependency, or does the result + * type capture a parameter? + */ + def isCaptureDependent(using Context) = dependencyStatus == CaptureDeps + def newParamRef(n: Int): TermParamRef = new TermParamRefImpl(this, n) /** The least supertype of `resultType` that does not contain parameter dependencies */ def nonDependentResultApprox(using Context): Type = - if (isResultDependent) { + if isResultDependent then val dropDependencies = new ApproximatingTypeMap { def apply(tp: Type) = tp match { case tp @ TermParamRef(`thisLambdaType`, _) => range(defn.NothingType, atVariance(1)(apply(tp.underlying))) + case CapturingType(parent, refs, boxed) => + val parent1 = this(parent) + val elems1 = refs.elems.filter { + case tp @ TermParamRef(`thisLambdaType`, _) => false + case _ => true + } + if elems1.size == refs.elems.size then + derivedCapturingType(tp, parent1, refs) + else + range( + CapturingType(parent1, CaptureSet(elems1), boxed), + CapturingType(parent1, CaptureSet.universal, boxed)) case AnnotatedType(parent, ann) if ann.refersToParamOf(thisLambdaType) => - mapOver(parent) + val parent1 = mapOver(parent) + if ann.symbol == defn.RetainsAnnot then + range( + AnnotatedType(parent1, CaptureSet.empty.toRegularAnnotation), + AnnotatedType(parent1, CaptureSet.universal.toRegularAnnotation)) + else + parent1 case _ => mapOver(tp) } } dropDependencies(resultType) - } else resultType } @@ -4046,9 +4165,10 @@ object Types { final val Unknown: DependencyStatus = 0 // not yet computed final val NoDeps: DependencyStatus = 1 // no dependent parameters found final val FalseDeps: DependencyStatus = 2 // all dependent parameters are prefixes of non-depended alias types - final val TrueDeps: DependencyStatus = 3 // some truly dependent parameters exist - final val StatusMask: DependencyStatus = 3 // the bits indicating actual dependency status - final val Provisional: DependencyStatus = 4 // set if dependency status can still change due to type variable instantiations + final val CaptureDeps: DependencyStatus = 3 + final val TrueDeps: DependencyStatus = 4 // some truly dependent parameters exist + final val StatusMask: DependencyStatus = 7 // the bits indicating actual dependency status + final val Provisional: DependencyStatus = 8 // set if dependency status can still change due to type variable instantiations } // ----- Type application: LambdaParam, AppliedType --------------------- @@ -4370,8 +4490,9 @@ object Types { /** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to * refer to `TermParamRef(binder, paramNum)`. */ - abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef with SingletonType { + abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef, CaptureRef { type BT = TermLambda + def canBeTracked(using Context) = true def kindString: String = "Term" def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum) } @@ -5040,7 +5161,7 @@ object Types { // ----- Annotated and Import types ----------------------------------------------- /** An annotated type tpe @ annot */ - abstract case class AnnotatedType(parent: Type, annot: Annotation) extends CachedProxyType with ValueType { + abstract case class AnnotatedType(parent: Type, annot: Annotation) extends CachedProxyType, ValueType { override def underlying(using Context): Type = parent @@ -5069,16 +5190,16 @@ object Types { // equals comes from case class; no matching override is needed override def computeHash(bs: Binders): Int = - doHash(bs, System.identityHashCode(annot), parent) + doHash(bs, annot.hash, parent) override def hashIsStable: Boolean = parent.hashIsStable override def eql(that: Type): Boolean = that match - case that: AnnotatedType => (parent eq that.parent) && (annot eq that.annot) + case that: AnnotatedType => (parent eq that.parent) && (annot eql that.annot) case _ => false override def iso(that: Any, bs: BinderPairs): Boolean = that match - case that: AnnotatedType => parent.equals(that.parent, bs) && (annot eq that.annot) + case that: AnnotatedType => parent.equals(that.parent, bs) && (annot eql that.annot) case _ => false } @@ -5089,6 +5210,7 @@ object Types { annots.foldLeft(underlying)(apply(_, _)) def apply(parent: Type, annot: Annotation)(using Context): AnnotatedType = unique(CachedAnnotatedType(parent, annot)) + end AnnotatedType // Special type objects and classes ----------------------------------------------------- @@ -5308,7 +5430,7 @@ object Types { /** Common base class of TypeMap and TypeAccumulator */ abstract class VariantTraversal: - protected[core] var variance: Int = 1 + protected[dotc] var variance: Int = 1 inline protected def atVariance[T](v: Int)(op: => T): T = { val saved = variance @@ -5334,6 +5456,24 @@ object Types { } end VariantTraversal + /** A supertrait for some typemaps that are bijections. Used for capture checking + * BiTypeMaps should map capture references to capture references. + */ + trait BiTypeMap extends TypeMap: + thisMap => + def inverse(tp: Type): Type + + def inverseTypeMap(using Context) = new BiTypeMap: + def apply(tp: Type) = thisMap.inverse(tp) + def inverse(tp: Type) = thisMap.apply(tp) + + def forward(ref: CaptureRef): CaptureRef = this(ref) match + case result: CaptureRef if result.canBeTracked => result + + def backward(ref: CaptureRef): CaptureRef = inverse(ref) match + case result: CaptureRef if result.canBeTracked => result + end BiTypeMap + abstract class TypeMap(implicit protected var mapCtx: Context) extends VariantTraversal with (Type => Type) { thisMap => @@ -5361,6 +5501,8 @@ object Types { tp.derivedMatchType(bound, scrutinee, cases) protected def derivedAnnotatedType(tp: AnnotatedType, underlying: Type, annot: Annotation): Type = tp.derivedAnnotatedType(underlying, annot) + protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = + tp.derivedCapturingType(parent, refs) protected def derivedWildcardType(tp: WildcardType, bounds: Type): Type = tp.derivedWildcardType(bounds) protected def derivedSkolemType(tp: SkolemType, info: Type): Type = @@ -5396,6 +5538,12 @@ object Types { def isRange(tp: Type): Boolean = tp.isInstanceOf[Range] + protected def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type = + val saved = variance + variance = v + try derivedCapturingType(tp, this(parent), refs.map(this)) + finally variance = saved + /** Map this function over given type */ def mapOver(tp: Type): Type = { record(s"TypeMap mapOver ${getClass}") @@ -5437,6 +5585,9 @@ object Types { case tp: ExprType => derivedExprType(tp, this(tp.resultType)) + case CapturingType(parent, refs, _) => + mapCapturingType(tp, parent, refs, variance) + case tp @ AnnotatedType(underlying, annot) => val underlying1 = this(underlying) val annot1 = annot.mapWith(this) @@ -5757,6 +5908,13 @@ object Types { if (underlying.isExactlyNothing) underlying else tp.derivedAnnotatedType(underlying, annot) } + override protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type = + parent match // ^^^ handle ranges in capture sets as well + case Range(lo, hi) => + range(derivedCapturingType(tp, lo, refs), derivedCapturingType(tp, hi, refs)) + case _ => + tp.derivedCapturingType(parent, refs) + override protected def derivedWildcardType(tp: WildcardType, bounds: Type): WildcardType = tp.derivedWildcardType(rangeToBounds(bounds)) @@ -5796,6 +5954,12 @@ object Types { tp.derivedLambdaType(tp.paramNames, formals, restpe) } + override def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type = + if v == 0 then + range(mapCapturingType(tp, parent, refs, -1), mapCapturingType(tp, parent, refs, 1)) + else + super.mapCapturingType(tp, parent, refs, v) + protected def reapply(tp: Type): Type = apply(tp) } @@ -5893,6 +6057,9 @@ object Types { val x2 = atVariance(0)(this(x1, tp.scrutinee)) foldOver(x2, tp.cases) + case CapturingType(parent, refs, _) => + (this(x, parent) /: refs.elems)(this) + case AnnotatedType(underlying, annot) => this(applyToAnnot(x, annot), underlying) diff --git a/compiler/src/dotty/tools/dotc/core/Variances.scala b/compiler/src/dotty/tools/dotc/core/Variances.scala index 122c7a10e4b7..44dda6b0077e 100644 --- a/compiler/src/dotty/tools/dotc/core/Variances.scala +++ b/compiler/src/dotty/tools/dotc/core/Variances.scala @@ -4,6 +4,7 @@ package core import Types._, Contexts._, Flags._, Symbols._, Annotations._ import TypeApplications.TypeParamInfo import Decorators._ +import cc.CapturingType object Variances { @@ -99,6 +100,8 @@ object Variances { v } varianceInArgs(varianceInType(tycon)(tparam), args, tycon.typeParams) + case CapturingType(tp, _, _) => + varianceInType(tp)(tparam) case AnnotatedType(tp, annot) => varianceInType(tp)(tparam) & varianceInAnnot(annot)(tparam) case AndType(tp1, tp2) => diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index e4a5c0ae8c6d..bbc18cd3f5ff 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -821,7 +821,7 @@ class TreeUnpickler(reader: TastyReader, def TypeDef(rhs: Tree) = ta.assignType(untpd.TypeDef(sym.name.asTypeName, rhs), sym) - def ta = ctx.typeAssigner + def ta = ctx.typeAssigner val name = readName() pickling.println(s"reading def of $name at $start") @@ -1263,11 +1263,9 @@ class TreeUnpickler(reader: TastyReader, // types. This came up in #137 of collection strawman. val tycon = readTpt() val args = until(end)(readTpt()) - val ownType = - if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe) - else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false) - else tycon.tpe.safeAppliedTo(args.tpes) - untpd.AppliedTypeTree(tycon, args).withType(ownType) + val tree = untpd.AppliedTypeTree(tycon, args) + val ownType = ctx.typeAssigner.processAppliedType(tree, tycon.tpe.safeAppliedTo(args.tpes)) + tree.withType(ownType) case ANNOTATEDtpt => Annotated(readTpt(), readTerm()) case LAMBDAtpt => diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 27c1dead4482..6ee299d34922 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -890,6 +890,24 @@ object Parsers { } } + def followingIsCaptureSet(): Boolean = + val lookahead = in.LookaheadScanner() + def recur(): Boolean = + (lookahead.isIdent || lookahead.token == THIS) && { + lookahead.nextToken() + if lookahead.token == COMMA then + lookahead.nextToken() + recur() + else + lookahead.token == RBRACE && { + lookahead.nextToken() + canStartInfixTypeTokens.contains(lookahead.token) + || lookahead.token == LBRACKET + } + } + lookahead.nextToken() + recur() + /* --------- OPERAND/OPERATOR STACK --------------------------------------- */ var opStack: List[OpInfo] = Nil @@ -1330,17 +1348,25 @@ object Parsers { case _ => false } + /** CaptureRef ::= ident | `this` + */ + def captureRef(): Tree = + if in.token == THIS then simpleRef() else termIdent() + /** Type ::= FunType * | HkTypeParamClause ‘=>>’ Type * | FunParamClause ‘=>>’ Type * | MatchType * | InfixType + * | CaptureSet Type * FunType ::= (MonoFunType | PolyFunType) * MonoFunType ::= FunTypeArgs (‘=>’ | ‘?=>’) Type * PolyFunType ::= HKTypeParamClause '=>' Type * FunTypeArgs ::= InfixType * | `(' [ [ ‘[using]’ ‘['erased'] FunArgType {`,' FunArgType } ] `)' * | '(' [ ‘[using]’ ‘['erased'] TypedFunParam {',' TypedFunParam } ')' + * CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` + * CaptureRef ::= Ident */ def typ(): Tree = { val start = in.offset @@ -1446,6 +1472,10 @@ object Parsers { } else { accept(TLARROW); typ() } } + else if in.token == LBRACE && followingIsCaptureSet() then + val refs = inBraces { commaSeparated(captureRef) } + val t = typ() + CapturingTypeTree(refs, t) else if (in.token == INDENT) enclosed(INDENT, typ()) else infixType() @@ -1514,7 +1544,7 @@ object Parsers { def infixType(): Tree = infixTypeRest(refinedType()) def infixTypeRest(t: Tree): Tree = - infixOps(t, canStartTypeTokens, refinedTypeFn, Location.ElseWhere, + infixOps(t, canStartInfixTypeTokens, refinedTypeFn, Location.ElseWhere, isType = true, isOperator = !followingIsVararg()) @@ -3154,7 +3184,7 @@ object Parsers { ImportSelector( atSpan(in.skipToken()) { Ident(nme.EMPTY) }, bound = - if canStartTypeTokens.contains(in.token) then rejectWildcardType(infixType()) + if canStartInfixTypeTokens.contains(in.token) then rejectWildcardType(infixType()) else EmptyTree) /** id [‘as’ (id | ‘_’) */ diff --git a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala index cba07a6e5a34..7fadf341905d 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Tokens.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Tokens.scala @@ -230,8 +230,8 @@ object Tokens extends TokensCommon { final val canStartExprTokens2: TokenSet = canStartExprTokens3 | BitSet(DO) - final val canStartTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet( - THIS, SUPER, USCORE, LPAREN, AT) + final val canStartInfixTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet( + THIS, SUPER, USCORE, LPAREN, LBRACE, AT) final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 3a3fca5e7f90..054fe62682dd 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -14,13 +14,16 @@ import Variances.varianceSign import util.SourcePosition import scala.util.control.NonFatal import scala.annotation.switch +import config.Config +import cc.{CapturingType, CaptureSet} class PlainPrinter(_ctx: Context) extends Printer { + /** The context of all public methods in Printer and subclasses. * Overridden in RefinedPrinter. */ - protected def curCtx: Context = _ctx.addMode(Mode.Printing) - protected given [DummyToEnforceDef]: Context = curCtx + def printerContext: Context = _ctx.addMode(Mode.Printing) + protected given [DummyToEnforceDef]: Context = printerContext protected def printDebug = ctx.settings.YprintDebug.value @@ -186,6 +189,22 @@ class PlainPrinter(_ctx: Context) extends Printer { keywordStr(" match ") ~ "{" ~ casesText ~ "}" ~ (" <: " ~ toText(bound) provided !bound.isAny) }.close + case CapturingType(parent, refs, boxed) => + def box = Str("box ") provided boxed + if printDebug && !refs.isConst then + changePrec(GlobalPrec)(box ~ s"$refs " ~ toText(parent)) + else if ctx.settings.YccDebug.value then + changePrec(GlobalPrec)(box ~ refs.toText(this) ~ " " ~ toText(parent)) + else if !refs.isConst && refs.elems.isEmpty then + changePrec(GlobalPrec)("?" ~ " " ~ toText(parent)) + else if Config.printCaptureSetsAsPrefix then + changePrec(GlobalPrec)( + box ~ "{" + ~ Text(refs.elems.toList.map(toTextCaptureRef), ", ") + ~ "} " + ~ toText(parent)) + else + changePrec(InfixPrec)(toText(parent) ~ " retains " ~ box ~ toText(refs.toRetainsTypeArg)) case tp: PreviousErrorType if ctx.settings.XprintTypes.value => "" // do not print previously reported error message because they may try to print this error type again recuresevely case tp: ErrorType => @@ -273,7 +292,7 @@ class PlainPrinter(_ctx: Context) extends Printer { /** If -uniqid is set, the unique id of symbol, after a # */ protected def idString(sym: Symbol): String = - if (showUniqueIds || Printer.debugPrintUnique) "#" + sym.id else "" + if showUniqueIds then "#" + sym.id else "" def nameString(sym: Symbol): String = simpleNameString(sym) + idString(sym) // + "<" + (if (sym.exists) sym.owner else "") + ">" @@ -313,7 +332,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case tp @ ConstantType(value) => toText(value) case pref: TermParamRef => - nameString(pref.binder.paramNames(pref.paramNum)) + nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder) case tp: RecThis => val idx = openRecs.reverse.indexOf(tp.binder) if (idx >= 0) selfRecName(idx + 1) @@ -334,6 +353,11 @@ class PlainPrinter(_ctx: Context) extends Printer { } } + def toTextCaptureRef(tp: Type): Text = + homogenize(tp) match + case tp: SingletonType => toTextRef(tp) + case _ => toText(tp) + protected def isOmittablePrefix(sym: Symbol): Boolean = defn.unqualifiedOwnerTypes.exists(_.symbol == sym) || isEmptyPrefix(sym) diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index 550bdb94af4f..b883b6be805b 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -6,7 +6,7 @@ import core._ import Texts._, ast.Trees._ import Types.{Type, SingletonType, LambdaParam}, Symbols.Symbol, Scopes.Scope, Constants.Constant, - Names.Name, Denotations._, Annotations.Annotation + Names.Name, Denotations._, Annotations.Annotation, Contexts.Context import typer.Implicits.SearchResult import util.SourcePosition import typer.ImportInfo @@ -104,6 +104,9 @@ abstract class Printer { /** Textual representation of a prefix of some reference, ending in `.` or `#` */ def toTextPrefix(tp: Type): Text + /** Textual representation of a reference in a capture set */ + def toTextCaptureRef(tp: Type): Text + /** Textual representation of symbol's declaration */ def dclText(sym: Symbol): Text @@ -182,6 +185,9 @@ abstract class Printer { /** A plain printer without any embellishments */ def plain: Printer + + /** The context in which this printer operates */ + def printerContext: Context } object Printer { diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index ef11ec9434ae..c7606acf3dd6 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -34,11 +34,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { /** A stack of enclosing DefDef, TypeDef, or ClassDef, or ModuleDefs nodes */ private var enclosingDef: untpd.Tree = untpd.EmptyTree - private var myCtx: Context = super.curCtx + private var myCtx: Context = super.printerContext private var printPos = ctx.settings.YprintPos.value private val printLines = ctx.settings.printLines.value - override protected def curCtx: Context = myCtx + override def printerContext: Context = myCtx def withEnclosingDef(enclDef: Tree[? >: Untyped])(op: => Text): Text = { val savedCtx = myCtx @@ -164,10 +164,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { changePrec(GlobalPrec) { "(" ~ keywordText("erased ").provided(info.isErasedMethod) - ~ ( if info.isParamDependent || info.isResultDependent - then paramsText(info) - else argsText(info.paramInfos) - ) + ~ paramsText(info) ~ ") " ~ arrow(info.isImplicitMethod) ~ " " @@ -245,9 +242,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty => // don't eta contract if the application would be printed specially toText(tycon) - case tp: RefinedType - if (defn.isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass)) - && !printDebug => + case tp: RefinedType if defn.isFunctionOrPolyType(tp) && !printDebug => toTextMethodAsFunction(tp.refinedInfo) case tp: TypeRef => if (tp.symbol.isAnonymousClass && !showUniqueIds) @@ -703,6 +698,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { val (prefix, postfix) = if isTermHole then ("{{{ ", " }}}") else ("[[[ ", " ]]]") val argsText = toTextGlobal(args, ", ") prefix ~~ idx.toString ~~ "|" ~~ argsText ~~ postfix + case CapturingTypeTree(refs, parent) => + changePrec(GlobalPrec)("{" ~ Text(refs.map(toText), ", ") ~ "} " ~ toText(parent)) case _ => tree.fallbackToText(this) } @@ -789,9 +786,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if mdef.hasType then Modifiers(mdef.symbol) else mdef.rawMods private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers( - sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), + sym.flagsUNSAFE & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY, - sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) + sym.annotationsUNSAFE.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot @@ -988,13 +985,13 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { else if (suppressKw) PrintableFlags(isType) &~ Private else PrintableFlags(isType) if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= GivenOrImplicit // drop implicit/given from classes - val rawFlags = if (sym.exists) sym.flags else mods.flags + val rawFlags = if (sym.exists) sym.flagsUNSAFE else mods.flags if (rawFlags.is(Param)) flagMask = flagMask &~ Given &~ Erased val flags = rawFlags & flagMask var flagsText = toTextFlags(sym, flags) val annotTexts = if sym.exists then - sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText) + sym.annotationsUNSAFE.filterNot(ann => dropAnnotForModText(ann.symbol)).map(toText) else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)).map(annotText(NoSymbol, _)) Text(annotTexts, " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw) diff --git a/compiler/src/dotty/tools/dotc/reporting/messages.scala b/compiler/src/dotty/tools/dotc/reporting/messages.scala index 063ba96410c8..eac14bda8d4b 100644 --- a/compiler/src/dotty/tools/dotc/reporting/messages.scala +++ b/compiler/src/dotty/tools/dotc/reporting/messages.scala @@ -286,7 +286,6 @@ import transform.SymUtils._ val treeStr = inTree.map(x => s"\nTree: ${x.show}").getOrElse("") treeStr + "\n" + super.explain - end TypeMismatch class NotAMember(site: Type, val name: Name, selected: String, addendum: => String = "")(using Context) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala index 6e198bbeada9..595094f3edd6 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala @@ -175,6 +175,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { private val byNameMarker = marker("ByName") private val matchMarker = marker("Match") private val superMarker = marker("Super") + private val retainsMarker = marker("Retains") /** Extract the API representation of a source file */ def apiSource(tree: Tree): Seq[api.ClassLike] = { diff --git a/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala b/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala new file mode 100644 index 000000000000..9a287b2dd1d9 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/EmptyPhase.scala @@ -0,0 +1,19 @@ +package dotty.tools.dotc +package transform + +import core.* +import Contexts.Context +import Phases.Phase + +/** A phase that can be inserted directly after a phase that cannot + * be checked, to enable a -Ycheck as soon as possible afterwards + */ +class EmptyPhase extends Phase: + + def phaseName: String = "dummy" + + override def isEnabled(using Context) = prev.isEnabled + + override def run(using Context) = () + +end EmptyPhase \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index c7e02a5c6837..048395e8dffa 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -289,7 +289,11 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase tree.fun, tree.args.mapConserve(arg => if (methType.isImplicitMethod && arg.span.isSynthetic) - PruneErasedDefs.trivialErasedTree(arg) + arg match + case _: RefTree | _: Apply | _: TypeApply if arg.symbol.is(Erased) => + dropInlines.transform(arg) + case _ => + PruneErasedDefs.trivialErasedTree(arg) else dropInlines.transform(arg))) else tree diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 76f89cb65757..a61b736a9cc1 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -14,15 +14,22 @@ import typer.ErrorReporting.err import typer.ProtoTypes.* import typer.TypeAssigner.seqLitType import typer.ConstFold +import NamerOps.methodType import config.Printers.recheckr import util.Property import StdNames.nme import reporting.trace +object Recheck: + + /** Attachment key for rechecked types of TypeTrees */ + private val RecheckedType = Property.Key[Type] + abstract class Recheck extends Phase, IdentityDenotTransformer: thisPhase => import ast.tpd.* + import Recheck.* def preRecheckPhase = this.prev.asInstanceOf[PreRecheck] @@ -36,12 +43,17 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: override def widenSkolems = true def run(using Context): Unit = - newRechecker().checkUnit(ctx.compilationUnit) + val rechecker = newRechecker() + rechecker.transformTypes.traverse(ctx.compilationUnit.tpdTree) + rechecker.checkUnit(ctx.compilationUnit) def newRechecker()(using Context): Rechecker class Rechecker(ictx: Context): - val ta = ictx.typeAssigner + private val ta = ictx.typeAssigner + private val keepTypes = inContext(ictx) { + ictx.settings.Xprint.value.containsPhase(thisPhase) + } extension (sym: Symbol) def updateInfo(newInfo: Type)(using Context): Unit = if sym.info ne newInfo then @@ -53,23 +65,102 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: else sym.flags ).installAfter(preRecheckPhase) - /** Hook to be overridden */ - protected def reinfer(tp: Type)(using Context): Type = tp - - def reinferResult(info: Type)(using Context): Type = info match - case info: MethodOrPoly => - info.derivedLambdaType(resType = reinferResult(info.resultType)) - case _ => - reinfer(info) + extension (tpe: Type) def rememberFor(tree: Tree)(using Context): Unit = + if (tpe ne tree.tpe) && !tree.hasAttachment(RecheckedType) then + tree.putAttachment(RecheckedType, tpe) + + def knownType(tree: Tree) = + tree.attachmentOrElse(RecheckedType, tree.tpe) + + def isUpdated(sym: Symbol)(using Context) = + val symd = sym.denot + symd.validFor.firstPhaseId == thisPhase.id && (sym.originDenotation ne symd) + + def transformType(tp: Type, inferred: Boolean)(using Context): Type = tp + + object transformTypes extends TreeTraverser: + + // Substitute parameter symbols in `from` to paramRefs in corresponding + // method or poly types `to`. We use a single BiTypeMap to do everything. + class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context) + extends DeepTypeMap, BiTypeMap: + + def apply(t: Type): Type = t match + case t: NamedType => + val sym = t.symbol + def outer(froms: List[List[Symbol]], tos: List[LambdaType]): Type = + def inner(from: List[Symbol], to: List[ParamRef]): Type = + if from.isEmpty then outer(froms.tail, tos.tail) + else if sym eq from.head then to.head + else inner(from.tail, to.tail) + if tos.isEmpty then t + else inner(froms.head, tos.head.paramRefs) + outer(from, to) + case _ => + mapOver(t) + + def inverse(t: Type): Type = t match + case t: ParamRef => + def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = + if from.isEmpty then t + else if t.binder eq from.head then to.head(t.paramNum).namedType + else recur(from.tail, to.tail) + recur(to, from) + case _ => + mapOver(t) + end SubstParams + + def traverse(tree: Tree)(using Context) = + traverseChildren(tree) + tree match - def enterDef(stat: Tree)(using Context): Unit = - val sym = stat.symbol - stat match - case stat: ValOrDefDef if stat.tpt.isInstanceOf[InferredTypeTree] => - sym.updateInfo(reinferResult(sym.info)) - case stat: Bind => - sym.updateInfo(reinferResult(sym.info)) - case _ => + case tree: TypeTree => + transformType(tree.tpe, tree.isInstanceOf[InferredTypeTree]).rememberFor(tree) + case tree: ValOrDefDef => + val sym = tree.symbol + + // replace an existing symbol info with inferred types + def integrateRT( + info: Type, // symbol info to replace + psymss: List[List[Symbol]], // the local (type and trem) parameter symbols corresponding to `info` + prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order + prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order + ): Type = + info match + case mt: MethodOrPoly => + val psyms = psymss.head + mt.companion(mt.paramNames)( + mt1 => + if !psyms.exists(isUpdated) && !mt.isParamDependent && prevLambdas.isEmpty then + mt.paramInfos + else + val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) + psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]), + mt1 => + integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) + ) + case info: ExprType => + info.derivedExprType(resType = + integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) + case _ => + val restp = knownType(tree.tpt) + if prevLambdas.isEmpty then restp + else SubstParams(prevPsymss, prevLambdas)(restp) + + if tree.tpt.hasAttachment(RecheckedType) && !sym.isConstructor then + val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) + .showing(i"update info $sym: ${sym.info} --> $result", recheckr) + if newInfo ne sym.info then + val completer = new LazyType: + def complete(denot: SymDenotation)(using Context) = + denot.info = newInfo + recheckDef(tree, sym) + sym.updateInfo(completer) + case tree: Bind => + val sym = tree.symbol + sym.updateInfo(transformType(sym.info, inferred = true)) + case _ => + end transformTypes def constFold(tree: Tree, tp: Type)(using Context): Type = val tree1 = tree.withType(tp) @@ -90,10 +181,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: excluded = if tree.symbol.is(Private) then EmptyFlags else Private ).suchThat(tree.symbol ==) constFold(tree, qualType.select(name, mbr)) + //.showing(i"recheck select $qualType . $name : ${mbr.symbol.info} = $result") def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match case Bind(name, body) => - enterDef(tree) recheck(body, pt) val sym = tree.symbol if sym.isType then sym.typeRef else sym.info @@ -104,16 +195,13 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val exprType = recheck(expr, defn.UnitType) bindType - def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type = - if !tree.rhs.isEmpty then recheck(tree.rhs, tree.symbol.info) - sym.termRef + def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = + if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info) - def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type = - tree.paramss.foreach(_.foreach(enterDef)) - val rhsCtx = linkConstructorParams(sym) + def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = + val rhsCtx = linkConstructorParams(sym).withOwner(sym) if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then - recheck(tree.rhs, tree.symbol.localReturnType)(using rhsCtx) - sym.termRef + inContext(rhsCtx) { recheck(tree.rhs, recheck(tree.tpt)) } def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type = recheck(tree.rhs) @@ -134,6 +222,11 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case _ => mapOver(t) formals.mapConserve(tm) + /** Hook for method type instantiation + */ + protected def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = + mt.instantiate(argTypes) + def recheckApply(tree: Apply, pt: Type)(using Context): Type = recheck(tree.fun).widen match case fntpe: MethodType => @@ -153,7 +246,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: assert(formals.isEmpty) Nil val argTypes = recheckArgs(tree.args, formals, fntpe.paramRefs) - constFold(tree, fntpe.instantiate(argTypes)) + constFold(tree, instantiate(fntpe, argTypes, tree.fun.symbol)) def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type = recheck(tree.fun).widen match @@ -174,7 +267,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type = recheckStats(stats) - val exprType = recheck(expr, pt.dropIfProto) + val exprType = recheck(expr) + // The expected type `pt` is not propagated. Doing so would allow variables in the + // expected type to contain references to local symbols of the block, so the + // local symbols could escape that way. TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm)) def recheckBlock(tree: Block, pt: Type)(using Context): Type = @@ -195,10 +291,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckMatch(tree: Match, pt: Type)(using Context): Type = val selectorType = recheck(tree.selector) - val casesTypes = tree.cases.map(recheck(_, selectorType.widen, pt)) + val casesTypes = tree.cases.map(recheckCase(_, selectorType.widen, pt)) TypeComparer.lub(casesTypes) - def recheck(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = + def recheckCase(tree: CaseDef, selType: Type, pt: Type)(using Context): Type = recheck(tree.pat, selType) recheck(tree.guard, defn.BooleanType) recheck(tree.body, pt) @@ -214,7 +310,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: def recheckTry(tree: Try, pt: Type)(using Context): Type = val bodyType = recheck(tree.expr, pt) - val casesTypes = tree.cases.map(recheck(_, defn.ThrowableType, pt)) + val casesTypes = tree.cases.map(recheckCase(_, defn.ThrowableType, pt)) val finalizerType = recheck(tree.finalizer, defn.UnitType) TypeComparer.lub(bodyType :: casesTypes) @@ -227,9 +323,8 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: val elemTypes = tree.elems.map(recheck(_, elemProto)) seqLitType(tree, TypeComparer.lub(declaredElemType :: elemTypes)) - def recheckTypeTree(tree: TypeTree)(using Context): Type = tree match - case tree: InferredTypeTree => reinfer(tree.tpe) - case _ => tree.tpe + def recheckTypeTree(tree: TypeTree)(using Context): Type = + knownType(tree) def recheckAnnotated(tree: Annotated)(using Context): Type = tree.tpe match @@ -246,14 +341,20 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: NoType def recheckStats(stats: List[Tree])(using Context): Unit = - stats.foreach(enterDef) stats.foreach(recheck(_)) + def recheckDef(tree: ValOrDefDef, sym: Symbol)(using Context): Unit = + inContext(ctx.localContext(tree, sym)) { + tree match + case tree: ValDef => recheckValDef(tree, sym) + case tree: DefDef => recheckDefDef(tree, sym) + } + /** Recheck tree without adapting it, returning its new type. * @param tree the original tree * @param pt the expected result type */ - def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + def recheckStart(tree: Tree, pt: Type = WildcardType)(using Context): Type = def recheckNamed(tree: NameTree, pt: Type)(using Context): Type = val sym = tree.symbol @@ -261,11 +362,12 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: Ident => recheckIdent(tree) case tree: Select => recheckSelect(tree) case tree: Bind => recheckBind(tree, pt) - case tree: ValDef => + case tree: ValOrDefDef => if tree.isEmpty then NoType - else recheckValDef(tree, sym)(using ctx.localContext(tree, sym)) - case tree: DefDef => - recheckDefDef(tree, sym)(using ctx.localContext(tree, sym)) + else + if isUpdated(sym) then sym.ensureCompleted() + else recheckDef(tree, sym) + sym.termRef case tree: TypeDef => tree.rhs match case impl: Template => @@ -295,35 +397,61 @@ abstract class Recheck extends Phase, IdentityDenotTransformer: case tree: PackageDef => recheckPackageDef(tree) case tree: Thicket => defn.NothingType - try - val result = tree match - case tree: NameTree => recheckNamed(tree, pt) - case tree => recheckUnnamed(tree, pt) - checkConforms(result, pt, tree) - result - catch case ex: Exception => - println(i"error while rechecking $tree") - throw ex - } - end recheck + tree match + case tree: NameTree => recheckNamed(tree, pt) + case tree => recheckUnnamed(tree, pt) + end recheckStart + + def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = + checkConforms(tpe, pt, tree) + if keepTypes then tpe.rememberFor(tree) + tpe + + def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = + trace(i"rechecking $tree with pt = $pt", recheckr, show = true) { + try recheckFinish(recheckStart(tree, pt), tree, pt) + catch case ex: Exception => + println(i"error while rechecking $tree") + throw ex + } + + private val debugSuccesses = false def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match - case _: DefTree | EmptyTree | _: TypeTree => + case _: DefTree | EmptyTree | _: TypeTree | _: Closure => + // Don't report closure nodes, since their span is a point; wait instead + // for enclosing block to preduce an error case _ => val actual = tpe.widenExpr val expected = pt.widenExpr + //println(i"check conforms $actual <:< $expected") val isCompatible = actual <:< expected || expected.isRepeatedParam && actual <:< expected.translateFromRepeated(toArray = tree.tpe.isRef(defn.ArrayClass)) if !isCompatible then - println(i"err at ${ctx.phase}") - err.typeMismatch(tree.withType(tpe), pt) + err.typeMismatch(tree.withType(tpe), expected) + else if debugSuccesses then + tree match + case _: Ident => + println(i"SUCCESS $tree:\n${TypeComparer.explained(_.isSubType(actual, expected))}") + case _ => def checkUnit(unit: CompilationUnit)(using Context): Unit = recheck(unit.tpdTree) end Rechecker + + override def show(tree: untpd.Tree)(using Context): String = + val addRecheckedTypes = new TreeMap: + override def transform(tree: Tree)(using Context): Tree = + val tree1 = super.transform(tree) + tree.getAttachment(RecheckedType) match + case Some(tpe) => tree1.withType(tpe) + case None => tree1 + atPhase(thisPhase) { + super.show(addRecheckedTypes.transform(tree.asInstanceOf[tpd.Tree])) + } end Recheck class TestRecheck extends Recheck: diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index 29fd1adb6688..044ea11eb27e 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -375,14 +375,14 @@ class TreeChecker extends Phase with SymTransformer { val tpe = tree.typeOpt // Polymorphic apply methods stay structural until Erasure - val isPolyFunctionApply = (tree.name eq nme.apply) && (tree.qualifier.typeOpt <:< defn.PolyFunctionType) + val isPolyFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass) // Outer selects are pickled specially so don't require a symbol val isOuterSelect = tree.name.is(OuterSelectName) val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name) if !(tree.isType || isPolyFunctionApply || isOuterSelect || isPrimitiveArrayOp) then val denot = tree.denot assert(denot.exists, i"Selection $tree with type $tpe does not have a denotation") - assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol") + assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol, qualifier type = ${tree.qualifier.typeOpt}") val sym = tree.symbol val symIsFixed = tpe match { diff --git a/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala b/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala index 6be58352e6dc..26bea001d1eb 100644 --- a/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala +++ b/compiler/src/dotty/tools/dotc/transform/TryCatchPatterns.scala @@ -70,7 +70,7 @@ class TryCatchPatterns extends MiniPhase { case _ => isDefaultCase(cdef) } - private def isSimpleThrowable(tp: Type)(using Context): Boolean = tp.stripAnnots match { + private def isSimpleThrowable(tp: Type)(using Context): Boolean = tp.stripped match { case tp @ TypeRef(pre, _) => (pre == NoPrefix || pre.typeSymbol.isStatic) && // Does not require outer class check !tp.symbol.is(Flags.Trait) && // Traits not supported by JVM diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index 8ffe2198c4d9..7c5d34126bd9 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -148,7 +148,7 @@ object TypeTestsCasts { } case AndType(tp1, tp2) => recur(X, tp1) && recur(X, tp2) case OrType(tp1, tp2) => recur(X, tp1) && recur(X, tp2) - case AnnotatedType(t, _) => recur(X, t) + case tp: AnnotatedType => recur(X, tp.parent) case _: RefinedType => false case _ => true }) @@ -217,7 +217,7 @@ object TypeTestsCasts { * can be true in some cases. Issues a warning or an error otherwise. */ def checkSensical(foundClasses: List[Symbol])(using Context): Boolean = - def exprType = i"type ${expr.tpe.widen.stripAnnots}" + def exprType = i"type ${expr.tpe.widen.stripped}" def check(foundCls: Symbol): Boolean = if (!isCheckable(foundCls)) true else if (!foundCls.derivesFrom(testCls)) { diff --git a/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala new file mode 100644 index 000000000000..1415016fea26 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala @@ -0,0 +1,468 @@ +package dotty.tools +package dotc +package cc + +import core._ +import Phases.*, DenotTransformers.*, SymDenotations.* +import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.* +import Types._ +import Symbols._ +import StdNames._ +import Decorators._ +import config.Printers.{capt, recheckr} +import ast.{tpd, untpd, Trees} +import NameKinds.{DocArtifactName, OuterSelectName, DefaultGetterName} +import Trees._ +import scala.util.control.NonFatal +import typer.ErrorReporting._ +import typer.RefChecks +import util.Spans.Span +import util.{SimpleIdentitySet, EqHashMap, SrcPos} +import util.Chars.* +import transform.* +import transform.SymUtils.* +import scala.collection.mutable +import reporting._ +import dotty.tools.backend.jvm.DottyBackendInterface.symExtensions +import CaptureSet.{CompareResult, withCaptureSetsExplained} + +object CheckCaptures: + import ast.tpd.* + + case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer: Env): + def isOpen = !captured.isAlwaysEmpty && !isBoxed + + final class SubstParamsMap(from: BindingType, to: List[Type])(using Context) + extends ApproximatingTypeMap: + def apply(tp: Type): Type = tp match + case tp: ParamRef => + if tp.binder == from then to(tp.paramNum) else tp + case tp: NamedType => + if tp.prefix `eq` NoPrefix then tp + else tp.derivedSelect(apply(tp.prefix)) + case _: ThisType => + tp + case _ => + mapOver(tp) + + /** Check that a @retains annotation only mentions references that can be tracked + * This check is performed at Typer. + */ + def checkWellformed(ann: Tree)(using Context): Unit = + for elem <- retainedElems(ann) do + elem.tpe match + case ref: CaptureRef => + if !ref.canBeTracked then + report.error(em"$elem cannot be tracked since it is not a parameter or a local variable", elem.srcPos) + case tpe => + report.error(em"$tpe is not a legal type for a capture set", elem.srcPos) + + /** If `tp` is a capturing type, check that all references it mentions have non-empty + * capture sets. + * This check is performed after capture sets are computed in phase cc. + */ + def checkWellformedPost(tp: Type, pos: SrcPos)(using Context): Unit = tp match + case CapturingType(parent, refs, _) => + for ref <- refs.elems do + if ref.captureSetOfInfo.elems.isEmpty then + report.error(em"$ref cannot be tracked since its capture set is empty", pos) + else if parent.captureSet.accountsFor(ref) then + report.warning(em"redundant capture: $parent already accounts for $ref", pos) + case _ => + + def checkWellformedPost(ann: Tree)(using Context): Unit = + /** The lists `elems(i) :: prev.reerse :: elems(0),...,elems(i-1),elems(i+1),elems(n)` + * where `n == elems.length-1`, i <- 0..n`. + */ + def choices(prev: List[Tree], elems: List[Tree]): List[List[Tree]] = elems match + case Nil => Nil + case elem :: elems => + List(elem :: (prev reverse_::: elems)) ++ choices(elem :: prev, elems) + for case first :: others <- choices(Nil, retainedElems(ann)) do + val firstRef = first.toCaptureRef + val remaining = CaptureSet(others.map(_.toCaptureRef)*) + if remaining.accountsFor(firstRef) then + report.warning(em"redundant capture: $remaining already accounts for $firstRef", ann.srcPos) + + private inline val disallowGlobal = true + +class CheckCaptures extends Recheck: + thisPhase => + + import ast.tpd.* + import CheckCaptures.* + + def phaseName: String = "cc" + override def isEnabled(using Context) = ctx.settings.Ycc.value + + def newRechecker()(using Context) = CaptureChecker(ctx) + + override def run(using Context): Unit = + checkOverrides.traverse(ctx.compilationUnit.tpdTree) + super.run + + def checkOverrides = new TreeTraverser: + def traverse(t: Tree)(using Context) = + t match + case t: Template => + // ^^^ TODO: Can we avoid doing overrides checks twice? + // We need to do them here since only at this phase CaptureTypes are relevant + // But maybe we can then elide the check during the RefChecks phase if -Ycc is set? + RefChecks.checkAllOverrides(ctx.owner.asClass) + case _ => + traverseChildren(t) + + class CaptureChecker(ictx: Context) extends Rechecker(ictx): + import ast.tpd.* + + override def transformType(tp: Type, inferred: Boolean)(using Context): Type = + + def addInnerVars(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) => + tp.derivedAppliedType(tycon, args.map(addVars(_, boxed = true))) + case tp @ RefinedType(core, rname, rinfo) => + val rinfo1 = addVars(rinfo) + if defn.isFunctionType(tp) then + rinfo1.toFunctionType(isJava = false, alwaysDependent = true) + else + tp.derivedRefinedType(addInnerVars(core), rname, rinfo1) + case tp: MethodType => + tp.derivedLambdaType( + paramInfos = tp.paramInfos.mapConserve(addVars(_)), + resType = addVars(tp.resType)) + case tp: PolyType => + tp.derivedLambdaType( + resType = addVars(tp.resType)) + case tp: ExprType => + tp.derivedExprType(resType = addVars(tp.resType)) + case _ => + tp + + def addFunctionRefinements(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) => + if defn.isNonRefinedFunction(tp) then + MethodType.companion( + isContextual = defn.isContextFunctionClass(tycon.classSymbol), + isErased = defn.isErasedFunctionClass(tycon.classSymbol) + )(args.init, addFunctionRefinements(args.last)) + .toFunctionType(isJava = false, alwaysDependent = true) + .showing(i"add function refinement $tp --> $result", capt) + else + tp.derivedAppliedType(tycon, args.map(addFunctionRefinements(_))) + case tp @ RefinedType(core, rname, rinfo) if !defn.isFunctionType(tp) => + tp.derivedRefinedType( + addFunctionRefinements(core), rname, addFunctionRefinements(rinfo)) + case tp: MethodOrPoly => + tp.derivedLambdaType(resType = addFunctionRefinements(tp.resType)) + case tp: ExprType => + tp.derivedExprType(resType = addFunctionRefinements(tp.resType)) + case _ => + tp + + /** Refine a possibly applied class type C where the class has tracked parameters + * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n } + * where CV_1, ..., CV_n are fresh capture sets. + */ + def addCaptureRefinements(tp: Type): Type = tp.stripped match + case _: TypeRef | _: AppliedType if tp.typeSymbol.isClass => + val cls = tp.typeSymbol.asClass + cls.paramGetters.foldLeft(tp) { (core, getter) => + if getter.termRef.isTracked then + val getterType = tp.memberInfo(getter).strippedDealias + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false)) + .showing(i"add capture refinement $tp --> $result", capt) + else + core + } + case _ => + tp + + def addVars(tp: Type, boxed: Boolean = false): Type = + var tp1 = addInnerVars(tp) + val tp2 = addCaptureRefinements(tp1) + if tp1.canHaveInferredCapture + then CapturingType(tp2, CaptureSet.Var(), boxed) + else tp2 + + if inferred then + val cleanup = new TypeMap: + def apply(t: Type) = t match + case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot => + apply(parent) + case _ => + mapOver(t) + addVars(addFunctionRefinements(cleanup(tp))) + .showing(i"reinfer $tp --> $result", capt) + else + val addBoxes = new TypeTraverser: + def setBoxed(t: Type) = t match + case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot => + annot.tree.setBoxedCapturing() + case _ => + + def traverse(t: Type) = + t match + case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) => + args.foreach(setBoxed) + case TypeBounds(lo, hi) => + setBoxed(lo); setBoxed(hi) + case _ => + traverseChildren(t) + end addBoxes + + addBoxes.traverse(tp) + tp + end transformType + + private def interpolator(using Context) = new TypeTraverser: + override def traverse(t: Type) = + t match + case CapturingType(parent, refs: CaptureSet.Var, _) => + if variance < 0 then capt.println(i"solving $t") + refs.solve(variance) + traverse(parent) + case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionOrPolyType(t) => + traverse(rinfo) + case tp: TypeVar => + case tp: TypeRef => + traverse(tp.prefix) + case _ => + traverseChildren(t) + + private def interpolateVarsIn(tpt: Tree)(using Context): Unit = + if tpt.isInstanceOf[InferredTypeTree] then + interpolator.traverse(knownType(tpt)) + .showing(i"solved vars in ${knownType(tpt)}", capt) + + private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null) + + private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap() + def capturedVars(sym: Symbol)(using Context) = + myCapturedVars.getOrElseUpdate(sym, + if sym.ownersIterator.exists(_.isTerm) then CaptureSet.Var() + else CaptureSet.empty) + + def markFree(sym: Symbol, pos: SrcPos)(using Context): Unit = + if sym.exists then + val ref = sym.termRef + def recur(env: Env): Unit = + if env.isOpen && env.owner != sym.enclosure then + capt.println(i"Mark $sym with cs ${ref.captureSet} free in ${env.owner}") + checkElem(ref, env.captured, pos) + if env.owner.isConstructor then + if env.outer.owner != sym.enclosure then recur(env.outer.outer) + else recur(env.outer) + if ref.isTracked then recur(curEnv) + + def includeCallCaptures(sym: Symbol, pos: SrcPos)(using Context): Unit = + if curEnv.isOpen then + val ownEnclosure = ctx.owner.enclosingMethodOrClass + var targetSet = capturedVars(sym) + if !targetSet.isAlwaysEmpty && sym.enclosure == ownEnclosure then + targetSet = targetSet.filter { + case ref: TermRef => ref.symbol.enclosure != ownEnclosure + case _ => true + } + checkSubset(targetSet, curEnv.captured, pos) + + def includeBoxedCaptures(tp: Type, pos: SrcPos)(using Context): Unit = + if curEnv.isOpen then + val ownEnclosure = ctx.owner.enclosingMethodOrClass + val targetSet = tp.boxedCaptured.filter { + case ref: TermRef => ref.symbol.enclosure != ownEnclosure + case _ => true + } + checkSubset(targetSet, curEnv.captured, pos) + + def assertSub(cs1: CaptureSet, cs2: CaptureSet)(using Context) = + assert(cs1.subCaptures(cs2, frozen = false).isOK, i"$cs1 is not a subset of $cs2") + + def checkElem(elem: CaptureRef, cs: CaptureSet, pos: SrcPos)(using Context) = + val res = elem.singletonCaptureSet.subCaptures(cs, frozen = false) + if !res.isOK then + report.error(i"$elem cannot be referenced here; it is not included in allowed capture set ${res.blocking}", pos) + + def checkSubset(cs1: CaptureSet, cs2: CaptureSet, pos: SrcPos)(using Context) = + val res = cs1.subCaptures(cs2, frozen = false) + if !res.isOK then + report.error(i"references $cs1 are not all included in allowed capture set ${res.blocking}", pos) + + override def recheckClosure(tree: Closure, pt: Type)(using Context): Type = + val cs = capturedVars(tree.meth.symbol) + recheckr.println(i"typing closure $tree with cvs $cs") + super.recheckClosure(tree, pt).capturing(cs) + .showing(i"rechecked $tree, $result", capt) + + override def recheckIdent(tree: Ident)(using Context): Type = + markFree(tree.symbol, tree.srcPos) + if tree.symbol.is(Method) then includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckIdent(tree) + + override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = + try super.recheckValDef(tree, sym) + finally + if !sym.is(Param) then + // parameters with inferred types belong to anonymous methods. We need to wait + // for more info from the context, so we cannot interpolate. Note that we cannot + // expect to have all necessary info available at the point where the anonymous + // function is compiled since we do not propagate expected types into blocks. + interpolateVarsIn(tree.tpt) + + override def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Unit = + val saved = curEnv + val localSet = capturedVars(sym) + if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, false, curEnv) + try super.recheckDefDef(tree, sym) + finally + interpolateVarsIn(tree.tpt) + curEnv = saved + + override def recheckClassDef(tree: TypeDef, impl: Template, cls: ClassSymbol)(using Context): Type = + for param <- cls.paramGetters do + if param.is(Private) && !param.info.captureSet.isAlwaysEmpty then + report.error( + "Implementation restriction: Class parameter with non-empty capture set must be a `val`", + param.srcPos) + val saved = curEnv + val localSet = capturedVars(cls) + if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, false, curEnv) + try super.recheckClassDef(tree, impl, cls) + finally curEnv = saved + + /** First half: Refine the type of a constructor call `new C(t_1, ..., t_n)` + * to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked + * parameters of C and T_1, ..., T_m are the types of the corresponding arguments. + * + * Second half: union of all capture sets of arguments to tracked parameters. + */ + private def addParamArgRefinements(core: Type, argTypes: List[Type], cls: ClassSymbol)(using Context): (Type, CaptureSet) = + cls.paramGetters.lazyZip(argTypes).foldLeft((core, CaptureSet.empty: CaptureSet)) { (acc, refine) => + val (core, allCaptures) = acc + val (getter, argType) = refine + if getter.termRef.isTracked then + (RefinedType(core, getter.name, argType), allCaptures ++ argType.captureSet) + else + (core, allCaptures) + } + + /** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`. + * This means: + * - Instantiate result type with actual arguments + * - If call is to a constructor: + * - remember types of arguments corresponding to tracked + * parameters in refinements. + * - add capture set of instantiated class to capture set of result type. + */ + override def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = + val ownType = + if mt.isResultDependent then SubstParamsMap(mt, argTypes)(mt.resType) + else mt.resType + if sym.isConstructor then + val cls = sym.owner.asClass + val (refined, cs) = addParamArgRefinements(ownType, argTypes, cls) + refined.capturing(cs ++ capturedVars(cls) ++ capturedVars(sym)) + .showing(i"constr type $mt with $argTypes%, % in $cls = $result", capt) + else ownType + + def recheckByNameArg(tree: Tree, pt: Type)(using Context): Type = + val closureDef(mdef) = tree + val arg = mdef.rhs + val localSet = CaptureSet.Var() + curEnv = Env(mdef.symbol, localSet, isBoxed = false, curEnv) + val result = + try + inContext(ctx.withOwner(mdef.symbol)) { + recheckStart(arg, pt).capturing(localSet) + } + finally curEnv = curEnv.outer + recheckFinish(result, arg, pt) + + override def recheckApply(tree: Apply, pt: Type)(using Context): Type = + if tree.symbol == defn.cbnArg then + recheckByNameArg(tree.args(0), pt) + else + includeCallCaptures(tree.symbol, tree.srcPos) + super.recheckApply(tree, pt) + + override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type = + val res = super.recheck(tree, pt) + if tree.isTerm then + includeBoxedCaptures(res, tree.srcPos) + res + + override def checkUnit(unit: CompilationUnit)(using Context): Unit = + withCaptureSetsExplained { + super.checkUnit(unit) + PostRefinerCheck.traverse(unit.tpdTree) + if ctx.settings.YccDebug.value then + show(unit.tpdTree) // this dows not print tree, but makes its variables visible for dependency printing + } + + def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit = + if disallowGlobal then + tree match + case LambdaTypeTree(_, restpt) => + checkNotGlobal(restpt, allArgs*) + case _ => + for ref <- knownType(tree).captureSet.elems do + val isGlobal = ref match + case ref: TermRef => + ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot) + case _ => false + val what = if ref.isRootCapability then "universal" else "global" + if isGlobal then + val notAllowed = i" is not allowed to capture the $what capability $ref" + def msg = tree match + case tree: InferredTypeTree => + i"""inferred type argument ${knownType(tree)}$notAllowed + | + |The inferred arguments are: [${allArgs.map(knownType)}%, %]""" + case _ => s"type argument$notAllowed" + report.error(msg, tree.srcPos) + + object PostRefinerCheck extends TreeTraverser: + def traverse(tree: Tree)(using Context) = + tree match + case _: InferredTypeTree => + case tree: TypeTree if !tree.span.isZeroExtent => + knownType(tree).foreachPart( + checkWellformedPost(_, tree.srcPos)) + knownType(tree).foreachPart { + case AnnotatedType(_, annot) => + checkWellformedPost(annot.tree) + case _ => + } + case tree1 @ TypeApply(fn, args) if disallowGlobal => + for arg <- args do + //println(i"checking $arg in $tree: ${knownType(tree).captureSet}") + checkNotGlobal(arg, args*) + case t: ValOrDefDef if t.tpt.isInstanceOf[InferredTypeTree] => + val sym = t.symbol + val isLocal = + sym.ownersIterator.exists(_.isTerm) + || sym.accessBoundary(defn.RootClass).isContainedIn(sym.topLevelClass) + + // The following classes of definitions need explicit capture types ... + if !isLocal // ... since external capture types are not inferred + || sym.owner.is(Trait) // ... since we do OverridingPairs checking before capture inference + || sym.allOverriddenSymbols.nonEmpty // ... since we do override checking before capture inference + then + val inferred = knownType(t.tpt) + def checkPure(tp: Type) = tp match + case CapturingType(_, refs, _) if !refs.elems.isEmpty => + val resultStr = if t.isInstanceOf[DefDef] then " result" else "" + report.error( + em"""Non-local $sym cannot have an inferred$resultStr type + |$inferred + |with non-empty capture set $refs. + |The type needs to be declared explicitly.""", t.srcPos) + case _ => + inferred.foreachPart(checkPure, StopAt.Static) + case _ => + traverseChildren(tree) + + def postRefinerCheck(tree: tpd.Tree)(using Context): Unit = + PostRefinerCheck.traverse(tree) + + end CaptureChecker +end CheckCaptures diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index 3b743906fd51..116c8ff9bbfc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -74,9 +74,8 @@ object Checking { } for (arg, which, bound) <- TypeOps.boundsViolations(args, boundss, instantiate, app) do report.error( - showInferred(DoesNotConformToBound(arg.tpe, which, bound), - app, tpt), - arg.srcPos.focus) + showInferred(DoesNotConformToBound(arg.tpe, which, bound), app, tpt), + arg.srcPos.focus) /** Check that type arguments `args` conform to corresponding bounds in `tl` * Note: This does not check the bounds of AppliedTypeTrees. These @@ -310,6 +309,7 @@ object Checking { case AndType(tp1, tp2) => isInteresting(tp1) || isInteresting(tp2) case OrType(tp1, tp2) => isInteresting(tp1) && isInteresting(tp2) case _: RefinedOrRecType | _: AppliedType => true + case tp: AnnotatedType => isInteresting(tp.parent) case _ => false } diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 7654b98995ff..de44dd0efb18 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -14,6 +14,7 @@ import Decorators._ import config.Printers.{gadts, typr, debug} import annotation.tailrec import reporting._ +import cc.{CapturingType, derivedCapturingType} import collection.mutable import scala.annotation.internal.sharable @@ -126,8 +127,8 @@ object Inferencing { couldInstantiateTypeVar(parent) case tp: AndOrType => couldInstantiateTypeVar(tp.tp1) || couldInstantiateTypeVar(tp.tp2) - case AnnotatedType(tp, _) => - couldInstantiateTypeVar(tp) + case tp: AnnotatedType => + couldInstantiateTypeVar(tp.parent) case _ => false @@ -527,6 +528,7 @@ object Inferencing { case tp: RefinedType => tp.derivedRefinedType(captureWildcards(tp.parent), tp.refinedName, tp.refinedInfo) case tp: RecType => tp.derivedRecType(captureWildcards(tp.parent)) case tp: LazyRef => captureWildcards(tp.ref) + case CapturingType(parent, refs, _) => tp.derivedCapturingType(captureWildcards(parent), refs) case tp: AnnotatedType => tp.derivedAnnotatedType(captureWildcards(tp.parent), tp.annot) case _ => tp } @@ -696,6 +698,7 @@ trait Inferencing { this: Typer => if !argType.isSingleton then argType = SkolemType(argType) argType <:< tvar case _ => + () // scala-meta complains if this is missing, but I could not mimimize further end constrainIfDependentParamRef } @@ -710,4 +713,3 @@ trait Inferencing { this: Typer => enum IfBottom: case ok, fail, flip - diff --git a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala index 4a77573a8386..aa9b84428426 100644 --- a/compiler/src/dotty/tools/dotc/typer/RefChecks.scala +++ b/compiler/src/dotty/tools/dotc/typer/RefChecks.scala @@ -224,7 +224,7 @@ object RefChecks { * TODO This still needs to be cleaned up; the current version is a straight port of what was there * before, but it looks too complicated and method bodies are far too large. */ - private def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = { + def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = { val self = clazz.thisType val upwardsSelf = upwardsThisType(clazz) var hasErrors = false diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 2a5a9ca284ac..cc72918f8040 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -15,6 +15,7 @@ import ProtoTypes._ import collection.mutable import reporting._ import Checking.{checkNoPrivateLeaks, checkNoWildcard} +import cc.CaptureSet trait TypeAssigner { import tpd.* @@ -191,6 +192,14 @@ trait TypeAssigner { if tpe.isError then tpe else errorType(ex"$whatCanNot be accessed as a member of $pre$where.$whyNot", pos) + def processAppliedType(tree: untpd.Tree, tp: Type)(using Context): Type = tp match + case AppliedType(tycon, args) => + val constr = tycon.typeSymbol + if constr == defn.andType then AndType(args(0), args(1)) + else if constr == defn.orType then OrType(args(0), args(1), soft = false) + else tp + case _ => tp + /** Type assignment method. Each method takes as parameters * - an untpd.Tree to which it assigns a type, * - typed child trees it needs to access to cpmpute that type, @@ -288,8 +297,12 @@ trait TypeAssigner { val ownType = fn.tpe.widen match { case fntpe: MethodType => if (sameLength(fntpe.paramInfos, args) || ctx.phase.prev.relaxedTyping) - if (fntpe.isResultDependent) safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes) - else fntpe.resultType + if fntpe.isCaptureDependent then + fntpe.resultType.substParams(fntpe, args.tpes) + else if fntpe.isResultDependent then + safeSubstParams(fntpe.resultType, fntpe.paramRefs, args.tpes) + else + fntpe.resultType else errorType(i"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos) case t => @@ -461,11 +474,10 @@ trait TypeAssigner { assert(!hasNamedArg(args) || ctx.reporter.errorsReported, tree) val tparams = tycon.tpe.typeParams val ownType = - if (sameLength(tparams, args)) - if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe) - else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false) - else tycon.tpe.appliedTo(args.tpes) - else wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos) + if !sameLength(tparams, args) then + wrongNumberOfTypeArgs(tycon.tpe, tparams, args, tree.srcPos) + else + processAppliedType(tree, tycon.tpe.appliedTo(args.tpes)) tree.withType(ownType) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 8883b492e2d9..39aabe6ce6c4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -49,6 +49,7 @@ import transform.TypeUtils._ import reporting._ import Nullables._ import NullOpsDecorator._ +import cc.CheckCaptures import config.Config object Typer { @@ -1134,8 +1135,8 @@ class Typer extends Namer */ private def decomposeProtoFunction(pt: Type, defaultArity: Int, pos: SrcPos)(using Context): (List[Type], untpd.Tree) = { def typeTree(tp: Type) = tp match { - case _: WildcardType => untpd.TypeTree() - case _ => untpd.TypeTree(tp) + case _: WildcardType => new untpd.InferredTypeTree() + case _ => untpd.InferredTypeTree(tp) } def interpolateWildcards = new TypeMap { def apply(t: Type): Type = t match @@ -1144,7 +1145,7 @@ class Typer extends Namer case _ => mapOver(t) } - val pt1 = pt.stripTypeVar.dealias + val pt1 = pt.strippedDealias if (pt1 ne pt1.dropDependentRefinement) && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType) then @@ -2560,6 +2561,8 @@ class Typer extends Namer registerNowarn(annot1, tree) val arg1 = typed(tree.arg, pt) if (ctx.mode is Mode.Type) { + if annot1.symbol.maybeOwner == defn.RetainsAnnot then + CheckCaptures.checkWellformed(annot1) if arg1.isType then assignType(cpy.Annotated(tree)(arg1, annot1), arg1, annot1) else diff --git a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala index ffca320d53d3..1fac0dac0913 100644 --- a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala +++ b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala @@ -12,8 +12,10 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] { def contains[E >: Elem <: AnyRef](x: E): Boolean def foreach(f: Elem => Unit): Unit def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A def toList: List[Elem] + def iterator: Iterator[Elem] final def isEmpty: Boolean = size == 0 @@ -55,8 +57,10 @@ object SimpleIdentitySet { def contains[E <: AnyRef](x: E): Boolean = false def foreach(f: Nothing => Unit): Unit = () def exists[E <: AnyRef](p: E => Boolean): Boolean = false + def map[B <: AnyRef](f: Nothing => B): SimpleIdentitySet[B] = empty def /: [A, E <: AnyRef](z: A)(f: (A, E) => A): A = z def toList = Nil + def iterator = Iterator.empty } private class Set1[+Elem <: AnyRef](x0: AnyRef) extends SimpleIdentitySet[Elem] { @@ -69,9 +73,12 @@ object SimpleIdentitySet { def foreach(f: Elem => Unit): Unit = f(x0.asInstanceOf[Elem]) def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + Set1(f(x0.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(z, x0.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: Nil + def iterator = Iterator.single(x0.asInstanceOf[Elem]) } private class Set2[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef) extends SimpleIdentitySet[Elem] { @@ -86,9 +93,15 @@ object SimpleIdentitySet { def foreach(f: Elem => Unit): Unit = { f(x0.asInstanceOf[Elem]); f(x1.asInstanceOf[Elem]) } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) || p(x1.asInstanceOf[E]) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + Set2(f(x0.asInstanceOf[Elem]), f(x1.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: Nil + def iterator = Iterator.tabulate(2) { + case 0 => x0.asInstanceOf[Elem] + case 1 => x1.asInstanceOf[Elem] + } } private class Set3[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef, x2: AnyRef) extends SimpleIdentitySet[Elem] { @@ -114,9 +127,16 @@ object SimpleIdentitySet { } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) || p(x1.asInstanceOf[E]) || p(x2.asInstanceOf[E]) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + Set3(f(x0.asInstanceOf[Elem]), f(x1.asInstanceOf[Elem]), f(x2.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]), x2.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: x2.asInstanceOf[Elem] :: Nil + def iterator = Iterator.tabulate(3) { + case 0 => x0.asInstanceOf[Elem] + case 1 => x1.asInstanceOf[Elem] + case 2 => x2.asInstanceOf[Elem] + } } private class SetN[+Elem <: AnyRef](val xs: Array[AnyRef]) extends SimpleIdentitySet[Elem] { @@ -156,6 +176,8 @@ object SimpleIdentitySet { } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = xs.asInstanceOf[Array[E]].exists(p) + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + SetN(xs.map(x => f(x.asInstanceOf[Elem]).asInstanceOf[AnyRef])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = xs.asInstanceOf[Array[E]].foldLeft(z)(f) def toList: List[Elem] = { @@ -163,6 +185,7 @@ object SimpleIdentitySet { foreach(buf += _) buf.toList } + def iterator = xs.iterator.asInstanceOf[Iterator[Elem]] override def ++ [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): SimpleIdentitySet[E] = that match { case that: SetN[?] => diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index 7ddbe5832e5f..13f99620a449 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -39,6 +39,7 @@ class CompilationTests { compileFilesInDir("tests/pos-special/isInstanceOf", allowDeepSubtypes.and("-Xfatal-warnings")), compileFilesInDir("tests/new", defaultOptions), compileFilesInDir("tests/pos-scala2", scala2CompatMode), + compileFilesInDir("tests/pos-custom-args/captures", defaultOptions.and("-Ycc")), compileFilesInDir("tests/pos-custom-args/erased", defaultOptions.and("-language:experimental.erasedDefinitions")), compileFilesInDir("tests/pos", defaultOptions.and("-Ysafe-init")), compileFilesInDir("tests/pos-deep-subtype", allowDeepSubtypes), @@ -136,6 +137,7 @@ class CompilationTests { compileFilesInDir("tests/neg-custom-args/allow-deep-subtypes", allowDeepSubtypes), compileFilesInDir("tests/neg-custom-args/explicit-nulls", defaultOptions.and("-Yexplicit-nulls")), compileFilesInDir("tests/neg-custom-args/no-experimental", defaultOptions.and("-Yno-experimental")), + compileFilesInDir("tests/neg-custom-args/captures", defaultOptions.and("-Ycc")), compileDir("tests/neg-custom-args/impl-conv", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), compileFile("tests/neg-custom-args/implicit-conversions-old.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), @@ -180,6 +182,7 @@ class CompilationTests { compileFile("tests/neg-custom-args/i7314.scala", defaultOptions.and("-Xfatal-warnings", "-source", "future")), compileFile("tests/neg-custom-args/feature-shadowing.scala", defaultOptions.and("-Xfatal-warnings", "-feature")), compileDir("tests/neg-custom-args/hidden-type-errors", defaultOptions.and("-explain")), + compileFile("tests/neg-custom-args/capt-wf.scala", defaultOptions.and("-Ycc", "-Xfatal-warnings")), ).checkExpectedErrors() } diff --git a/library/src-bootstrapped/scala/Retains.scala b/library/src-bootstrapped/scala/Retains.scala new file mode 100644 index 000000000000..f3bfa282a012 --- /dev/null +++ b/library/src-bootstrapped/scala/Retains.scala @@ -0,0 +1,6 @@ +package scala + +/** An annotation that indicates capture + */ +class retains(xs: Any*) extends annotation.StaticAnnotation + diff --git a/library/src-bootstrapped/scala/annotation/ability.scala b/library/src-bootstrapped/scala/annotation/ability.scala new file mode 100644 index 000000000000..8b327a2f8b02 --- /dev/null +++ b/library/src-bootstrapped/scala/annotation/ability.scala @@ -0,0 +1,9 @@ +package scala.annotation + +/** An annotation inidcating that a val should be tracked as its own ability. + * Example: + * + * @ability erased val canThrow: * = ??? + * ^^^ rename to capability + */ +class ability extends StaticAnnotation \ No newline at end of file diff --git a/library/src/scala/runtime/stdLibPatches/Predef.scala b/library/src/scala/runtime/stdLibPatches/Predef.scala index 13dfc77ac60b..387096ab55c5 100644 --- a/library/src/scala/runtime/stdLibPatches/Predef.scala +++ b/library/src/scala/runtime/stdLibPatches/Predef.scala @@ -47,4 +47,5 @@ object Predef: */ extension [T](x: T | Null) inline def nn: x.type & T = scala.runtime.Scala3RunTime.nn(x) + end Predef diff --git a/tests/disabled/neg-custom-args/captures/capt-wf.scala b/tests/disabled/neg-custom-args/captures/capt-wf.scala new file mode 100644 index 000000000000..54fe545f443b --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/capt-wf.scala @@ -0,0 +1,19 @@ +// No longer valid +class C +type Cap = C @retains(*) +type Top = Any @retains(*) + +type T = (x: Cap) => List[String @retains(x)] => Unit // error +val x: (x: Cap) => Array[String @retains(x)] = ??? // error +val y = x + +def test: Unit = + def f(x: Cap) = // ok + val g = (xs: List[String @retains(x)]) => () + g + def f2(x: Cap)(xs: List[String @retains(x)]) = () + val x = f // error + val x2 = f2 // error + val y = f(C()) // ok + val y2 = f2(C()) // ok + () diff --git a/tests/disabled/neg-custom-args/captures/try2.check b/tests/disabled/neg-custom-args/captures/try2.check new file mode 100644 index 000000000000..c7b20d0f7c5e --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/try2.check @@ -0,0 +1,38 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:31:32 ----------------------------------------- +31 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => Nothing + | Required: () => Nothing + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:45:2 ------------------------------------------ +45 | yy // error + | ^^ + | Found: (yy : List[(xx : (() => Int) retains canThrow)]) + | Required: List[() => Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try2.scala:52:2 ------------------------------------------ +47 |val global = handle { +48 | (x: CanThrow[Exception]) => +49 | () => +50 | raise(new Exception)(using x) +51 | 22 +52 |} { // error + | ^ + | Found: (() => Int) retains canThrow + | Required: () => Int +53 | (ex: Exception) => () => 22 +54 |} + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/try2.scala:24:28 -------------------------------------------------------------- +24 | val a = handle[Exception, CanThrow[Exception]] { // error + | ^^^^^^^^^^^^^^^^^^^ + | type argument is not allowed to capture the global capability (canThrow : *) +-- Error: tests/neg-custom-args/captures/try2.scala:36:11 -------------------------------------------------------------- +36 | val xx = handle { // error + | ^^^^^^ + |inferred type argument ((() => Int) retains canThrow) is not allowed to capture the global capability (canThrow : *) + | + |The inferred arguments are: [Exception, ((() => Int) retains canThrow)] diff --git a/tests/disabled/neg-custom-args/captures/try2.scala b/tests/disabled/neg-custom-args/captures/try2.scala new file mode 100644 index 000000000000..dd3cc890a197 --- /dev/null +++ b/tests/disabled/neg-custom-args/captures/try2.scala @@ -0,0 +1,55 @@ +// Retains syntax for classes not (yet?) supported +import language.experimental.erasedDefinitions +import annotation.ability + +@ability erased val canThrow: * = ??? + +class CanThrow[E <: Exception] extends Retains[canThrow.type] +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test: List[() => Int] = + val a = handle[Exception, CanThrow[Exception]] { // error + (x: CanThrow[Exception]) => x + }{ + (ex: Exception) => ??? + } + + val b = handle[Exception, () => Nothing] { + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) // error + } { + (ex: Exception) => ??? + } + + val xx = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 + } { + (ex: Exception) => () => 22 + } + val yy = xx :: Nil + yy // error + +val global = handle { + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 +} { // error + (ex: Exception) => () => 22 +} diff --git a/tests/disabled/pos/lazylist.scala b/tests/disabled/pos/lazylist.scala new file mode 100644 index 000000000000..be628113d2d8 --- /dev/null +++ b/tests/disabled/pos/lazylist.scala @@ -0,0 +1,51 @@ +package lazylists + +abstract class LazyList[+T]: + this: ({*} LazyList[T]) => + + def isEmpty: Boolean + def head: T + def tail: LazyList[T] + + def map[U](f: {*} T => U): {f, this} LazyList[U] = + if isEmpty then LazyNil + else LazyCons(f(head), () => tail.map(f)) + + def concat[U >: T](that: {*} LazyList[U]): {this, that} LazyList[U] + +// def flatMap[U](f: {*} T => LazyList[U]): {f, this} LazyList[U] + +class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + def isEmpty = false + def head = x + def tail: {*} LazyList[T] = xs() + def concat[U >: T](that: {*} LazyList[U]): {this, that} LazyList[U] = + LazyCons(x, () => xs().concat(that)) +// def flatMap[U](f: {*} T => LazyList[U]): {f, this} LazyList[U] = +// f(x).concat(xs().flatMap(f)) + +object LazyNil extends LazyList[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + def concat[U](that: {*} LazyList[U]): {that} LazyList[U] = that +// def flatMap[U](f: {*} Nothing => LazyList[U]): LazyList[U] = LazyNil + +def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = + xs.map(f) + +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap, cap3: Cap) = + def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil + def g(x: Int) = if cap2 == cap2 then x else 0 + def h(x: Int) = if cap3 == cap3 then x else 0 + val ref1 = LazyCons(1, () => f(LazyNil)) + val ref1c: {cap1} LazyList[Int] = ref1 + val ref2 = map(ref1, g) + val ref2c: {cap2, ref1} LazyList[Int] = ref2 + val ref3 = ref1.map(g) + val ref3c: {cap2, ref1} LazyList[Int] = ref3 + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(h) + val ref4c: {cap1, cap2, cap3} LazyList[Int] = ref4 \ No newline at end of file diff --git a/tests/neg/i9325.scala b/tests/neg-custom-args/allow-deep-subtypes/i9325.scala similarity index 100% rename from tests/neg/i9325.scala rename to tests/neg-custom-args/allow-deep-subtypes/i9325.scala diff --git a/tests/neg-custom-args/capt-wf.scala b/tests/neg-custom-args/capt-wf.scala new file mode 100644 index 000000000000..dc4d6a0d4bff --- /dev/null +++ b/tests/neg-custom-args/capt-wf.scala @@ -0,0 +1,35 @@ +class C +type Cap = {*} C + +object foo + +def test(c: Cap, other: String): Unit = + val x1: {*} C = ??? // OK + val x2: {other} C = ??? // error: cs is empty + val s1 = () => "abc" + val x3: {s1} C = ??? // error: cs is empty + val x3a: () => String = s1 + val s2 = () => if x1 == null then "" else "abc" + val x4: {s2} C = ??? // OK + val x5: {c, c} C = ??? // error: redundant + val x6: {c} {c} C = ??? // error: redundant + val x7: {c} Cap = ??? // error: redundant + val x8: {*} {c} C = ??? // OK + val x9: {c, *} C = ??? // error: redundant + val x10: {*, c} C = ??? // error: redundant + + def even(n: Int): Boolean = if n == 0 then true else odd(n - 1) + def odd(n: Int): Boolean = if n == 1 then true else even(n - 1) + val e1 = even + val o1 = odd + + val y1: {e1} String = ??? // error cs is empty + val y2: {o1} String = ??? // error cs is empty + + lazy val ev: (Int => Boolean) = (n: Int) => + lazy val od: (Int => Boolean) = (n: Int) => + if n == 1 then true else ev(n - 1) + if n == 0 then true else od(n - 1) + val y3: {ev} String = ??? // error cs is empty + + () \ No newline at end of file diff --git a/tests/neg-custom-args/captures/bounded.scala b/tests/neg-custom-args/captures/bounded.scala new file mode 100644 index 000000000000..dc2621e95a65 --- /dev/null +++ b/tests/neg-custom-args/captures/bounded.scala @@ -0,0 +1,14 @@ +class CC +type Cap = {*} CC + +def test(c: Cap) = + class B[X <: {c} Object](x: X): + def elem = x + def lateElem = () => x + + def f(x: Int): Int = if c == c then x else 0 + val b = new B(f) + val r1 = b.elem + val r1c: {c} Int => Int = r1 + val r2 = b.lateElem + val r2c: () => {c} Int => Int = r2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/boxmap.check b/tests/neg-custom-args/captures/boxmap.check new file mode 100644 index 000000000000..406077077af5 --- /dev/null +++ b/tests/neg-custom-args/captures/boxmap.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/boxmap.scala:14:2 ---------------------------------------- +14 | () => b[Box[B]]((x: A) => box(f(x))) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {f} () => ? Box[B] + | Required: () => Box[B] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/boxmap.scala b/tests/neg-custom-args/captures/boxmap.scala new file mode 100644 index 000000000000..e335320ef9d4 --- /dev/null +++ b/tests/neg-custom-args/captures/boxmap.scala @@ -0,0 +1,14 @@ +type Top = Any @retains(*) + +infix type ==> [A, B] = (A => B) @retains(*) + +type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) + +def box[T <: Top](x: T): Box[T] = + [K <: Top] => (k: T ==> K) => k(x) + +def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = + b[Box[B]]((x: A) => box(f(x))) + +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): () => Box[B] = + () => b[Box[B]]((x: A) => box(f(x))) // error diff --git a/tests/neg-custom-args/captures/byname.scala b/tests/neg-custom-args/captures/byname.scala new file mode 100644 index 000000000000..526cdc50952f --- /dev/null +++ b/tests/neg-custom-args/captures/byname.scala @@ -0,0 +1,10 @@ +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap) = + def f() = if cap1 == cap1 then g else g + def g(x: Int) = if cap2 == cap2 then 1 else x + def h(ff: => {cap2} Int => Int) = ff + h(f()) // error + + diff --git a/tests/neg-custom-args/captures/capt-box-env.scala b/tests/neg-custom-args/captures/capt-box-env.scala new file mode 100644 index 000000000000..e9743054076e --- /dev/null +++ b/tests/neg-custom-args/captures/capt-box-env.scala @@ -0,0 +1,12 @@ +class C +type Cap = {*} C + +class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + +def test(c: Cap) = + def f(x: Cap): Unit = if c == x then () + val p = Pair(f, f) + val g = () => p.fst == p.snd + val gc: () => Boolean = g // error diff --git a/tests/neg-custom-args/captures/capt-box.scala b/tests/neg-custom-args/captures/capt-box.scala new file mode 100644 index 000000000000..317fc064ec0b --- /dev/null +++ b/tests/neg-custom-args/captures/capt-box.scala @@ -0,0 +1,13 @@ +//import scala.retains +class C +type Cap = {*} C + +def test(x: Cap) = + + def foo(y: Cap) = if x == y then println() + + val x1 = foo + + val x2 = identity(x1) + + val x3: Cap => Unit = x2 // error \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-depfun.scala b/tests/neg-custom-args/captures/capt-depfun.scala new file mode 100644 index 000000000000..6b0beb92b313 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-depfun.scala @@ -0,0 +1,7 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? + val dc: (({y, z} String) => {y, z} String) = ac(g()) // error diff --git a/tests/neg-custom-args/captures/capt-depfun2.scala b/tests/neg-custom-args/captures/capt-depfun2.scala new file mode 100644 index 000000000000..874d753b048d --- /dev/null +++ b/tests/neg-custom-args/captures/capt-depfun2.scala @@ -0,0 +1,10 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val dc = ac(g()) // error: Needs explicit type Array[? >: String <: {y, z} String] + // This is a shortcoming of rechecking since the originally inferred + // type is `Array[String]` and the actual type after rechecking + // cannot be expressed as `Array[C String]` for any capture set C \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt-env.scala b/tests/neg-custom-args/captures/capt-env.scala new file mode 100644 index 000000000000..84b4b57a7930 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-env.scala @@ -0,0 +1,13 @@ +class C +type Cap = {*} C + +class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + +def test(c: Cap) = + def f(x: Cap): Unit = if c == x then () + val p = Pair(f, f) + val g = () => p.fst == p.snd + val gc: () => Boolean = g // error + diff --git a/tests/neg-custom-args/captures/capt-test.scala b/tests/neg-custom-args/captures/capt-test.scala new file mode 100644 index 000000000000..0c536a280f5c --- /dev/null +++ b/tests/neg-custom-args/captures/capt-test.scala @@ -0,0 +1,26 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: (CanThrow[E]) => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test: Unit = + val b = handle[Exception, () => Nothing] { // error + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) + } { + (ex: Exception) => ??? + } diff --git a/tests/neg-custom-args/captures/capt-wf-typer.scala b/tests/neg-custom-args/captures/capt-wf-typer.scala new file mode 100644 index 000000000000..5120e2b288d5 --- /dev/null +++ b/tests/neg-custom-args/captures/capt-wf-typer.scala @@ -0,0 +1,10 @@ +class C +type Cap = {*} C + +object foo + +def test(c: Cap, other: String): Unit = + val x7: {c} String = ??? // OK + val x8: String @retains(x7 + x7) = ??? // error + val x9: String @retains(foo) = ??? // error + () \ No newline at end of file diff --git a/tests/neg-custom-args/captures/capt1.check b/tests/neg-custom-args/captures/capt1.check new file mode 100644 index 000000000000..ce7c4833bf9c --- /dev/null +++ b/tests/neg-custom-args/captures/capt1.check @@ -0,0 +1,46 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:3:2 ------------------------------------------ +3 | () => if x == null then y else y // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => ? C + | Required: () => C + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:6:2 ------------------------------------------ +6 | () => if x == null then y else y // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: {x} () => ? C + | Required: Matchable + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:13:2 ----------------------------------------- +13 | def f(y: Int) = if x == null then y else y // error + | ^ + | Found: {x} Int => Int + | Required: Matchable +14 | f + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:20:2 ----------------------------------------- +20 | class F(y: Int) extends A: // error + | ^ + | Found: {x} A + | Required: A +21 | def m() = if x == null then y else y +22 | F(22) + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:25:2 ----------------------------------------- +25 | new A: // error + | ^ + | Found: {x} A + | Required: A +26 | def m() = if x == null then y else y + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:31:24 ---------------------------------------- +31 | val z2 = h[() => Cap](() => x)(() => C()) // error + | ^^^^^^^ + | Found: {x} () => ? Cap + | Required: () => Cap + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/capt1.scala b/tests/neg-custom-args/captures/capt1.scala new file mode 100644 index 000000000000..4da49c5f4f1e --- /dev/null +++ b/tests/neg-custom-args/captures/capt1.scala @@ -0,0 +1,34 @@ +class C +def f(x: C @retains(*), y: C): () => C = + () => if x == null then y else y // error + +def g(x: C @retains(*), y: C): Matchable = + () => if x == null then y else y // error + +def h1(x: C @retains(*), y: C): Any = + def f() = if x == null then y else y + () => f() // ok + +def h2(x: C @retains(*)): Matchable = + def f(y: Int) = if x == null then y else y // error + f + +class A +type Cap = C @retains(*) + +def h3(x: Cap): A = + class F(y: Int) extends A: // error + def m() = if x == null then y else y + F(22) + +def h4(x: Cap, y: Int): A = + new A: // error + def m() = if x == null then y else y + +def foo() = + val x: C @retains(*) = ??? + def h[X](a: X)(b: X) = a + val z2 = h[() => Cap](() => x)(() => C()) // error + val z3 = h[(() => Cap) @retains(x)](() => x)(() => C()) // ok + val z4 = h[(() => Cap) @retains(x)](() => x)(() => C()) // what was inferred for z3 + diff --git a/tests/neg-custom-args/captures/capt2.scala b/tests/neg-custom-args/captures/capt2.scala new file mode 100644 index 000000000000..1eee53463f6d --- /dev/null +++ b/tests/neg-custom-args/captures/capt2.scala @@ -0,0 +1,9 @@ +//import scala.retains +class C +type Cap = {*} C + +def f1(c: Cap): (() => {c} C) = () => c // error, but would be OK under capture abbreciations for funciton types +def f2(c: Cap): ({c} () => C) = () => c // error + +def h5(x: Cap): () => C = + f1(x) // error diff --git a/tests/neg-custom-args/captures/capt3.scala b/tests/neg-custom-args/captures/capt3.scala new file mode 100644 index 000000000000..80b937276f73 --- /dev/null +++ b/tests/neg-custom-args/captures/capt3.scala @@ -0,0 +1,26 @@ +class C +type Cap = C @retains(*) + +def test1() = + val x: Cap = C() + val y = () => { x; () } + val z = y + z: (() => Unit) // error + +def test2() = + val x: Cap = C() + def y = () => { x; () } + def z = y + z: (() => Unit) // error + +def test3() = + val x: Cap = C() + def y = () => { x; () } + val z = y + z: (() => Unit) // error + +def test4() = + val x: Cap = C() + val y = () => { x; () } + def z = y + z: (() => Unit) // error diff --git a/tests/neg-custom-args/captures/cc1.scala b/tests/neg-custom-args/captures/cc1.scala new file mode 100644 index 000000000000..ebd983c58fe9 --- /dev/null +++ b/tests/neg-custom-args/captures/cc1.scala @@ -0,0 +1,4 @@ +object Test: + + def f[A <: Matchable @retains(*)](x: A): Matchable = x // error + diff --git a/tests/neg-custom-args/captures/classes.scala b/tests/neg-custom-args/captures/classes.scala new file mode 100644 index 000000000000..b87d21913d4e --- /dev/null +++ b/tests/neg-custom-args/captures/classes.scala @@ -0,0 +1,12 @@ +class B +type Cap = {*} B +class C0(n: Cap) // error: class parameter must be a `val`. + +class C(val n: Cap): + def foo(): {n} B = n + +def test(x: Cap, y: Cap) = + val c0 = C(x) + val c1: C = c0 // error + val c2 = if ??? then C(x) else /*identity*/(C(y)) // TODO: uncomment + val c3: {x} C { val n: {x, y} B } = c2 // error diff --git a/tests/neg-custom-args/captures/io.scala b/tests/neg-custom-args/captures/io.scala new file mode 100644 index 000000000000..17c22a2111e4 --- /dev/null +++ b/tests/neg-custom-args/captures/io.scala @@ -0,0 +1,22 @@ +sealed trait IO: + def puts(msg: Any): Unit = println(msg) + +def test1 = + val IO : IO @retains(*) = new IO {} + def foo = {IO; IO.puts("hello") } + val x : () => Unit = () => foo // error: Found: (() => Unit) retains IO; Required: () => Unit + +def test2 = + val IO : IO @retains(*) = new IO {} + def puts(msg: Any, io: IO @retains(*)) = println(msg) + def foo() = puts("hello", IO) + val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + +type Capability[T] = T @retains(*) + +def test3 = + val IO : Capability[IO] = new IO {} + def puts(msg: Any, io: Capability[IO]) = println(msg) + def foo() = puts("hello", IO) + val x : () => Unit = () => foo() // error: Found: (() => Unit) retains IO; Required: () => Unit + diff --git a/tests/neg-custom-args/captures/lazylist.check b/tests/neg-custom-args/captures/lazylist.check new file mode 100644 index 000000000000..3a80de9bdf16 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylist.check @@ -0,0 +1,42 @@ +-- [E163] Declaration Error: tests/neg-custom-args/captures/lazylist.scala:22:6 ---------------------------------------- +22 | def tail: {*} LazyList[Nothing] = ??? // error overriding + | ^ + | error overriding method tail in class LazyList of type => lazylists.LazyList[Nothing]; + | method tail of type => {*} lazylists.LazyList[Nothing] has incompatible type + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:35:29 ------------------------------------- +35 | val ref1c: LazyList[Int] = ref1 // error + | ^^^^ + | Found: (ref1 : {cap1} lazylists.LazyCons[Int]{xs: {cap1} () => {*} lazylists.LazyList[Int]}) + | Required: lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:37:36 ------------------------------------- +37 | val ref2c: {ref1} LazyList[Int] = ref2 // error + | ^^^^ + | Found: (ref2 : {cap2, ref1} lazylists.LazyList[Int]) + | Required: {ref1} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:39:36 ------------------------------------- +39 | val ref3c: {cap2} LazyList[Int] = ref3 // error + | ^^^^ + | Found: (ref3 : {cap2, ref1} lazylists.LazyList[Int]) + | Required: {cap2} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazylist.scala:41:48 ------------------------------------- +41 | val ref4c: {cap1, ref3, cap3} LazyList[Int] = ref4 // error + | ^^^^ + | Found: (ref4 : {cap3, cap2, ref1, cap1} lazylists.LazyList[Int]) + | Required: {cap1, ref3, cap3} lazylists.LazyList[Int] + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/lazylist.scala:17:6 ----------------------------------------------------------- +17 | def tail = xs() // error: cannot have an inferred type + | ^^^^^^^^^^^^^^^ + | Non-local method tail cannot have an inferred result type + | {*} lazylists.LazyList[T] + | with non-empty capture set {*}. + | The type needs to be declared explicitly. diff --git a/tests/neg-custom-args/captures/lazylist.scala b/tests/neg-custom-args/captures/lazylist.scala new file mode 100644 index 000000000000..f7be43e8dc27 --- /dev/null +++ b/tests/neg-custom-args/captures/lazylist.scala @@ -0,0 +1,41 @@ +package lazylists + +abstract class LazyList[+T]: + this: ({*} LazyList[T]) => + + def isEmpty: Boolean + def head: T + def tail: LazyList[T] + + def map[U](f: {*} T => U): {f, this} LazyList[U] = + if isEmpty then LazyNil + else LazyCons(f(head), () => tail.map(f)) + +class LazyCons[+T](val x: T, val xs: {*} () => {*} LazyList[T]) extends LazyList[T]: + def isEmpty = false + def head = x + def tail = xs() // error: cannot have an inferred type + +object LazyNil extends LazyList[Nothing]: + def isEmpty = true + def head = ??? + def tail: {*} LazyList[Nothing] = ??? // error overriding + +def map[A, B](xs: {*} LazyList[A], f: {*} A => B): {f, xs} LazyList[B] = + xs.map(f) + +class CC +type Cap = {*} CC + +def test(cap1: Cap, cap2: Cap, cap3: Cap) = + def f[T](x: LazyList[T]): LazyList[T] = if cap1 == cap1 then x else LazyNil + def g(x: Int) = if cap2 == cap2 then x else 0 + def h(x: Int) = if cap3 == cap3 then x else 0 + val ref1 = LazyCons(1, () => f(LazyNil)) + val ref1c: LazyList[Int] = ref1 // error + val ref2 = map(ref1, g) + val ref2c: {ref1} LazyList[Int] = ref2 // error + val ref3 = ref1.map(g) + val ref3c: {cap2} LazyList[Int] = ref3 // error + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(h) + val ref4c: {cap1, ref3, cap3} LazyList[Int] = ref4 // error diff --git a/tests/neg-custom-args/captures/lazyref.check b/tests/neg-custom-args/captures/lazyref.check new file mode 100644 index 000000000000..2affed020dec --- /dev/null +++ b/tests/neg-custom-args/captures/lazyref.check @@ -0,0 +1,28 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:19:28 -------------------------------------- +19 | val ref1c: LazyRef[Int] = ref1 // error + | ^^^^ + | Found: (ref1 : {cap1} LazyRef[Int]{elem: {cap1} () => Int}) + | Required: LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:21:35 -------------------------------------- +21 | val ref2c: {cap2} LazyRef[Int] = ref2 // error + | ^^^^ + | Found: (ref2 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {cap2} LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:23:35 -------------------------------------- +23 | val ref3c: {ref1} LazyRef[Int] = ref3 // error + | ^^^^ + | Found: (ref3 : {cap2, ref1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {ref1} LazyRef[Int] + +longer explanation available when compiling with `-explain` +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/lazyref.scala:25:35 -------------------------------------- +25 | val ref4c: {cap1} LazyRef[Int] = ref4 // error + | ^^^^ + | Found: (ref4 : {cap2, cap1} LazyRef[Int]{elem: {*} () => Int}) + | Required: {cap1} LazyRef[Int] + +longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/lazyref.scala b/tests/neg-custom-args/captures/lazyref.scala new file mode 100644 index 000000000000..1002f9685675 --- /dev/null +++ b/tests/neg-custom-args/captures/lazyref.scala @@ -0,0 +1,25 @@ +class CC +type Cap = {*} CC + +class LazyRef[T](val elem: {*} () => T): + val get = elem + def map[U](f: {*} T => U): {f, this} LazyRef[U] = + new LazyRef(() => f(elem())) + +def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = + new LazyRef(() => f(ref.elem())) + +def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = + (ref1, f1) => map[A, B](ref1, f1) + +def test(cap1: Cap, cap2: Cap) = + def f(x: Int) = if cap1 == cap1 then x else 0 + def g(x: Int) = if cap2 == cap2 then x else 0 + val ref1 = LazyRef(() => f(0)) + val ref1c: LazyRef[Int] = ref1 // error + val ref2 = map(ref1, g) + val ref2c: {cap2} LazyRef[Int] = ref2 // error + val ref3 = ref1.map(g) + val ref3c: {ref1} LazyRef[Int] = ref3 // error + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(g) + val ref4c: {cap1} LazyRef[Int] = ref4 // error diff --git a/tests/neg-custom-args/captures/try.check b/tests/neg-custom-args/captures/try.check new file mode 100644 index 000000000000..bd95835c6525 --- /dev/null +++ b/tests/neg-custom-args/captures/try.check @@ -0,0 +1,25 @@ +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:28:43 ------------------------------------------ +28 | val b = handle[Exception, () => Nothing] { // error + | ^ + | Found: ? (x: CanThrow[Exception]) => {x} () => ? Nothing + | Required: CanThrow[Exception] => () => Nothing +29 | (x: CanThrow[Exception]) => () => raise(new Exception)(using x) +30 | } { + +longer explanation available when compiling with `-explain` +-- Error: tests/neg-custom-args/captures/try.scala:22:28 --------------------------------------------------------------- +22 | val a = handle[Exception, CanThrow[Exception]] { // error + | ^^^^^^^^^^^^^^^^^^^ + | type argument is not allowed to capture the universal capability (* : Any) +-- Error: tests/neg-custom-args/captures/try.scala:34:11 --------------------------------------------------------------- +34 | val xx = handle { // error + | ^^^^^^ + | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | + | The inferred arguments are: [? Exception, {*} () => Int] +-- Error: tests/neg-custom-args/captures/try.scala:46:13 --------------------------------------------------------------- +46 |val global = handle { // error + | ^^^^^^ + | inferred type argument {*} () => Int is not allowed to capture the universal capability (* : Any) + | + | The inferred arguments are: [? Exception, {*} () => Int] diff --git a/tests/neg-custom-args/captures/try.scala b/tests/neg-custom-args/captures/try.scala new file mode 100644 index 000000000000..804a16192be0 --- /dev/null +++ b/tests/neg-custom-args/captures/try.scala @@ -0,0 +1,53 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) +type Top = Any @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R <: Top](op: CanThrow[E] => R)(handler: E => R): R = + val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +def test = + val a = handle[Exception, CanThrow[Exception]] { // error + (x: CanThrow[Exception]) => x + }{ + (ex: Exception) => ??? + } + + val b = handle[Exception, () => Nothing] { // error + (x: CanThrow[Exception]) => () => raise(new Exception)(using x) + } { + (ex: Exception) => ??? + } + + val xx = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 + } { + (ex: Exception) => () => 22 + } + val yy = xx :: Nil + yy // OK + + +val global = handle { // error + (x: CanThrow[Exception]) => + () => + raise(new Exception)(using x) + 22 +} { + (ex: Exception) => () => 22 +} \ No newline at end of file diff --git a/tests/neg-custom-args/captures/try3.scala b/tests/neg-custom-args/captures/try3.scala new file mode 100644 index 000000000000..4fbb980b9e03 --- /dev/null +++ b/tests/neg-custom-args/captures/try3.scala @@ -0,0 +1,27 @@ +import java.io.IOException + +class CT[E] +type CanThrow[E] = {*} CT[E] +type Top = {*} Any + +def handle[E <: Exception, T <: Top](op: CanThrow[E] ?=> T)(handler: E => T): T = + val x: CanThrow[E] = ??? + try op(using x) + catch case ex: E => handler(ex) + +def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = + throw ex + +@main def Test: Int = + def f(a: Boolean) = + handle { // error + if !a then raise(IOException()) + (b: Boolean) => + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => -1 + } + val g = f(true) + g(false) // would raise an uncaught exception + f(true)(false) // would raise an uncaught exception diff --git a/tests/neg/multiLineOps.scala b/tests/neg/multiLineOps.scala index 8499cc9fe710..08a0a3925fd1 100644 --- a/tests/neg/multiLineOps.scala +++ b/tests/neg/multiLineOps.scala @@ -5,7 +5,7 @@ val x = 1 val b1 = { 22 * 22 // ok - */*one more*/22 // error: end of statement expected // error: not found: * + */*one more*/22 // error: end of statement expected } val b2: Boolean = { diff --git a/tests/neg/polymorphic-functions1.check b/tests/neg/polymorphic-functions1.check new file mode 100644 index 000000000000..86492e96dab5 --- /dev/null +++ b/tests/neg/polymorphic-functions1.check @@ -0,0 +1,7 @@ +-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 --------------------------------------------- +1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error + | ^ + | Found: [T] => (x: Int) => Int + | Required: [T] => (x: T) => x.type + +longer explanation available when compiling with `-explain` diff --git a/tests/neg/polymorphic-functions1.scala b/tests/neg/polymorphic-functions1.scala new file mode 100644 index 000000000000..de887f3b8c50 --- /dev/null +++ b/tests/neg/polymorphic-functions1.scala @@ -0,0 +1 @@ +val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error diff --git a/tests/pos-custom-args/captures/bounded.scala b/tests/pos-custom-args/captures/bounded.scala new file mode 100644 index 000000000000..fad0b50c2137 --- /dev/null +++ b/tests/pos-custom-args/captures/bounded.scala @@ -0,0 +1,14 @@ +class CC +type Cap = {*} CC + +def test(c: Cap) = + class B[X <: {c} Object](x: X): + def elem = x + def lateElem = () => x + + def f(x: Int): Int = if c == c then x else 0 + val b = new B(f) + val r1 = b.elem + val r1c: {c} Int => Int = r1 + val r2 = b.lateElem + val r2c: {c} () => {c} Int => Int = r2 \ No newline at end of file diff --git a/tests/pos-custom-args/captures/boxmap-paper.scala b/tests/pos-custom-args/captures/boxmap-paper.scala new file mode 100644 index 000000000000..ed8c648526d1 --- /dev/null +++ b/tests/pos-custom-args/captures/boxmap-paper.scala @@ -0,0 +1,38 @@ +infix type ==> [A, B] = {*} (A => B) + +type Cell[+T] = [K] => (T ==> K) => K + +def cell[T](x: T): Cell[T] = + [K] => (k: T ==> K) => k(x) + +def get[T](c: Cell[T]): T = c[T](identity) + +def map[A, B](c: Cell[A])(f: A ==> B): Cell[B] + = c[Cell[B]]((x: A) => cell(f(x))) + +def pureMap[A, B](c: Cell[A])(f: A => B): Cell[B] + = c[Cell[B]]((x: A) => cell(f(x))) + +def lazyMap[A, B](c: Cell[A])(f: A ==> B): {f} () => Cell[B] + = () => c[Cell[B]]((x: A) => cell(f(x))) + +trait IO: + def print(s: String): Unit + +def test(io: {*} IO) = + + val loggedOne: {io} () => Int = () => { io.print("1"); 1 } + + val c: Cell[{io} () => Int] + = cell[{io} () => Int](loggedOne) + + val g = (f: {io} () => Int) => + val x = f(); io.print(" + ") + val y = f(); io.print(s" = ${x + y}") + + val r = lazyMap[{io} () => Int, Unit](c)(f => g(f)) + val r2 = lazyMap[{io} () => Int, Unit](c)(g) + val r3 = lazyMap(c)(g) + val _ = r() + val _ = r2() + val _ = r3() diff --git a/tests/pos-custom-args/captures/boxmap.scala b/tests/pos-custom-args/captures/boxmap.scala new file mode 100644 index 000000000000..a0dcade2b179 --- /dev/null +++ b/tests/pos-custom-args/captures/boxmap.scala @@ -0,0 +1,20 @@ +type Top = Any @retains(*) + +infix type ==> [A, B] = (A => B) @retains(*) + +type Box[+T <: Top] = ([K <: Top] => (T ==> K) => K) + +def box[T <: Top](x: T): Box[T] = + [K <: Top] => (k: T ==> K) => k(x) + +def map[A <: Top, B <: Top](b: Box[A])(f: A ==> B): Box[B] = + b[Box[B]]((x: A) => box(f(x))) + +def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B): (() => Box[B]) @retains(f) = + () => b[Box[B]]((x: A) => box(f(x))) + +def test[A <: Top, B <: Top] = + def lazymap[A <: Top, B <: Top](b: Box[A])(f: A ==> B) = + () => b[Box[B]]((x: A) => box(f(x))) + val x: (b: Box[A]) => (f: A ==> B) => (() => Box[B]) @retains(f) = lazymap[A, B] + () diff --git a/tests/pos-custom-args/captures/byname.scala b/tests/pos-custom-args/captures/byname.scala new file mode 100644 index 000000000000..917154079b36 --- /dev/null +++ b/tests/pos-custom-args/captures/byname.scala @@ -0,0 +1,10 @@ +class CC +type Cap = {*} CC + +class I + +def test(cap1: Cap, cap2: Cap): {cap1} I = + def f() = if cap1 == cap1 then I() else I() + def h(x: => {cap1} I) = x + h(f()) + diff --git a/tests/pos-custom-args/captures/capt-depfun.scala b/tests/pos-custom-args/captures/capt-depfun.scala new file mode 100644 index 000000000000..6b99eff32692 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-depfun.scala @@ -0,0 +1,18 @@ +class C +type Cap = C @retains(*) + +type T = (x: Cap) => String @retains(x) + +val aa: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + +def f(y: Cap, z: Cap): String @retains(*) = + val a: ((x: Cap) => String @retains(x)) = (x: Cap) => "" + val b = a(y) + val c: String @retains(y) = b + def g(): C @retains(y, z) = ??? + val d = a(g()) + + val ac: ((x: Cap) => String @retains(x) => String @retains(x)) = ??? + val bc: (({y} String) => {y} String) = ac(y) + val dc: (String => {y, z} String) = ac(g()) + c diff --git a/tests/pos-custom-args/captures/capt-depfun2.scala b/tests/pos-custom-args/captures/capt-depfun2.scala new file mode 100644 index 000000000000..17f98b4a1554 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-depfun2.scala @@ -0,0 +1,8 @@ +class C +type Cap = C @retains(*) + +def f(y: Cap, z: Cap) = + def g(): C @retains(y, z) = ??? + val ac: ((x: Cap) => Array[String @retains(x)]) = ??? + val dc: Array[? >: String <: {y, z} String] = ac(g()) // needs to be inferred + val ec = ac(y) diff --git a/tests/pos-custom-args/captures/capt-test.scala b/tests/pos-custom-args/captures/capt-test.scala new file mode 100644 index 000000000000..f40bd2ff1746 --- /dev/null +++ b/tests/pos-custom-args/captures/capt-test.scala @@ -0,0 +1,35 @@ +abstract class LIST[+T]: + def isEmpty: Boolean + def head: T + def tail: LIST[T] + def map[U](f: {*} T => U): LIST[U] = + if isEmpty then NIL + else CONS(f(head), tail.map(f)) + +class CONS[+T](x: T, xs: LIST[T]) extends LIST[T]: + def isEmpty = false + def head = x + def tail = xs +object NIL extends LIST[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + +def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = + xs.map(f) + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val y = f + val ys = CONS(y, NIL) + val zs = + val z = g + CONS(z, ys) + val zsc: LIST[{d, y} Cap => Unit] = zs + + val a4 = zs.map(identity) + val a4c: LIST[{d, y} Cap => Unit] = a4 diff --git a/tests/pos-custom-args/captures/capt0.scala b/tests/pos-custom-args/captures/capt0.scala new file mode 100644 index 000000000000..c8ff8a102856 --- /dev/null +++ b/tests/pos-custom-args/captures/capt0.scala @@ -0,0 +1,7 @@ +object Test: + + def test() = + val x: {*} Any = "abc" + val y: Object @scala.retains(x) = ??? + val z: Object @scala.retains(x, *) = y: Object @scala.retains(x) + diff --git a/tests/pos-custom-args/captures/capt1.scala b/tests/pos-custom-args/captures/capt1.scala new file mode 100644 index 000000000000..14c0855544d4 --- /dev/null +++ b/tests/pos-custom-args/captures/capt1.scala @@ -0,0 +1,27 @@ +class C +type Cap = {*} C +def f1(c: Cap): {c} () => c.type = () => c // ok + +def f2: Int = + val g: {*} Boolean => Int = ??? + val x = g(true) + x + +def f3: Int = + def g: {*} Boolean => Int = ??? + def h = g + val x = g.apply(true) + x + +def foo() = + val x: {*} C = ??? + val y: {x} C = x + val x2: {x} () => C = ??? + val y2: {x} () => {x} C = x2 + + val z1: {*} () => Cap = f1(x) + def h[X](a: X)(b: X) = a + + val z2 = + if x == null then () => x else () => C() + x \ No newline at end of file diff --git a/tests/pos-custom-args/captures/capt2.scala b/tests/pos-custom-args/captures/capt2.scala new file mode 100644 index 000000000000..e3d4cd67b30c --- /dev/null +++ b/tests/pos-custom-args/captures/capt2.scala @@ -0,0 +1,20 @@ +import scala.retains +class C +type Cap = C @retains(*) + +def test1() = + val y: {*} String = "" + def x: Object @retains(y) = y + +def test2() = + val x: Cap = C() + val y = () => { x; () } + def z: (() => Unit) @retains(x) = y + z: (() => Unit) @retains(x) + def z2: (() => Unit) @retains(y) = y + z2: (() => Unit) @retains(y) + val p: {*} () => String = () => "abc" + val q: {p} C = ??? + p: ({p} () => String) + + diff --git a/tests/pos-custom-args/captures/cc-expand.scala b/tests/pos-custom-args/captures/cc-expand.scala new file mode 100644 index 000000000000..eedc95554b17 --- /dev/null +++ b/tests/pos-custom-args/captures/cc-expand.scala @@ -0,0 +1,21 @@ +object Test: + + class A + class B + class C + class CTC + type CT = CTC @retains(*) + + def test(ct: CT, dt: CT) = + + def x0: A => {ct} B = ??? + + def x1: A => B @retains(ct) = ??? + def x2: A => B => C @retains(ct) = ??? + def x3: A => () => B => C @retains(ct) = ??? + + def x4: (x: A @retains(ct)) => B => C = ??? + + def x5: A => (x: B @retains(ct)) => () => C @retains(dt) = ??? + def x6: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x, dt) = ??? + def x7: A => (x: B @retains(ct)) => (() => C @retains(dt)) @retains(x) = ??? \ No newline at end of file diff --git a/tests/pos-custom-args/captures/classes.scala b/tests/pos-custom-args/captures/classes.scala new file mode 100644 index 000000000000..f3d6e44b27ca --- /dev/null +++ b/tests/pos-custom-args/captures/classes.scala @@ -0,0 +1,34 @@ +class B +type Cap = {*} B +class C(val n: Cap): + this: ({n} C) => + def foo(): {n} B = n + + +def test(x: Cap, y: Cap, z: Cap) = + val c0 = C(x) + val c1: {x} C {val n: {x} B} = c0 + val d = c1.foo() + d: ({x} B) + + val c2 = if ??? then C(x) else C(y) + val c2a = identity(c2) + val c3: {x, y} C { val n: {x, y} B } = c2 + val d1 = c3.foo() + d1: B @retains(x, y) + + class Local: + + def this(a: Cap) = + this() + if a == z then println("?") + + val f = y + def foo = x + end Local + + val l = Local() + val l1: {x, y} Local = l + val l2 = Local(x) + val l3: {x, y, z} Local = l2 + diff --git a/tests/pos-custom-args/captures/iterators.scala b/tests/pos-custom-args/captures/iterators.scala new file mode 100644 index 000000000000..dd1067bcdc72 --- /dev/null +++ b/tests/pos-custom-args/captures/iterators.scala @@ -0,0 +1,23 @@ +package cctest + +abstract class Iterator[T]: + thisIterator => + + def hasNext: Boolean + def next: T + def map(f: {*} T => T): {f} Iterator[T] = new Iterator: + def hasNext = thisIterator.hasNext + def next = f(thisIterator.next) +end Iterator + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap, e: Cap) = + val it = new Iterator[Int]: + private var ctr = 0 + def hasNext = ctr < 10 + def next = { ctr += 1; ctr } + + def f(x: Int): Int = if c == d then x else 10 + val it2 = it.map(f) diff --git a/tests/pos-custom-args/captures/lazyref.scala b/tests/pos-custom-args/captures/lazyref.scala new file mode 100644 index 000000000000..39748b00506b --- /dev/null +++ b/tests/pos-custom-args/captures/lazyref.scala @@ -0,0 +1,25 @@ +class CC +type Cap = {*} CC + +class LazyRef[T](val elem: {*} () => T): + val get = elem + def map[U](f: {*} T => U): {f, this} LazyRef[U] = + new LazyRef(() => f(elem())) + +def map[A, B](ref: {*} LazyRef[A], f: {*} A => B): {f, ref} LazyRef[B] = + new LazyRef(() => f(ref.elem())) + +def mapc[A, B]: (ref: {*} LazyRef[A], f: {*} A => B) => {f, ref} LazyRef[B] = + (ref1, f1) => map[A, B](ref1, f1) + +def test(cap1: Cap, cap2: Cap) = + def f(x: Int) = if cap1 == cap1 then x else 0 + def g(x: Int) = if cap2 == cap2 then x else 0 + val ref1 = LazyRef(() => f(0)) + val ref1c: {cap1} LazyRef[Int] = ref1 + val ref2 = map(ref1, g) + val ref2c: {cap2, ref1} LazyRef[Int] = ref2 + val ref3 = ref1.map(g) + val ref3c: {cap2, ref1} LazyRef[Int] = ref3 + val ref4 = (if cap1 == cap2 then ref1 else ref2).map(g) + val ref4c: {cap1, cap2} LazyRef[Int] = ref4 diff --git a/tests/pos-custom-args/captures/list-encoding.scala b/tests/pos-custom-args/captures/list-encoding.scala new file mode 100644 index 000000000000..74bc8bd2b099 --- /dev/null +++ b/tests/pos-custom-args/captures/list-encoding.scala @@ -0,0 +1,23 @@ +package listEncoding + +class Cap + +type Op[T, C] = + {*} (v: T) => {*} (s: C) => C + +type List[T] = + [C] => (op: Op[T, C]) => {op} (s: C) => C + +def nil[T]: List[T] = + [C] => (op: Op[T, C]) => (s: C) => s + +def cons[T](hd: T, tl: List[T]): List[T] = + [C] => (op: Op[T, C]) => (s: C) => op(hd)(tl(op)(s)) + +def foo(c: {*} Cap) = + def f(x: String @retains(c), y: String @retains(c)) = + cons(x, cons(y, nil)) + def g(x: String @retains(c), y: Any) = + cons(x, cons(y, nil)) + def h(x: String, y: Any @retains(c)) = + cons(x, cons(y, nil)) diff --git a/tests/pos-custom-args/captures/lists.scala b/tests/pos-custom-args/captures/lists.scala new file mode 100644 index 000000000000..139f885ec87a --- /dev/null +++ b/tests/pos-custom-args/captures/lists.scala @@ -0,0 +1,91 @@ +abstract class LIST[+T]: + def isEmpty: Boolean + def head: T + def tail: LIST[T] + def map[U](f: {*} T => U): LIST[U] = + if isEmpty then NIL + else CONS(f(head), tail.map(f)) + +class CONS[+T](x: T, xs: LIST[T]) extends LIST[T]: + def isEmpty = false + def head = x + def tail = xs +object NIL extends LIST[Nothing]: + def isEmpty = true + def head = ??? + def tail = ??? + +def map[A, B](f: {*} A => B)(xs: LIST[A]): LIST[B] = + xs.map(f) + +class C +type Cap = {*} C + +def test(c: Cap, d: Cap, e: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val y = f + val ys = CONS(y, NIL) + val zs = + val z = g + CONS(z, ys) + val zsc: LIST[{d, y} Cap => Unit] = zs + val z1 = zs.head + val z1c: {y, d} Cap => Unit = z1 + val ys1 = zs.tail + val y1 = ys1.head + + + def m1[A, B] = + (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + + def m1c: (f: {*} String => Int) => {f} LIST[String] => LIST[Int] = m1[String, Int] + + def m2 = [A, B] => + (f: {*} A => B) => (xs: LIST[A]) => xs.map(f) + + def m2c: [A, B] => (f: {*} A => B) => {f} LIST[A] => LIST[B] = m2 + + def eff[A](x: A) = if x == e then x else x + + val eff2 = [A] => (x: A) => if x == e then x else x + + val a0 = identity[{d, y} Cap => Unit] + val a0c: ({d, y} Cap => Unit) => {d, y} Cap => Unit = a0 + val a1 = zs.map[{d, y} Cap => Unit](a0) + val a1c: LIST[{d, y} Cap => Unit] = a1 + val a2 = zs.map[{d, y} Cap => Unit](identity[{d, y} Cap => Unit]) + val a2c: LIST[{d, y} Cap => Unit] = a2 + val a3 = zs.map(identity[{d, y} Cap => Unit]) + val a3c: LIST[{d, y} Cap => Unit] = a3 + val a4 = zs.map(identity) + val a4c: LIST[{d, c} Cap => Unit] = a4 + val a5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) + val a5c: LIST[{d, c} Cap => Unit] = a5 + val a6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](identity)(zs) + val a6c: LIST[{d, c} Cap => Unit] = a6 + + val b0 = eff[{d, y} Cap => Unit] + val b0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = b0 + val b1 = zs.map[{d, y} Cap => Unit](a0) + val b1c: {e} LIST[{d, y} Cap => Unit] = b1 + val b2 = zs.map[{d, y} Cap => Unit](eff[{d, y} Cap => Unit]) + val b2c: {e} LIST[{d, y} Cap => Unit] = b2 + val b3 = zs.map(eff[{d, y} Cap => Unit]) + val b3c: {e} LIST[{d, y} Cap => Unit] = b3 + val b4 = zs.map(eff) + val b4c: {e} LIST[{d, c} Cap => Unit] = b4 + val b5 = map[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) + val b5c: {e} LIST[{d, c} Cap => Unit] = b5 + val b6 = m1[{d, y} Cap => Unit, {d, y} Cap => Unit](eff)(zs) + val b6c: {e} LIST[{d, c} Cap => Unit] = b6 + + val c0 = eff2[{d, y} Cap => Unit] + val c0c: {e} ({d, y} Cap => Unit) => {d, y} Cap => Unit = c0 + val c1 = zs.map[{d, y} Cap => Unit](a0) + val c1c: {e} LIST[{d, y} Cap => Unit] = c1 + val c2 = zs.map[{d, y} Cap => Unit](eff2[{d, y} Cap => Unit]) + val c2c: {e} LIST[{d, y} Cap => Unit] = c2 + val c3 = zs.map(eff2[{d, y} Cap => Unit]) + val c3c: {e} LIST[{d, y} Cap => Unit] = c3 + diff --git a/tests/pos-custom-args/captures/pairs.scala b/tests/pos-custom-args/captures/pairs.scala new file mode 100644 index 000000000000..4f23a086a075 --- /dev/null +++ b/tests/pos-custom-args/captures/pairs.scala @@ -0,0 +1,33 @@ + +class C +type Cap = {*} C + +object Generic: + + class Pair[+A, +B](x: A, y: B): + def fst: A = x + def snd: B = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: {c} Cap => Unit = x1 + val y1 = p.snd + val y1c: {d} Cap => Unit = y1 + +object Monomorphic: + + class Pair(val x: {*} Cap => Unit, val y: {*} Cap => Unit): + def fst = x + def snd = y + + def test(c: Cap, d: Cap) = + def f(x: Cap): Unit = if c == x then () + def g(x: Cap): Unit = if d == x then () + val p = Pair(f, g) + val x1 = p.fst + val x1c: {c} Cap => Unit = x1 + val y1 = p.snd + val y1c: {d} Cap => Unit = y1 diff --git a/tests/pos-custom-args/captures/try.scala b/tests/pos-custom-args/captures/try.scala new file mode 100644 index 000000000000..a50eeabfb3a3 --- /dev/null +++ b/tests/pos-custom-args/captures/try.scala @@ -0,0 +1,26 @@ +import language.experimental.erasedDefinitions + +class CT[E <: Exception] +type CanThrow[E <: Exception] = CT[E] @retains(*) + +infix type throws[R, E <: Exception] = (erased CanThrow[E]) ?=> R + +class Fail extends Exception + +def raise[E <: Exception](e: E): Nothing throws E = throw e + +def foo(x: Boolean): Int throws Fail = + if x then 1 else raise(Fail()) + +def handle[E <: Exception, R](op: (erased CanThrow[E]) => R)(handler: E => R): R = + erased val x: CanThrow[E] = ??? + try op(x) + catch case ex: E => handler(ex) + +val _ = handle { (erased x) => + if true then + raise(new Exception)(using x) + 22 + else + 11 + } \ No newline at end of file diff --git a/tests/pos-custom-args/captures/try3.scala b/tests/pos-custom-args/captures/try3.scala new file mode 100644 index 000000000000..074517d8a9e5 --- /dev/null +++ b/tests/pos-custom-args/captures/try3.scala @@ -0,0 +1,51 @@ +import language.experimental.erasedDefinitions +import annotation.ability +import java.io.IOException + +class CT[-E] // variance is needed for correct rechecking inference +type CanThrow[E] = {*} CT[E] + +def handle[E <: Exception, T](op: CanThrow[E] ?=> T)(handler: E => T): T = + val x: CanThrow[E] = ??? + try op(using x) + catch case ex: E => handler(ex) + +def raise[E <: Exception](ex: E)(using CanThrow[E]): Nothing = + throw ex + +def test1: Int = + def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + handle { + if !a then raise(IOException()) + (b: Boolean) => (_: CanThrow[IOException]) ?=> + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => (_: CanThrow[IOException]) ?=> -1 + } + handle { + val g = f(true) + g(false) // can raise an exception + f(true)(false) // can raise an exception + } { + ex => -1 + } +/* +def test2: Int = + def f(a: Boolean): Boolean => CanThrow[IOException] ?=> Int = + handle { // error + if !a then raise(IOException()) + (b: Boolean) => + if !b then raise(IOException()) + 0 + } { + ex => (b: Boolean) => -1 + } + handle { + val g = f(true) + g(false) // would raise an uncaught exception + f(true)(false) // would raise an uncaught exception + } { + ex => -1 + } +*/ \ No newline at end of file