Skip to content

Commit 3a2451f

Browse files
committed
feat(CancellationSource): add cooperative cancellation to wait method
1 parent 3f3dd36 commit 3a2451f

File tree

2 files changed

+98
-19
lines changed

2 files changed

+98
-19
lines changed

Sources/AsyncObjects/CancellationSource/CancellationSource.swift

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Foundation
2+
import AsyncAlgorithms
23

34
/// An object that controls cooperative cancellation of multiple registered tasks and linked object registered tasks.
45
///
@@ -36,16 +37,19 @@ public struct CancellationSource: AsyncObject, Cancellable, Loggable {
3637
internal typealias Continuation = GlobalContinuation<Void, Error>
3738
/// The cancellable work with invocation context.
3839
internal typealias WorkItem = (
39-
Cancellable, id: UUID, file: String, function: String, line: UInt
40+
any Cancellable, id: UUID, file: String, function: String, line: UInt
4041
)
4142

4243
/// The lifetime task that is cancelled when
4344
/// `CancellationSource` is cancelled.
4445
@usableFromInline
45-
var lifetime: Task<Void, Error>!
46+
let lifetime: Task<Void, Error>
4647
/// The stream continuation used to register work items
4748
/// for cooperative cancellation.
48-
var pipe: AsyncStream<WorkItem>.Continuation!
49+
let pipe: AsyncStream<WorkItem>.Continuation
50+
/// The channel that controls waiting on the `CancellationSource`.
51+
/// Once `CancellationSource` is cancelled, channel finishes.
52+
let waiter: AsyncChannel<Void>
4953

5054
/// A Boolean value that indicates whether cancellation is already
5155
/// invoked on the source.
@@ -61,24 +65,57 @@ public struct CancellationSource: AsyncObject, Cancellable, Loggable {
6165
///
6266
/// - Returns: The newly created cancellation source.
6367
public init() {
64-
let stream = AsyncStream<WorkItem> { self.pipe = $0 }
65-
self.lifetime = Task.detached {
66-
try await withThrowingTaskGroup(of: Void.self) { group in
67-
for await item in stream {
68-
group.addTask {
69-
try? await waitHandlingCancelation(
70-
for: item.0, associatedId: item.id,
71-
file: item.file,
72-
function: item.function,
73-
line: item.line
74-
)
68+
var continuation: AsyncStream<WorkItem>.Continuation!
69+
let stream = AsyncStream<WorkItem> { continuation = $0 }
70+
let channel = AsyncChannel<Void>()
71+
self.pipe = continuation
72+
self.waiter = channel
73+
74+
func lifetime() -> Task<Void, Error> {
75+
return Task.detached {
76+
await withThrowingTaskGroup(of: Void.self) { group in
77+
for await item in stream {
78+
group.addTask {
79+
try? await waitHandlingCancelation(
80+
for: item.0, associatedId: item.id,
81+
file: item.file,
82+
function: item.function,
83+
line: item.line
84+
)
85+
}
7586
}
87+
88+
group.cancelAll()
7689
}
90+
channel.finish()
91+
}
92+
}
7793

78-
group.cancelAll()
79-
try await group.waitForAll()
94+
#if swift(>=5.8)
95+
if #available(macOS 13.3, iOS 16.4, tvOS 16.4, watchOS 9.4, *) {
96+
self.lifetime = Task.detached {
97+
await withDiscardingTaskGroup { group in
98+
for await item in stream {
99+
group.addTask {
100+
try? await waitHandlingCancelation(
101+
for: item.0, associatedId: item.id,
102+
file: item.file,
103+
function: item.function,
104+
line: item.line
105+
)
106+
}
107+
}
108+
109+
group.cancelAll()
110+
}
111+
channel.finish()
80112
}
113+
} else {
114+
self.lifetime = lifetime()
81115
}
116+
#else
117+
self.lifetime = lifetime()
118+
#endif
82119
}
83120

84121
/// Register cancellable work for cooperative cancellation
@@ -163,11 +200,17 @@ public struct CancellationSource: AsyncObject, Cancellable, Loggable {
163200
file: String = #fileID,
164201
function: String = #function,
165202
line: UInt = #line
166-
) async {
203+
) async throws {
167204
let id = UUID()
168205
log("Waiting", id: id, file: file, function: function, line: line)
169-
let _ = await lifetime.result
170-
log("Completed", id: id, file: file, function: function, line: line)
206+
await waiter.send(())
207+
do {
208+
try Task.checkCancellation()
209+
log("Completed", id: id, file: file, function: function, line: line)
210+
} catch {
211+
log("Cancelled", id: id, file: file, function: function, line: line)
212+
throw error
213+
}
171214
}
172215
}
173216

Tests/AsyncObjectsTests/CancellationSourceTests.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,39 @@ class CancellationSourceInitializationTests: XCTestCase {
184184
XCTAssertTrue(task.isCancelled)
185185
}
186186
}
187+
188+
@MainActor
189+
class CancellationSourceWaitTests: XCTestCase {
190+
191+
func testWithoutCancellation() async throws {
192+
let source = CancellationSource()
193+
let task = Task.detached {
194+
try await Task.sleep(seconds: 10)
195+
XCTFail("Unexpected task progression")
196+
}
197+
source.register(task: task)
198+
do {
199+
try await source.wait(forSeconds: 3)
200+
XCTFail("Unexpected task progression")
201+
} catch is DurationTimeoutError {}
202+
XCTAssertFalse(source.isCancelled)
203+
XCTAssertFalse(task.isCancelled)
204+
}
205+
206+
func testCooperativeCancellation() async throws {
207+
let source = CancellationSource()
208+
Task.detached(cancellationSource: source) {
209+
try await Task.sleep(seconds: 20)
210+
XCTFail("Unexpected task progression")
211+
}
212+
let task = Task.detached {
213+
do {
214+
try await source.wait(forSeconds: 5)
215+
XCTFail("Unexpected task progression")
216+
} catch is CancellationError {}
217+
}
218+
task.cancel()
219+
try await task.value
220+
XCTAssertFalse(source.isCancelled)
221+
}
222+
}

0 commit comments

Comments
 (0)