diff --git a/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift b/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift deleted file mode 100644 index efbbf9a8..00000000 --- a/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift +++ /dev/null @@ -1,284 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -/// Creates an asynchronous sequence of elements from three underlying asynchronous sequences -public func merge(_ base1: Base1, _ base2: Base2, _ base3: Base3) -> AsyncMerge3Sequence -where - Base1.Element == Base2.Element, - Base2.Element == Base3.Element, - Base1: Sendable, Base2: Sendable, Base3: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base3.AsyncIterator: Sendable { - return AsyncMerge3Sequence(base1, base2, base3) -} - -/// An asynchronous sequence of elements from three underlying asynchronous sequences -/// -/// In a `AsyncMerge3Sequence` instance, the *i*th element is the *i*th element -/// resolved in sequential order out of the two underlying asynchronous sequences. -/// Use the `merge(_:_:_:)` function to create an `AsyncMerge3Sequence`. -public struct AsyncMerge3Sequence: AsyncSequence, Sendable -where - Base1.Element == Base2.Element, - Base2.Element == Base3.Element, - Base1: Sendable, Base2: Sendable, Base3: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base3.AsyncIterator: Sendable { - public typealias Element = Base1.Element - /// An iterator for `AsyncMerge3Sequence` - public struct Iterator: AsyncIteratorProtocol, Sendable { - enum Partial: @unchecked Sendable { - case first(Result, Base1.AsyncIterator) - case second(Result, Base2.AsyncIterator) - case third(Result, Base3.AsyncIterator) - } - - var state: (PartialIteration, PartialIteration, PartialIteration) - - init(_ iterator1: Base1.AsyncIterator, _ iterator2: Base2.AsyncIterator, _ iterator3: Base3.AsyncIterator) { - state = (.idle(iterator1), .idle(iterator2), .idle(iterator3)) - } - - mutating func apply(_ task1: Task?, _ task2: Task?, _ task3: Task?) async rethrows -> Element? { - switch await Task.select([task1, task2, task3].compactMap { $0 }).value { - case .first(let result, let iterator): - do { - guard let value = try state.0.resolve(result, iterator) else { - return try await next() - } - return value - } catch { - state.1.cancel() - state.2.cancel() - throw error - } - case .second(let result, let iterator): - do { - guard let value = try state.1.resolve(result, iterator) else { - return try await next() - } - return value - } catch { - state.0.cancel() - state.2.cancel() - throw error - } - case .third(let result, let iterator): - do { - guard let value = try state.2.resolve(result, iterator) else { - return try await next() - } - return value - } catch { - state.0.cancel() - state.1.cancel() - throw error - } - } - } - - func first(_ iterator1: Base1.AsyncIterator) -> Task { - Task { - var iter = iterator1 - do { - let value = try await iter.next() - return .first(.success(value), iter) - } catch { - return .first(.failure(error), iter) - } - } - } - - func second(_ iterator2: Base2.AsyncIterator) -> Task { - Task { - var iter = iterator2 - do { - let value = try await iter.next() - return .second(.success(value), iter) - } catch { - return .second(.failure(error), iter) - } - } - } - - func third(_ iterator3: Base3.AsyncIterator) -> Task { - Task { - var iter = iterator3 - do { - let value = try await iter.next() - return .third(.success(value), iter) - } catch { - return .third(.failure(error), iter) - } - } - } - - public mutating func next() async rethrows -> Element? { - // state must have either all terminal or at least 1 idle iterator - // state may not have a saturation of pending tasks - switch state { - // three idle - case (.idle(let iterator1), .idle(let iterator2), .idle(let iterator3)): - let task1 = first(iterator1) - let task2 = second(iterator2) - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - // two idle - case (.idle(let iterator1), .idle(let iterator2), .pending(let task3)): - let task1 = first(iterator1) - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.idle(let iterator1), .pending(let task2), .idle(let iterator3)): - let task1 = first(iterator1) - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.pending(let task1), .idle(let iterator2), .idle(let iterator3)): - let task2 = second(iterator2) - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - - // 1 idle - case (.idle(let iterator1), .pending(let task2), .pending(let task3)): - let task1 = first(iterator1) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.pending(let task1), .idle(let iterator2), .pending(let task3)): - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.pending(let task1), .pending(let task2), .idle(let iterator3)): - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - - // terminal degradations - // 1 terminal - case (.terminal, .idle(let iterator2), .idle(let iterator3)): - let task2 = second(iterator2) - let task3 = third(iterator3) - state = (.terminal, .pending(task2), .pending(task3)) - return try await apply(nil, task2, task3) - case (.terminal, .idle(let iterator2), .pending(let task3)): - let task2 = second(iterator2) - state = (.terminal, .pending(task2), .pending(task3)) - return try await apply(nil, task2, task3) - case (.terminal, .pending(let task2), .idle(let iterator3)): - let task3 = third(iterator3) - state = (.terminal, .pending(task2), .pending(task3)) - return try await apply(nil, task2, task3) - case (.idle(let iterator1), .terminal, .idle(let iterator3)): - let task1 = first(iterator1) - let task3 = third(iterator3) - state = (.pending(task1), .terminal, .pending(task3)) - return try await apply(task1, nil, task3) - case (.idle(let iterator1), .terminal, .pending(let task3)): - let task1 = first(iterator1) - state = (.pending(task1), .terminal, .pending(task3)) - return try await apply(task1, nil, task3) - case (.pending(let task1), .terminal, .idle(let iterator3)): - let task3 = third(iterator3) - state = (.pending(task1), .terminal, .pending(task3)) - return try await apply(task1, nil, task3) - case (.idle(let iterator1), .idle(let iterator2), .terminal): - let task1 = first(iterator1) - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .terminal) - return try await apply(task1, task2, nil) - case (.idle(let iterator1), .pending(let task2), .terminal): - let task1 = first(iterator1) - state = (.pending(task1), .pending(task2), .terminal) - return try await apply(task1, task2, nil) - case (.pending(let task1), .idle(let iterator2), .terminal): - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .terminal) - return try await apply(task1, task2, nil) - - // 2 terminal - // these can be permuted in place since they don't need to run two or more tasks at once - case (.terminal, .terminal, .idle(var iterator3)): - do { - if let value = try await iterator3.next() { - state = (.terminal, .terminal, .idle(iterator3)) - return value - } else { - state = (.terminal, .terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal, .terminal) - throw error - } - case (.terminal, .idle(var iterator2), .terminal): - do { - if let value = try await iterator2.next() { - state = (.terminal, .idle(iterator2), .terminal) - return value - } else { - state = (.terminal, .terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal, .terminal) - throw error - } - case (.idle(var iterator1), .terminal, .terminal): - do { - if let value = try await iterator1.next() { - state = (.idle(iterator1), .terminal, .terminal) - return value - } else { - state = (.terminal, .terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal, .terminal) - throw error - } - // 3 terminal - case (.terminal, .terminal, .terminal): - return nil - // partials - case (.pending(let task1), .pending(let task2), .pending(let task3)): - return try await apply(task1, task2, task3) - case (.pending(let task1), .pending(let task2), .terminal): - return try await apply(task1, task2, nil) - case (.pending(let task1), .terminal, .pending(let task3)): - return try await apply(task1, nil, task3) - case (.terminal, .pending(let task2), .pending(let task3)): - return try await apply(nil, task2, task3) - case (.pending(let task1), .terminal, .terminal): - return try await apply(task1, nil, nil) - case (.terminal, .pending(let task2), .terminal): - return try await apply(nil, task2, nil) - case (.terminal, .terminal, .pending(let task3)): - return try await apply(nil, nil, task3) - } - } - } - - let base1: Base1 - let base2: Base2 - let base3: Base3 - - init(_ base1: Base1, _ base2: Base2, _ base3: Base3) { - self.base1 = base1 - self.base2 = base2 - self.base3 = base3 - } - - public func makeAsyncIterator() -> Iterator { - return Iterator(base1.makeAsyncIterator(), base2.makeAsyncIterator(), base3.makeAsyncIterator()) - } -} diff --git a/Sources/AsyncAlgorithms/Locking.swift b/Sources/AsyncAlgorithms/Locking.swift index eedad1ee..74396080 100644 --- a/Sources/AsyncAlgorithms/Locking.swift +++ b/Sources/AsyncAlgorithms/Locking.swift @@ -87,6 +87,27 @@ internal struct Lock { func unlock() { Lock.unlock(platformLock) } + + /// Acquire the lock for the duration of the given block. + /// + /// This convenience method should be preferred to `lock` and `unlock` in + /// most situations, as it ensures that the lock will be released regardless + /// of how `body` exits. + /// + /// - Parameter body: The block to execute while holding the lock. + /// - Returns: The value returned by the block. + func withLock(_ body: () throws -> T) rethrows -> T { + self.lock() + defer { + self.unlock() + } + return try body() + } + + // specialise Void return (for performance) + func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + try self.withLock(body) + } } struct ManagedCriticalState { diff --git a/Sources/AsyncAlgorithms/Merge/AsyncMerge2Sequence.swift b/Sources/AsyncAlgorithms/Merge/AsyncMerge2Sequence.swift new file mode 100644 index 00000000..d1a4b4ec --- /dev/null +++ b/Sources/AsyncAlgorithms/Merge/AsyncMerge2Sequence.swift @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import DequeModule + +/// Creates an asynchronous sequence of elements from two underlying asynchronous sequences +public func merge(_ base1: Base1, _ base2: Base2) -> AsyncMerge2Sequence + where + Base1.Element == Base2.Element, + Base1: Sendable, Base2: Sendable, + Base1.Element: Sendable +{ + return AsyncMerge2Sequence(base1, base2) +} + +/// An ``Swift/AsyncSequence`` that takes two upstream ``Swift/AsyncSequence``s and combines their elements. +public struct AsyncMerge2Sequence< + Base1: AsyncSequence, + Base2: AsyncSequence +>: Sendable where + Base1.Element == Base2.Element, + Base1: Sendable, Base2: Sendable, + Base1.Element: Sendable +{ + public typealias Element = Base1.Element + + private let base1: Base1 + private let base2: Base2 + + /// Initializes a new ``AsyncMerge2Sequence``. + /// + /// - Parameters: + /// - base1: The first upstream ``Swift/AsyncSequence``. + /// - base2: The second upstream ``Swift/AsyncSequence``. + public init( + _ base1: Base1, + _ base2: Base2 + ) { + self.base1 = base1 + self.base2 = base2 + } +} + +extension AsyncMerge2Sequence: AsyncSequence { + public func makeAsyncIterator() -> AsyncIterator { + let storage = MergeStorage( + base1: base1, + base2: base2, + base3: nil + ) + return AsyncIterator(storage: storage) + } +} + +extension AsyncMerge2Sequence { + public struct AsyncIterator: AsyncIteratorProtocol { + /// This class is needed to hook the deinit to observe once all references to the ``AsyncIterator`` are dropped. + /// + /// If we get move-only types we should be able to drop this class and use the `deinit` of the ``AsyncIterator`` struct itself. + final class InternalClass: Sendable { + private let storage: MergeStorage + + fileprivate init(storage: MergeStorage) { + self.storage = storage + } + + deinit { + self.storage.iteratorDeinitialized() + } + + func next() async rethrows -> Element? { + try await storage.next() + } + } + + let internalClass: InternalClass + + fileprivate init(storage: MergeStorage) { + internalClass = InternalClass(storage: storage) + } + + public mutating func next() async rethrows -> Element? { + try await internalClass.next() + } + } +} diff --git a/Sources/AsyncAlgorithms/Merge/AsyncMerge3Sequence.swift b/Sources/AsyncAlgorithms/Merge/AsyncMerge3Sequence.swift new file mode 100644 index 00000000..579e0743 --- /dev/null +++ b/Sources/AsyncAlgorithms/Merge/AsyncMerge3Sequence.swift @@ -0,0 +1,105 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import DequeModule + +/// Creates an asynchronous sequence of elements from two underlying asynchronous sequences +public func merge< + Base1: AsyncSequence, + Base2: AsyncSequence, + Base3: AsyncSequence +>(_ base1: Base1, _ base2: Base2, _ base3: Base3) -> AsyncMerge3Sequence + where + Base1.Element == Base2.Element, + Base1.Element == Base3.Element, + Base1: Sendable, Base2: Sendable, Base3: Sendable, + Base1.Element: Sendable +{ + return AsyncMerge3Sequence(base1, base2, base3) +} + +/// An ``Swift/AsyncSequence`` that takes three upstream ``Swift/AsyncSequence``s and combines their elements. +public struct AsyncMerge3Sequence< + Base1: AsyncSequence, + Base2: AsyncSequence, + Base3: AsyncSequence +>: Sendable where + Base1.Element == Base2.Element, + Base1.Element == Base3.Element, + Base1: Sendable, Base2: Sendable, Base3: Sendable, + Base1.Element: Sendable +{ + public typealias Element = Base1.Element + + private let base1: Base1 + private let base2: Base2 + private let base3: Base3 + + /// Initializes a new ``AsyncMerge2Sequence``. + /// + /// - Parameters: + /// - base1: The first upstream ``Swift/AsyncSequence``. + /// - base2: The second upstream ``Swift/AsyncSequence``. + /// - base3: The third upstream ``Swift/AsyncSequence``. + public init( + _ base1: Base1, + _ base2: Base2, + _ base3: Base3 + ) { + self.base1 = base1 + self.base2 = base2 + self.base3 = base3 + } +} + +extension AsyncMerge3Sequence: AsyncSequence { + public func makeAsyncIterator() -> AsyncIterator { + let storage = MergeStorage( + base1: base1, + base2: base2, + base3: base3 + ) + return AsyncIterator(storage: storage) + } +} + +public extension AsyncMerge3Sequence { + struct AsyncIterator: AsyncIteratorProtocol { + /// This class is needed to hook the deinit to observe once all references to the ``AsyncIterator`` are dropped. + /// + /// If we get move-only types we should be able to drop this class and use the `deinit` of the ``AsyncIterator`` struct itself. + final class InternalClass: Sendable { + private let storage: MergeStorage + + fileprivate init(storage: MergeStorage) { + self.storage = storage + } + + deinit { + self.storage.iteratorDeinitialized() + } + + func next() async rethrows -> Element? { + try await storage.next() + } + } + + let internalClass: InternalClass + + fileprivate init(storage: MergeStorage) { + internalClass = InternalClass(storage: storage) + } + + public mutating func next() async rethrows -> Element? { + try await internalClass.next() + } + } +} diff --git a/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift b/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift new file mode 100644 index 00000000..ba0f2940 --- /dev/null +++ b/Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift @@ -0,0 +1,627 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import DequeModule + +/// The state machine for any of the `merge` operator. +/// +/// Right now this state machine supports 3 upstream `AsyncSequences`; however, this can easily be extended. +/// Once variadic generic land we should migrate this to use them instead. +struct MergeStateMachine< + Base1: AsyncSequence, + Base2: AsyncSequence, + Base3: AsyncSequence +> where + Base1.Element == Base2.Element, + Base1.Element == Base3.Element, + Base1: Sendable, Base2: Sendable, Base3: Sendable, + Base1.Element: Sendable +{ + typealias Element = Base1.Element + + private enum State { + /// The initial state before a call to `makeAsyncIterator` happened. + case initial( + base1: Base1, + base2: Base2, + base3: Base3? + ) + + /// The state after `makeAsyncIterator` was called and we created our `Task` to consume the upstream. + case merging( + task: Task, + buffer: Deque, + upstreamContinuations: [UnsafeContinuation], + upstreamsFinished: Int, + downstreamContinuation: UnsafeContinuation? + ) + + /// The state once any of the upstream sequences threw an `Error`. + case upstreamFailure( + buffer: Deque, + error: Error + ) + + /// The state once all upstream sequences finished or the downstream consumer stopped, i.e. by dropping all references + /// or by getting their `Task` cancelled. + case finished + + /// Internal state to avoid CoW. + case modifying + } + + /// The state machine's current state. + private var state: State + + private let numberOfUpstreamSequences: Int + + /// Initializes a new `StateMachine`. + init( + base1: Base1, + base2: Base2, + base3: Base3? + ) { + state = .initial( + base1: base1, + base2: base2, + base3: base3 + ) + + if base3 == nil { + self.numberOfUpstreamSequences = 2 + } else { + self.numberOfUpstreamSequences = 3 + } + } + + /// Actions returned by `iteratorDeinitialized()`. + enum IteratorDeinitializedAction { + /// Indicates that the `Task` needs to be cancelled and + /// all upstream continuations need to be resumed with a `CancellationError`. + case cancelTaskAndUpstreamContinuations( + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that nothing should be done. + case none + } + + mutating func iteratorDeinitialized() -> IteratorDeinitializedAction { + switch state { + case .initial: + // Nothing to do here. No demand was signalled until now + return .none + + case .merging(_, _, _, _, .some): + // An iterator was deinitialized while we have a suspended continuation. + preconditionFailure("Internal inconsistency current state \(self.state) and received iteratorDeinitialized()") + + case let .merging(task, _, upstreamContinuations, _, .none): + // The iterator was dropped which signals that the consumer is finished. + // We can transition to finished now and need to clean everything up. + state = .finished + + return .cancelTaskAndUpstreamContinuations( + task: task, + upstreamContinuations: upstreamContinuations + ) + + case .upstreamFailure: + // The iterator was dropped which signals that the consumer is finished. + // We can transition to finished now. The cleanup already happened when we + // transitioned to `upstreamFailure`. + state = .finished + + return .none + + case .finished: + // We are already finished so there is nothing left to clean up. + // This is just the references dropping afterwards. + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func taskStarted(_ task: Task) { + switch state { + case .initial: + // The user called `makeAsyncIterator` and we are starting the `Task` + // to consume the upstream sequences + state = .merging( + task: task, + buffer: .init(), + upstreamContinuations: [], // This should reserve capacity in the variadic generics case + upstreamsFinished: 0, + downstreamContinuation: nil + ) + + case .merging, .upstreamFailure, .finished: + // We only a single iterator to be created so this must never happen. + preconditionFailure("Internal inconsistency current state \(self.state) and received taskStarted()") + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `childTaskSuspended()`. + enum ChildTaskSuspendedAction { + /// Indicates that the continuation should be resumed which will lead to calling `next` on the upstream. + case resumeContinuation( + upstreamContinuation: UnsafeContinuation + ) + /// Indicates that the continuation should be resumed with an Error because another upstream sequence threw. + case resumeContinuationWithError( + upstreamContinuation: UnsafeContinuation, + error: Error + ) + /// Indicates that nothing should be done. + case none + } + + mutating func childTaskSuspended(_ continuation: UnsafeContinuation) -> ChildTaskSuspendedAction { + switch state { + case .initial: + // Child tasks are only created after we transitioned to `merging` + preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended()") + + case .merging(_, _, _, _, .some): + // We have outstanding demand so request the next element + return .resumeContinuation(upstreamContinuation: continuation) + + case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none): + // There is no outstanding demand from the downstream + // so we are storing the continuation and resume it once there is demand. + state = .modifying + + upstreamContinuations.append(continuation) + + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: nil + ) + + return .none + + case .upstreamFailure: + // Another upstream already threw so we just need to throw from this continuation + // which will end the consumption of the upstream. + + return .resumeContinuationWithError( + upstreamContinuation: continuation, + error: CancellationError() + ) + + case .finished: + // Since cancellation is cooperative it might be that child tasks are still getting + // suspended even though we already cancelled them. We must tolerate this and just resume + // the continuation with an error. + return .resumeContinuationWithError( + upstreamContinuation: continuation, + error: CancellationError() + ) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `elementProduced()`. + enum ElementProducedAction { + /// Indicates that the downstream continuation should be resumed with the element. + case resumeContinuation( + downstreamContinuation: UnsafeContinuation, + element: Element + ) + /// Indicates that nothing should be done. + case none + } + + mutating func elementProduced(_ element: Element) -> ElementProducedAction { + switch state { + case .initial: + // Child tasks that are producing elements are only created after we transitioned to `merging` + preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()") + + case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation)): + // We produced an element and have an outstanding downstream continuation + // this means we can go right ahead and resume the continuation with that element + precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") + + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: nil + ) + + return .resumeContinuation( + downstreamContinuation: downstreamContinuation, + element: element + ) + + case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none): + // There is not outstanding downstream continuation so we must buffer the element + // This happens if we race our upstream sequences to produce elements + // and the _losers_ are signalling their produced element + state = .modifying + + buffer.append(element) + + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: nil + ) + + return .none + + case .upstreamFailure: + // Another upstream already produced an error so we just drop the new element + return .none + + case .finished: + // Since cancellation is cooperative it might be that child tasks + // are still producing elements after we finished. + // We are just going to drop them since there is nothing we can do + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `upstreamFinished()`. + enum UpstreamFinishedAction { + /// Indicates that the task and the upstream continuations should be cancelled. + case cancelTaskAndUpstreamContinuations( + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that the downstream continuation should be resumed with `nil` and + /// the task and the upstream continuations should be cancelled. + case resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: UnsafeContinuation, + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that nothing should be done. + case none + } + + mutating func upstreamFinished() -> UpstreamFinishedAction { + switch state { + case .initial: + preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()") + + case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation)): + // One of the upstreams finished + precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") + + // First we increment our counter of finished upstreams + upstreamsFinished += 1 + + if upstreamsFinished == self.numberOfUpstreamSequences { + // All of our upstreams have finished and we can transition to finished now + // We also need to cancel the tasks and any outstanding continuations + state = .finished + + return .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: downstreamContinuation, + task: task, + upstreamContinuations: upstreamContinuations + ) + } else { + // There are still upstreams that haven't finished so we are just storing our new + // counter of finished upstreams + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: downstreamContinuation + ) + + return .none + } + + case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none): + // First we increment our counter of finished upstreams + upstreamsFinished += 1 + + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: nil + ) + + if upstreamsFinished == self.numberOfUpstreamSequences { + // All of our upstreams have finished; however, we are only transitioning to + // finished once our downstream calls `next` again. + return .cancelTaskAndUpstreamContinuations( + task: task, + upstreamContinuations: upstreamContinuations + ) + } else { + // There are still upstreams that haven't finished. + return .none + } + + case .upstreamFailure: + // Another upstream threw already so we can just ignore this finish + return .none + + case .finished: + // This is just everything finishing up, nothing to do here + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `upstreamThrew()`. + enum UpstreamThrewAction { + /// Indicates that the task and the upstream continuations should be cancelled. + case cancelTaskAndUpstreamContinuations( + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that the downstream continuation should be resumed with the `error` and + /// the task and the upstream continuations should be cancelled. + case resumeContinuationWithErrorAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: UnsafeContinuation, + error: Error, + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that nothing should be done. + case none + } + + mutating func upstreamThrew(_ error: Error) -> UpstreamThrewAction { + switch state { + case .initial: + preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()") + + case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation)): + // An upstream threw an error and we have a downstream continuation. + // We just need to resume the downstream continuation with the error and cancel everything + precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty") + + // We can transition to finished right away because we are returning the error + state = .finished + + return .resumeContinuationWithErrorAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: downstreamContinuation, + error: error, + task: task, + upstreamContinuations: upstreamContinuations + ) + + case let .merging(task, buffer, upstreamContinuations, _, .none): + // An upstream threw an error and we don't have a downstream continuation. + // We need to store the error and wait for the downstream to consume the + // rest of the buffer and the error. However, we can already cancel the task + // and the other upstream continuations since we won't need any more elements. + state = .upstreamFailure( + buffer: buffer, + error: error + ) + return .cancelTaskAndUpstreamContinuations( + task: task, + upstreamContinuations: upstreamContinuations + ) + + case .upstreamFailure: + // Another upstream threw already so we can just ignore this error + return .none + + case .finished: + // This is just everything finishing up, nothing to do here + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `cancelled()`. + enum CancelledAction { + /// Indicates that the downstream continuation needs to be resumed and + /// task and the upstream continuations should be cancelled. + case resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: UnsafeContinuation, + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that the task and the upstream continuations should be cancelled. + case cancelTaskAndUpstreamContinuations( + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that nothing should be done. + case none + } + + mutating func cancelled() -> CancelledAction { + switch state { + case .initial: + // Since we are transitioning to `merging` before we return from `makeAsyncIterator` + // this can never happen + preconditionFailure("Internal inconsistency current state \(self.state) and received cancelled()") + + case let .merging(task, _, upstreamContinuations, _, .some(downstreamContinuation)): + // The downstream Task got cancelled so we need to cancel our upstream Task + // and resume all continuations. We can also transition to finished. + state = .finished + + return .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: downstreamContinuation, + task: task, + upstreamContinuations: upstreamContinuations + ) + + case let .merging(task, _, upstreamContinuations, _, .none): + // The downstream Task got cancelled so we need to cancel our upstream Task + // and resume all continuations. We can also transition to finished. + state = .finished + + return .cancelTaskAndUpstreamContinuations( + task: task, + upstreamContinuations: upstreamContinuations + ) + + case .upstreamFailure: + // An upstream already threw and we cancelled everything already. + // We can just transition to finished now + state = .finished + + return .none + + case .finished: + // We are already finished so nothing to do here: + state = .finished + + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `next()`. + enum NextAction { + /// Indicates that a new `Task` should be created that consumes the sequence and the downstream must be supsended + case startTaskAndSuspendDownstreamTask(Base1, Base2, Base3?) + /// Indicates that the `element` should be returned. + case returnElement(Result) + /// Indicates that `nil` should be returned. + case returnNil + /// Indicates that the `error` should be thrown. + case throwError(Error) + /// Indicates that the downstream task should be suspended. + case suspendDownstreamTask + } + + mutating func next() -> NextAction { + switch state { + case .initial(let base1, let base2, let base3): + // This is the first time we got demand signalled. We need to start the task now + // We are transitioning to merging in the taskStarted method. + return .startTaskAndSuspendDownstreamTask(base1, base2, base3) + + case .merging(_, _, _, _, .some): + // We have multiple AsyncIterators iterating the sequence + preconditionFailure("Internal inconsistency current state \(self.state) and received next()") + + case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none): + state = .modifying + + if let element = buffer.popFirst() { + // We have an element buffered already so we can just return that. + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: nil + ) + + return .returnElement(.success(element)) + } else { + // There was nothing in the buffer so we have to suspend the downstream task + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: upstreamContinuations, + upstreamsFinished: upstreamsFinished, + downstreamContinuation: nil + ) + + return .suspendDownstreamTask + } + + case .upstreamFailure(var buffer, let error): + state = .modifying + + if let element = buffer.popFirst() { + // There was still a left over element that we need to return + state = .upstreamFailure( + buffer: buffer, + error: error + ) + + return .returnElement(.success(element)) + } else { + // The buffer is empty and we can now throw the error + // that an upstream produced + state = .finished + + return .throwError(error) + } + + case .finished: + // We are already finished so we are just returning `nil` + return .returnNil + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `next(for)`. + enum NextForAction { + /// Indicates that the upstream continuations should be resumed to demand new elements. + case resumeUpstreamContinuations( + upstreamContinuations: [UnsafeContinuation] + ) + } + + mutating func next(for continuation: UnsafeContinuation) -> NextForAction { + switch state { + case .initial, + .merging(_, _, _, _, .some), + .upstreamFailure, + .finished: + // All other states are handled by `next` already so we should never get in here with + // any of those + preconditionFailure("Internal inconsistency current state \(self.state) and received next(for:)") + + case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none): + // We suspended the task and need signal the upstreams + state = .merging( + task: task, + buffer: buffer, + upstreamContinuations: [], // TODO: don't alloc new array here + upstreamsFinished: upstreamsFinished, + downstreamContinuation: continuation + ) + + return .resumeUpstreamContinuations( + upstreamContinuations: upstreamContinuations + ) + + case .modifying: + preconditionFailure("Invalid state") + } + } +} diff --git a/Sources/AsyncAlgorithms/Merge/MergeStorage.swift b/Sources/AsyncAlgorithms/Merge/MergeStorage.swift new file mode 100644 index 00000000..de4c72b8 --- /dev/null +++ b/Sources/AsyncAlgorithms/Merge/MergeStorage.swift @@ -0,0 +1,449 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +final class MergeStorage< + Base1: AsyncSequence, + Base2: AsyncSequence, + Base3: AsyncSequence +>: @unchecked Sendable where + Base1.Element == Base2.Element, + Base1.Element == Base3.Element, + Base1: Sendable, Base2: Sendable, Base3: Sendable, + Base1.Element: Sendable +{ + typealias Element = Base1.Element + + /// The lock that protects our state. + private let lock = Lock.allocate() + /// The state machine. + private var stateMachine: MergeStateMachine + + init( + base1: Base1, + base2: Base2, + base3: Base3? + ) { + stateMachine = .init(base1: base1, base2: base2, base3: base3) + } + + deinit { + self.lock.deinitialize() + } + + func iteratorDeinitialized() { + let action = lock.withLock { self.stateMachine.iteratorDeinitialized() } + + switch action { + case let .cancelTaskAndUpstreamContinuations( + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + + task.cancel() + + case .none: + break + } + } + + func next() async rethrows -> Element? { + // We need to handle cancellation here because we are creating a continuation + // and because we need to cancel the `Task` we created to consume the upstream + try await withTaskCancellationHandler { + self.lock.lock() + let action = self.stateMachine.next() + + switch action { + case .startTaskAndSuspendDownstreamTask(let base1, let base2, let base3): + self.startTask( + stateMachine: &self.stateMachine, + base1: base1, + base2: base2, + base3: base3 + ) + // It is safe to hold the lock across this method + // since the closure is guaranteed to be run straight away + return try await withUnsafeThrowingContinuation { continuation in + let action = self.stateMachine.next(for: continuation) + self.lock.unlock() + + switch action { + case let .resumeUpstreamContinuations(upstreamContinuations): + // This is signalling the child tasks that are consuming the upstream + // sequences to signal demand. + upstreamContinuations.forEach { $0.resume(returning: ()) } + } + } + + + case let .returnElement(element): + self.lock.unlock() + + return try element._rethrowGet() + + case .returnNil: + self.lock.unlock() + return nil + + case let .throwError(error): + self.lock.unlock() + throw error + + case .suspendDownstreamTask: + // It is safe to hold the lock across this method + // since the closure is guaranteed to be run straight away + return try await withUnsafeThrowingContinuation { continuation in + let action = self.stateMachine.next(for: continuation) + self.lock.unlock() + + switch action { + case let .resumeUpstreamContinuations(upstreamContinuations): + // This is signalling the child tasks that are consuming the upstream + // sequences to signal demand. + upstreamContinuations.forEach { $0.resume(returning: ()) } + } + } + } + } onCancel: { + let action = self.lock.withLock { self.stateMachine.cancelled() } + + switch action { + case let .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation, + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + + task.cancel() + + downstreamContinuation.resume(returning: nil) + + case let .cancelTaskAndUpstreamContinuations( + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + + task.cancel() + + case .none: + break + } + } + } + + private func startTask(stateMachine: inout MergeStateMachine, base1: Base1, base2: Base2, base3: Base3?) { + // This creates a new `Task` that is iterating the upstream + // sequences. We must store it to cancel it at the right times. + let task = Task { + await withThrowingTaskGroup(of: Void.self) { group in + // For each upstream sequence we are adding a child task that + // is consuming the upstream sequence + group.addTask { + var iterator1 = base1.makeAsyncIterator() + + // This is our upstream consumption loop + loop: while true { + // We are creating a continuation before requesting the next + // element from upstream. This continuation is only resumed + // if the downstream consumer called `next` to signal his demand. + try await withUnsafeThrowingContinuation { continuation in + let action = self.lock.withLock { + self.stateMachine.childTaskSuspended(continuation) + } + + switch action { + case let .resumeContinuation(continuation): + // This happens if there is outstanding demand + // and we need to demand from upstream right away + continuation.resume(returning: ()) + + case let .resumeContinuationWithError(continuation, error): + // This happens if another upstream already failed or if + // the task got cancelled. + continuation.resume(throwing: error) + + case .none: + break + } + } + + // We got signalled from the downstream that we have demand so let's + // request a new element from the upstream + if let element1 = try await iterator1.next() { + let action = self.lock.withLock { + self.stateMachine.elementProduced(element1) + } + + switch action { + case let .resumeContinuation(continuation, element): + // We had an outstanding demand and where the first + // upstream to produce an element so we can forward it to + // the downstream + continuation.resume(returning: element) + + case .none: + break + } + + } else { + // The upstream returned `nil` which indicates that it finished + let action = self.lock.withLock { + self.stateMachine.upstreamFinished() + } + + // All of this is mostly cleanup around the Task and the outstanding + // continuations used for signalling. + switch action { + case let .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation, + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: nil) + + break loop + + case let .cancelTaskAndUpstreamContinuations( + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + break loop + case .none: + + break loop + } + } + } + } + + // Copy from the above just using the base2 sequence + group.addTask { + var iterator2 = base2.makeAsyncIterator() + + // This is our upstream consumption loop + loop: while true { + // We are creating a continuation before requesting the next + // element from upstream. This continuation is only resumed + // if the downstream consumer called `next` to signal his demand. + try await withUnsafeThrowingContinuation { continuation in + let action = self.lock.withLock { + self.stateMachine.childTaskSuspended(continuation) + } + + switch action { + case let .resumeContinuation(continuation): + // This happens if there is outstanding demand + // and we need to demand from upstream right away + continuation.resume(returning: ()) + + case let .resumeContinuationWithError(continuation, error): + // This happens if another upstream already failed or if + // the task got cancelled. + continuation.resume(throwing: error) + + case .none: + break + } + } + + // We got signalled from the downstream that we have demand so let's + // request a new element from the upstream + if let element2 = try await iterator2.next() { + let action = self.lock.withLock { + self.stateMachine.elementProduced(element2) + } + + switch action { + case let .resumeContinuation(continuation, element): + // We had an outstanding demand and where the first + // upstream to produce an element so we can forward it to + // the downstream + continuation.resume(returning: element) + + case .none: + break + } + + } else { + // The upstream returned `nil` which indicates that it finished + let action = self.lock.withLock { + self.stateMachine.upstreamFinished() + } + + // All of this is mostly cleanup around the Task and the outstanding + // continuations used for signalling. + switch action { + case let .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation, + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: nil) + + break loop + + case let .cancelTaskAndUpstreamContinuations( + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + break loop + case .none: + + break loop + } + } + } + } + + // Copy from the above just using the base3 sequence + if let base3 = base3 { + group.addTask { + var iterator3 = base3.makeAsyncIterator() + + // This is our upstream consumption loop + loop: while true { + // We are creating a continuation before requesting the next + // element from upstream. This continuation is only resumed + // if the downstream consumer called `next` to signal his demand. + try await withUnsafeThrowingContinuation { continuation in + let action = self.lock.withLock { + self.stateMachine.childTaskSuspended(continuation) + } + + switch action { + case let .resumeContinuation(continuation): + // This happens if there is outstanding demand + // and we need to demand from upstream right away + continuation.resume(returning: ()) + + case let .resumeContinuationWithError(continuation, error): + // This happens if another upstream already failed or if + // the task got cancelled. + continuation.resume(throwing: error) + + case .none: + break + } + } + + // We got signalled from the downstream that we have demand so let's + // request a new element from the upstream + if let element3 = try await iterator3.next() { + let action = self.lock.withLock { + self.stateMachine.elementProduced(element3) + } + + switch action { + case let .resumeContinuation(continuation, element): + // We had an outstanding demand and where the first + // upstream to produce an element so we can forward it to + // the downstream + continuation.resume(returning: element) + + case .none: + break + } + + } else { + // The upstream returned `nil` which indicates that it finished + let action = self.lock.withLock { + self.stateMachine.upstreamFinished() + } + + // All of this is mostly cleanup around the Task and the outstanding + // continuations used for signalling. + switch action { + case let .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation, + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: nil) + + break loop + + case let .cancelTaskAndUpstreamContinuations( + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + break loop + case .none: + + break loop + } + } + } + } + } + + do { + try await group.waitForAll() + } catch { + // One of the upstream sequences threw an error + let action = self.lock.withLock { + self.stateMachine.upstreamThrew(error) + } + + switch action { + case let .resumeContinuationWithErrorAndCancelTaskAndUpstreamContinuations( + downstreamContinuation, + error, + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + + task.cancel() + + downstreamContinuation.resume(throwing: error) + case let .cancelTaskAndUpstreamContinuations( + task, + upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + + task.cancel() + + case .none: + break + } + + group.cancelAll() + } + } + } + + // We need to inform our state machine that we started the Task + stateMachine.taskStarted(task) + } +} + diff --git a/Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift b/Sources/AsyncAlgorithms/Merge2StateMachine.swift similarity index 74% rename from Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift rename to Sources/AsyncAlgorithms/Merge2StateMachine.swift index eeaf0246..b0f15ab6 100644 --- a/Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift +++ b/Sources/AsyncAlgorithms/Merge2StateMachine.swift @@ -9,16 +9,6 @@ // //===----------------------------------------------------------------------===// -/// Creates an asynchronous sequence of elements from two underlying asynchronous sequences -public func merge(_ base1: Base1, _ base2: Base2) -> AsyncMerge2Sequence -where - Base1.Element == Base2.Element, - Base1: Sendable, Base2: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable { - return AsyncMerge2Sequence(base1, base2) -} - struct Merge2StateMachine: Sendable where Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base1.Element: Sendable, Base2.Element: Sendable { typealias Element1 = Base1.Element typealias Element2 = Base2.Element @@ -170,40 +160,3 @@ extension Merge2StateMachine.Either where Base1.Element == Base2.Element { } } } - -/// An asynchronous sequence of elements from two underlying asynchronous sequences -/// -/// In a `AsyncMerge2Sequence` instance, the *i*th element is the *i*th element -/// resolved in sequential order out of the two underlying asynchronous sequences. -/// Use the `merge(_:_:)` function to create an `AsyncMerge2Sequence`. -public struct AsyncMerge2Sequence: AsyncSequence, Sendable -where - Base1.Element == Base2.Element, - Base1: Sendable, Base2: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable { - public typealias Element = Base1.Element - /// An iterator for `AsyncMerge2Sequence` - public struct Iterator: AsyncIteratorProtocol, Sendable { - var state: Merge2StateMachine - init(_ base1: Base1.AsyncIterator, _ base2: Base2.AsyncIterator) { - state = Merge2StateMachine(base1, base2) - } - - public mutating func next() async rethrows -> Element? { - return try await state.next()?.value - } - } - - let base1: Base1 - let base2: Base2 - - init(_ base1: Base1, _ base2: Base2) { - self.base1 = base1 - self.base2 = base2 - } - - public func makeAsyncIterator() -> Iterator { - return Iterator(base1.makeAsyncIterator(), base2.makeAsyncIterator()) - } -} diff --git a/Tests/AsyncAlgorithmsTests/Support/Asserts.swift b/Tests/AsyncAlgorithmsTests/Support/Asserts.swift index a0f84dbb..c9cdb968 100644 --- a/Tests/AsyncAlgorithmsTests/Support/Asserts.swift +++ b/Tests/AsyncAlgorithmsTests/Support/Asserts.swift @@ -150,3 +150,18 @@ fileprivate func ==(_ lhs: [(A, B, C)] public func XCTAssertEqual(_ expression1: @autoclosure () throws -> [(A, B, C)], _ expression2: @autoclosure () throws -> [(A, B, C)], _ message: @autoclosure () -> String = "", file: StaticString = #filePath, line: UInt = #line) { _XCTAssertEqual(expression1, expression2, { $0 == $1 }, message, file: file, line: line) } + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func XCTAssertThrowsError( + _ expression: @autoclosure () async throws -> T, + file: StaticString = #file, + line: UInt = #line, + verify: (Error) -> Void = { _ in } +) async { + do { + _ = try await expression() + XCTFail("Expression did not throw error", file: file, line: line) + } catch { + verify(error) + } +} diff --git a/Tests/AsyncAlgorithmsTests/TestMerge.swift b/Tests/AsyncAlgorithmsTests/TestMerge.swift index 3cf7c577..76ce5344 100644 --- a/Tests/AsyncAlgorithmsTests/TestMerge.swift +++ b/Tests/AsyncAlgorithmsTests/TestMerge.swift @@ -505,4 +505,44 @@ final class TestMerge3: XCTestCase { task.cancel() wait(for: [finished], timeout: 1.0) } + + // MARK: - IteratorInitialized + + func testIteratorInitialized_whenInitial() async throws { + let reportingSequence1 = ReportingAsyncSequence([1]) + let reportingSequence2 = ReportingAsyncSequence([2]) + let merge = merge(reportingSequence1, reportingSequence2) + + _ = merge.makeAsyncIterator() + + // We need to give the task that consumes the upstream + // a bit of time to make the iterators + try await Task.sleep(nanoseconds: 1000000) + + XCTAssertEqual(reportingSequence1.events, []) + XCTAssertEqual(reportingSequence2.events, []) + } + + // MARK: - IteratorDeinitialized + + func testIteratorDeinitialized_whenMerging() async throws { + let merge = merge([1].async, [2].async) + + var iterator: _! = merge.makeAsyncIterator() + + let nextValue = await iterator.next() + XCTAssertNotNil(nextValue) + + iterator = nil + } + + func testIteratorDeinitialized_whenFinished() async throws { + let merge = merge(Array().async, [].async) + + var iterator: _? = merge.makeAsyncIterator() + let firstValue = await iterator?.next() + XCTAssertNil(firstValue) + + iterator = nil + } }