From ed6f151d3607c67e9f741b391602b8ed7bc60e96 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Fri, 17 Nov 2023 17:02:15 +0100 Subject: [PATCH] Refine `withReachCaptures` --- .../src/dotty/tools/dotc/cc/CaptureOps.scala | 30 +++++++++++++++++-- .../captures/reach-captures.scala | 14 +++++++++ .../captures/i15749a.scala | 2 +- .../pos-custom-args/captures/pair-reach.scala | 12 ++++++++ 4 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 tests/neg-custom-args/captures/reach-captures.scala rename tests/{neg-custom-args => pos-custom-args}/captures/i15749a.scala (87%) create mode 100644 tests/pos-custom-args/captures/pair-reach.scala diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index a9e4ce42087d..6994d35b6a04 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -245,14 +245,40 @@ extension (tp: Type) def withReachCaptures(ref: Type)(using Context): Type = object narrowCaps extends TypeMap: var ok = true + + private var hasImpureParam: Boolean = false + + def withImpureParamIf[T](cond: Boolean)(op: => T): T = + val saved = hasImpureParam + try + hasImpureParam ||= cond + op + finally hasImpureParam = saved + + def mapFunction(args: List[Type], restpe: Type, reconstruct: (List[Type], Type) => Type): Type = + def isImpureParam(param: Type): Boolean = param match + case CapturingType(_, cs) if cs.isUniversal => true + case _ => false + val args1 = atVariance(-variance): + args.mapConserve(this) + val restpe1 = withImpureParamIf(args.exists(isImpureParam) && variance > 0): + this(restpe) + reconstruct(args1, restpe1) + def apply(t: Type) = t.dealias match case t1 @ CapturingType(p, cs) if cs.isUniversal => if variance > 0 then + if hasImpureParam then + ok = false t1.derivedCapturingType(apply(p), ref.reach.singletonCaptureSet) else - ok = false - t + t1.derivedCapturingType(apply(p), cs) case _ => t match + case t: TermLambda => mapFunction(t.paramInfos, t.resultType, derivedLambdaType(t)(_, _)) + case defn.FunctionOf(args, restpe, isCtx) => + mapFunction(args, restpe, (args1, restpe1) => + if (args1 eq args) && (restpe1 eq restpe) then t + else defn.FunctionOf(args1, restpe1, isCtx)) case t @ CapturingType(p, cs) => t.derivedCapturingType(apply(p), cs) // don't map capture set variables case t => diff --git a/tests/neg-custom-args/captures/reach-captures.scala b/tests/neg-custom-args/captures/reach-captures.scala new file mode 100644 index 000000000000..8f16a404c4bd --- /dev/null +++ b/tests/neg-custom-args/captures/reach-captures.scala @@ -0,0 +1,14 @@ +import language.experimental.captureChecking +trait IO +def test1(): Unit = + val id: IO^ -> IO^ = x => x + val id1: IO^ -> IO^{id*} = id // error + +def test2(): Unit = + val f: (IO^ => Unit) => Unit = ??? + val f1: (IO^{f*} => Unit) ->{f*} Unit = f // ok + +def test3(): Unit = + val f: IO^ -> (IO^ => Unit) => Unit = ??? + val f1: IO^ -> (IO^{f*} => Unit) => Unit = f // error + val f2: IO^ -> (IO^ => Unit) ->{f*} Unit = f // error diff --git a/tests/neg-custom-args/captures/i15749a.scala b/tests/pos-custom-args/captures/i15749a.scala similarity index 87% rename from tests/neg-custom-args/captures/i15749a.scala rename to tests/pos-custom-args/captures/i15749a.scala index 0158928f4e39..23cfc8a27577 100644 --- a/tests/neg-custom-args/captures/i15749a.scala +++ b/tests/pos-custom-args/captures/i15749a.scala @@ -19,4 +19,4 @@ def test = def forceWrapper[A](mx: Wrapper[Unit ->{cap} A]): Wrapper[A] = // Γ ⊢ mx: Wrapper[□ {cap} Unit => A] // `force` should be typed as ∀(□ {cap} Unit -> A) A, but it can not - strictMap[Unit ->{mx*} A, A](mx)(t => force[A](t)) // error // should work + strictMap[Unit ->{mx*} A, A](mx)(t => force[A](t)) diff --git a/tests/pos-custom-args/captures/pair-reach.scala b/tests/pos-custom-args/captures/pair-reach.scala new file mode 100644 index 000000000000..9b0cd34c4db0 --- /dev/null +++ b/tests/pos-custom-args/captures/pair-reach.scala @@ -0,0 +1,12 @@ +import language.experimental.captureChecking + +trait IO + +type Pair[+T, +U] = [R] -> (op: (T, U) => R) -> R +def cons[T, U](a: T, b: U): Pair[T, U] = [R] => op => op(a, b) +def car[T, U](p: Pair[T, U]): T = p((a, b) => a) +def cdr[T, U](p: Pair[T, U]): U = p((a, b) => b) + +def foo(p: Pair[IO^, IO^]): Unit = + var x: IO^{p*} = null + x = car[IO^{p*}, IO^{p*}](p)