Skip to content

Commit 6a8aec2

Browse files
authored
#131 Cancel sends as if they were terminal events on task cancellation (#132)
1 parent 9d03d90 commit 6a8aec2

File tree

3 files changed

+146
-58
lines changed

3 files changed

+146
-58
lines changed

Sources/AsyncAlgorithms/AsyncChannel.swift

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -155,40 +155,69 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
155155
}
156156
}
157157

158+
func cancelSend() {
159+
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>], Set<Awaiting>) in
160+
if state.terminal {
161+
return ([], [])
162+
}
163+
state.terminal = true
164+
switch state.emission {
165+
case .idle:
166+
return ([], [])
167+
case .pending(let nexts):
168+
state.emission = .idle
169+
return (nexts, [])
170+
case .awaiting(let nexts):
171+
state.emission = .idle
172+
return ([], nexts)
173+
}
174+
}
175+
for send in sends {
176+
send.resume(returning: nil)
177+
}
178+
for next in nexts {
179+
next.continuation?.resume(returning: nil)
180+
}
181+
}
182+
158183
func _send(_ result: Result<Element?, Never>) async {
159-
let continuation: UnsafeContinuation<Element?, Never>? = await withUnsafeContinuation { continuation in
160-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
161-
if state.terminal {
162-
return UnsafeResumption(continuation: continuation, success: nil)
163-
}
164-
switch result {
165-
case .success(let value):
166-
if value == nil {
184+
await withTaskCancellationHandler {
185+
cancelSend()
186+
} operation: {
187+
let continuation: UnsafeContinuation<Element?, Never>? = await withUnsafeContinuation { continuation in
188+
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
189+
if state.terminal {
190+
return UnsafeResumption(continuation: continuation, success: nil)
191+
}
192+
switch result {
193+
case .success(let value):
194+
if value == nil {
195+
state.terminal = true
196+
}
197+
case .failure:
167198
state.terminal = true
168199
}
169-
case .failure:
170-
state.terminal = true
171-
}
172-
switch state.emission {
173-
case .idle:
174-
state.emission = .pending([continuation])
175-
return nil
176-
case .pending(var sends):
177-
sends.append(continuation)
178-
state.emission = .pending(sends)
179-
return nil
180-
case .awaiting(var nexts):
181-
let next = nexts.removeFirst().continuation
182-
if nexts.count == 0 {
183-
state.emission = .idle
184-
} else {
185-
state.emission = .awaiting(nexts)
200+
switch state.emission {
201+
case .idle:
202+
state.emission = .pending([continuation])
203+
return nil
204+
case .pending(var sends):
205+
sends.append(continuation)
206+
state.emission = .pending(sends)
207+
return nil
208+
case .awaiting(var nexts):
209+
let next = nexts.removeFirst().continuation
210+
if nexts.count == 0 {
211+
state.emission = .idle
212+
} else {
213+
state.emission = .awaiting(nexts)
214+
}
215+
return UnsafeResumption(continuation: continuation, success: next)
186216
}
187-
return UnsafeResumption(continuation: continuation, success: next)
188-
}
189-
}?.resume()
217+
}?.resume()
218+
}
219+
continuation?.resume(with: result)
190220
}
191-
continuation?.resume(with: result)
192221
}
193222

194223
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made.

Sources/AsyncAlgorithms/AsyncThrowingChannel.swift

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -153,40 +153,69 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
153153
}
154154
}
155155

156+
func cancelSend() {
157+
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>], Set<Awaiting>) in
158+
if state.terminal {
159+
return ([], [])
160+
}
161+
state.terminal = true
162+
switch state.emission {
163+
case .idle:
164+
return ([], [])
165+
case .pending(let nexts):
166+
state.emission = .idle
167+
return (nexts, [])
168+
case .awaiting(let nexts):
169+
state.emission = .idle
170+
return ([], nexts)
171+
}
172+
}
173+
for send in sends {
174+
send.resume(returning: nil)
175+
}
176+
for next in nexts {
177+
next.continuation?.resume(returning: nil)
178+
}
179+
}
180+
156181
func _send(_ result: Result<Element?, Error>) async {
157-
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
158-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
159-
if state.terminal {
160-
return UnsafeResumption(continuation: continuation, success: nil)
161-
}
162-
switch result {
163-
case .success(let value):
164-
if value == nil {
182+
await withTaskCancellationHandler {
183+
cancelSend()
184+
} operation: {
185+
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
186+
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
187+
if state.terminal {
188+
return UnsafeResumption(continuation: continuation, success: nil)
189+
}
190+
switch result {
191+
case .success(let value):
192+
if value == nil {
193+
state.terminal = true
194+
}
195+
case .failure:
165196
state.terminal = true
166197
}
167-
case .failure:
168-
state.terminal = true
169-
}
170-
switch state.emission {
171-
case .idle:
172-
state.emission = .pending([continuation])
173-
return nil
174-
case .pending(var sends):
175-
sends.append(continuation)
176-
state.emission = .pending(sends)
177-
return nil
178-
case .awaiting(var nexts):
179-
let next = nexts.removeFirst().continuation
180-
if nexts.count == 0 {
181-
state.emission = .idle
182-
} else {
183-
state.emission = .awaiting(nexts)
198+
switch state.emission {
199+
case .idle:
200+
state.emission = .pending([continuation])
201+
return nil
202+
case .pending(var sends):
203+
sends.append(continuation)
204+
state.emission = .pending(sends)
205+
return nil
206+
case .awaiting(var nexts):
207+
let next = nexts.removeFirst().continuation
208+
if nexts.count == 0 {
209+
state.emission = .idle
210+
} else {
211+
state.emission = .awaiting(nexts)
212+
}
213+
return UnsafeResumption(continuation: continuation, success: next)
184214
}
185-
return UnsafeResumption(continuation: continuation, success: next)
186-
}
187-
}?.resume()
215+
}?.resume()
216+
}
217+
continuation?.resume(with: result)
188218
}
189-
continuation?.resume(with: result)
190219
}
191220

192221
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made.

Tests/AsyncAlgorithmsTests/TestChannel.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,36 @@ final class TestChannel: XCTestCase {
117117
XCTAssertNil(value)
118118
}
119119

120+
func test_sendCancellation() async {
121+
let channel = AsyncChannel<Int>()
122+
let notYetDone = expectation(description: "not yet done")
123+
notYetDone.isInverted = true
124+
let done = expectation(description: "done")
125+
let task = Task {
126+
await channel.send(1)
127+
notYetDone.fulfill()
128+
done.fulfill()
129+
}
130+
wait(for: [notYetDone], timeout: 0.1)
131+
task.cancel()
132+
wait(for: [done], timeout: 1.0)
133+
}
134+
135+
func test_sendCancellation_throwing() async {
136+
let channel = AsyncThrowingChannel<Int, Error>()
137+
let notYetDone = expectation(description: "not yet done")
138+
notYetDone.isInverted = true
139+
let done = expectation(description: "done")
140+
let task = Task {
141+
await channel.send(1)
142+
notYetDone.fulfill()
143+
done.fulfill()
144+
}
145+
wait(for: [notYetDone], timeout: 0.1)
146+
task.cancel()
147+
wait(for: [done], timeout: 1.0)
148+
}
149+
120150
func test_cancellation_throwing() async throws {
121151
let channel = AsyncThrowingChannel<String, Error>()
122152
let ready = expectation(description: "ready")

0 commit comments

Comments
 (0)