From a9a02a4dc2752a9b38c46a9544de987474395077 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 11 Jul 2023 16:09:02 +0200 Subject: [PATCH] Merge provides all elements from the subsequences on cancellation On cancellation, merge currently does not yield all elements. This leads to situations in which the final elements of AsyncStreams are not forwarded to the user. This patch ensures, that only the underlying Task is cancelled and all subsequences' elements are forwarded to the user. --- .../Merge/MergeStateMachine.swift | 90 ++++++++++--------- .../AsyncAlgorithms/Merge/MergeStorage.swift | 11 +-- Tests/AsyncAlgorithmsTests/TestMerge.swift | 69 ++++++++++++++ 3 files changed, 118 insertions(+), 52 deletions(-) diff --git a/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift b/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift index bb832ada..5c6fd8fe 100644 --- a/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift +++ b/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift @@ -41,7 +41,8 @@ struct MergeStateMachine< buffer: Deque, upstreamContinuations: [UnsafeContinuation], upstreamsFinished: Int, - downstreamContinuation: UnsafeContinuation? + downstreamContinuation: UnsafeContinuation?, + cancelled: Bool ) /// The state once any of the upstream sequences threw an `Error`. @@ -100,11 +101,11 @@ struct MergeStateMachine< // Nothing to do here. No demand was signalled until now return .none - case .merging(_, _, _, _, .some): + case .merging(_, _, _, _, .some, _): // An iterator was deinitialized while we have a suspended continuation. preconditionFailure("Internal inconsistency current state \(self.state) and received iteratorDeinitialized()") - case let .merging(task, _, upstreamContinuations, _, .none): + case let .merging(task, _, upstreamContinuations, _, .none, _): // The iterator was dropped which signals that the consumer is finished. // We can transition to finished now and need to clean everything up. state = .finished @@ -142,7 +143,8 @@ struct MergeStateMachine< buffer: .init(), upstreamContinuations: [], // This should reserve capacity in the variadic generics case upstreamsFinished: 0, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: false ) case .merging, .upstreamFailure, .finished: @@ -175,11 +177,11 @@ struct MergeStateMachine< // Child tasks are only created after we transitioned to `merging` preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended()") - case .merging(_, _, _, _, .some): + case .merging(_, _, _, _, .some, _): // We have outstanding demand so request the next element return .resumeContinuation(upstreamContinuation: continuation) - case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none): + case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none, let cancelled): // There is no outstanding demand from the downstream // so we are storing the continuation and resume it once there is demand. state = .modifying @@ -191,7 +193,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .none @@ -236,7 +239,7 @@ struct MergeStateMachine< // Child tasks that are producing elements are only created after we transitioned to `merging` preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()") - case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation)): + case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation), cancelled): // We produced an element and have an outstanding downstream continuation // this means we can go right ahead and resume the continuation with that element precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") @@ -246,7 +249,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .resumeContinuation( @@ -254,7 +258,7 @@ struct MergeStateMachine< element: element ) - case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none): + case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled): // There is not outstanding downstream continuation so we must buffer the element // This happens if we race our upstream sequences to produce elements // and the _losers_ are signalling their produced element @@ -267,7 +271,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .none @@ -310,7 +315,7 @@ struct MergeStateMachine< case .initial: preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()") - case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation)): + case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation), let cancelled): // One of the upstreams finished precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") @@ -335,13 +340,14 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: downstreamContinuation + downstreamContinuation: downstreamContinuation, + cancelled: cancelled ) return .none } - case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none): + case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none, let cancelled): // First we increment our counter of finished upstreams upstreamsFinished += 1 @@ -350,7 +356,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) if upstreamsFinished == self.numberOfUpstreamSequences { @@ -402,7 +409,7 @@ struct MergeStateMachine< case .initial: preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()") - case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation)): + case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation), _): // An upstream threw an error and we have a downstream continuation. // We just need to resume the downstream continuation with the error and cancel everything precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") @@ -417,7 +424,7 @@ struct MergeStateMachine< upstreamContinuations: upstreamContinuations ) - case let .merging(task, buffer, upstreamContinuations, _, .none): + case let .merging(task, buffer, upstreamContinuations, _, .none, _): // An upstream threw an error and we don't have a downstream continuation. // We need to store the error and wait for the downstream to consume the // rest of the buffer and the error. However, we can already cancel the task @@ -454,10 +461,7 @@ struct MergeStateMachine< upstreamContinuations: [UnsafeContinuation] ) /// Indicates that the task and the upstream continuations should be cancelled. - case cancelTaskAndUpstreamContinuations( - task: Task, - upstreamContinuations: [UnsafeContinuation] - ) + case cancelTask(Task) /// Indicates that nothing should be done. case none } @@ -471,26 +475,21 @@ struct MergeStateMachine< return .none - case let .merging(task, _, upstreamContinuations, _, .some(downstreamContinuation)): - // The downstream Task got cancelled so we need to cancel our upstream Task - // and resume all continuations. We can also transition to finished. - state = .finished + case let .merging(task, buffer, upstreamContinuations, upstreamFinished, downstreamContinuation, cancelled): + guard !cancelled else { + return .none + } - return .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( - downstreamContinuation: downstreamContinuation, + self.state = .merging( task: task, - upstreamContinuations: upstreamContinuations + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamFinished, + downstreamContinuation: downstreamContinuation, + cancelled: true ) - case let .merging(task, _, upstreamContinuations, _, .none): - // The downstream Task got cancelled so we need to cancel our upstream Task - // and resume all continuations. We can also transition to finished. - state = .finished - - return .cancelTaskAndUpstreamContinuations( - task: task, - upstreamContinuations: upstreamContinuations - ) + return .cancelTask(task) case .upstreamFailure: // An upstream already threw and we cancelled everything already. @@ -531,11 +530,11 @@ struct MergeStateMachine< // We are transitioning to merging in the taskStarted method. return .startTaskAndSuspendDownstreamTask(base1, base2, base3) - case .merging(_, _, _, _, .some): + case .merging(_, _, _, _, .some, _): // We have multiple AsyncIterators iterating the sequence preconditionFailure("Internal inconsistency current state \(self.state) and received next()") - case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none): + case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled): state = .modifying if let element = buffer.popFirst() { @@ -545,7 +544,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .returnElement(.success(element)) @@ -556,7 +556,8 @@ struct MergeStateMachine< buffer: buffer, upstreamContinuations: upstreamContinuations, upstreamsFinished: upstreamsFinished, - downstreamContinuation: nil + downstreamContinuation: nil, + cancelled: cancelled ) return .suspendDownstreamTask @@ -601,21 +602,22 @@ struct MergeStateMachine< mutating func next(for continuation: UnsafeContinuation) -> NextForAction { switch state { case .initial, - .merging(_, _, _, _, .some), + .merging(_, _, _, _, .some, _), .upstreamFailure, .finished: // All other states are handled by `next` already so we should never get in here with // any of those preconditionFailure("Internal inconsistency current state \(self.state) and received next(for:)") - case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none): + case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none, cancelled): // We suspended the task and need signal the upstreams state = .merging( task: task, buffer: buffer, upstreamContinuations: [], // TODO: don't alloc new array here upstreamsFinished: upstreamsFinished, - downstreamContinuation: continuation + downstreamContinuation: continuation, + cancelled: cancelled ) return .resumeUpstreamContinuations( diff --git a/Sources/AsyncAlgorithms/Merge/MergeStorage.swift b/Sources/AsyncAlgorithms/Merge/MergeStorage.swift index 9dedee76..443c95cd 100644 --- a/Sources/AsyncAlgorithms/Merge/MergeStorage.swift +++ b/Sources/AsyncAlgorithms/Merge/MergeStorage.swift @@ -128,12 +128,7 @@ final class MergeStorage< downstreamContinuation.resume(returning: nil) - case let .cancelTaskAndUpstreamContinuations( - task, - upstreamContinuations - ): - upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } - + case let .cancelTask(task): task.cancel() case .none: @@ -262,8 +257,8 @@ final class MergeStorage< task, upstreamContinuations ): - upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } task.cancel() + upstreamContinuations.forEach { $0.resume() } downstreamContinuation.resume(returning: nil) @@ -273,8 +268,8 @@ final class MergeStorage< task, upstreamContinuations ): - upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } task.cancel() + upstreamContinuations.forEach { $0.resume() } break loop case .none: diff --git a/Tests/AsyncAlgorithmsTests/TestMerge.swift b/Tests/AsyncAlgorithmsTests/TestMerge.swift index c8d5e1ce..293e74a8 100644 --- a/Tests/AsyncAlgorithmsTests/TestMerge.swift +++ b/Tests/AsyncAlgorithmsTests/TestMerge.swift @@ -201,6 +201,38 @@ final class TestMerge2: XCTestCase { } t.cancel() } + + func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async { + let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self) + let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self) + continuation1.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation1.yield(1) + } + continuation2.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation2.yield(2) + } + continuation1.yield(0) // initial + let merge = merge(stream1, stream2) + let finished = expectation(description: "finished") + let iterated = expectation(description: "iterated") + let task = Task { + var count = 0 + for await _ in merge { + if count == 0 { iterated.fulfill() } + count += 1 + } + finished.fulfill() + XCTAssertEqual(count, 3) + } + // ensure the other task actually starts + await fulfillment(of: [iterated], timeout: 1.0) + // cancellation should ensure the loop finishes + // without regards to the remaining underlying sequence + task.cancel() + await fulfillment(of: [finished], timeout: 1.0) + } } final class TestMerge3: XCTestCase { @@ -555,4 +587,41 @@ final class TestMerge3: XCTestCase { iterator = nil } + + func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async { + let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self) + let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self) + let (stream3, continuation3) = AsyncStream.makeStream(of: Int.self) + continuation1.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation1.yield(1) + } + continuation2.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation2.yield(2) + } + continuation3.onTermination = { reason in + XCTAssertEqual(reason, .cancelled) + continuation3.yield(3) + } + continuation1.yield(0) // initial + let merge = merge(stream1, stream2, stream3) + let finished = expectation(description: "finished") + let iterated = expectation(description: "iterated") + let task = Task { + var count = 0 + for await _ in merge { + if count == 0 { iterated.fulfill() } + count += 1 + } + finished.fulfill() + XCTAssertEqual(count, 4) + } + // ensure the other task actually starts + await fulfillment(of: [iterated], timeout: 1.0) + // cancellation should ensure the loop finishes + // without regards to the remaining underlying sequence + task.cancel() + await fulfillment(of: [finished], timeout: 1.0) + } }