From 973b9ac023d4812e6f685106d872ef30192da8f8 Mon Sep 17 00:00:00 2001 From: FranzBusch Date: Wed, 12 Oct 2022 09:14:05 +0100 Subject: [PATCH] Merge the zip2 and zip3 implementations # Motivation We recently merged https://github.com/apple/swift-async-algorithms/pull/201 which removed the `Sendable` constraint of the `AsyncIterator` from `zip`. There were some left over comments that we wanted to fix in a follow up. Furthermore, we also can merge the implementations of `zip2` and `zip3` relatively easily. The merging also aids us in understanding what we need from the variadic generics proposals and allows us to give feedback since `zip` is non-trivial to implement with variadic generics. Lastly, the merged state machine will be a good base for the overhaul of `combineLatest`. # Modification This PR merges the state machines from `zip2` and `zip3` into a single one. Furthermore, it addresses some of the open feedback from the last PR. # Result We now have a single state machine which is a good foundation for our changes to `combineLatest`. --- .../Zip/AsyncZip2Sequence.swift | 35 +- .../Zip/AsyncZip3Sequence.swift | 36 +- Sources/AsyncAlgorithms/Zip/Zip2Runtime.swift | 212 ------- .../Zip/Zip2StateMachine.swift | 367 ------------ Sources/AsyncAlgorithms/Zip/Zip3Runtime.swift | 252 -------- .../Zip/Zip3StateMachine.swift | 439 -------------- .../AsyncAlgorithms/Zip/ZipStateMachine.swift | 548 ++++++++++++++++++ Sources/AsyncAlgorithms/Zip/ZipStorage.swift | 320 ++++++++++ 8 files changed, 921 insertions(+), 1288 deletions(-) delete mode 100644 Sources/AsyncAlgorithms/Zip/Zip2Runtime.swift delete mode 100644 Sources/AsyncAlgorithms/Zip/Zip2StateMachine.swift delete mode 100644 Sources/AsyncAlgorithms/Zip/Zip3Runtime.swift delete mode 100644 Sources/AsyncAlgorithms/Zip/Zip3StateMachine.swift create mode 100644 Sources/AsyncAlgorithms/Zip/ZipStateMachine.swift create mode 100644 Sources/AsyncAlgorithms/Zip/ZipStorage.swift diff --git a/Sources/AsyncAlgorithms/Zip/AsyncZip2Sequence.swift b/Sources/AsyncAlgorithms/Zip/AsyncZip2Sequence.swift index 292a612d..0ef591ba 100644 --- a/Sources/AsyncAlgorithms/Zip/AsyncZip2Sequence.swift +++ b/Sources/AsyncAlgorithms/Zip/AsyncZip2Sequence.swift @@ -21,7 +21,7 @@ public func zip( /// An asynchronous sequence that concurrently awaits values from two `AsyncSequence` types /// and emits a tuple of the values. public struct AsyncZip2Sequence: AsyncSequence -where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable { + where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable { public typealias Element = (Base1.Element, Base2.Element) public typealias AsyncIterator = Iterator @@ -34,21 +34,38 @@ where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: } public func makeAsyncIterator() -> AsyncIterator { - Iterator( - base1, - base2 - ) + Iterator(storage: .init(self.base1, self.base2, nil)) } public struct Iterator: AsyncIteratorProtocol { - let runtime: Zip2Runtime + final class InternalClass { + private let storage: ZipStorage - init(_ base1: Base1, _ base2: Base2) { - self.runtime = Zip2Runtime(base1, base2) + fileprivate init(storage: ZipStorage) { + self.storage = storage + } + + deinit { + self.storage.iteratorDeinitialized() + } + + func next() async rethrows -> Element? { + guard let element = try await self.storage.next() else { + return nil + } + + return (element.0, element.1) + } + } + + let internalClass: InternalClass + + fileprivate init(storage: ZipStorage) { + self.internalClass = InternalClass(storage: storage) } public mutating func next() async rethrows -> Element? { - try await self.runtime.next() + try await self.internalClass.next() } } } diff --git a/Sources/AsyncAlgorithms/Zip/AsyncZip3Sequence.swift b/Sources/AsyncAlgorithms/Zip/AsyncZip3Sequence.swift index 4a52158e..43317aca 100644 --- a/Sources/AsyncAlgorithms/Zip/AsyncZip3Sequence.swift +++ b/Sources/AsyncAlgorithms/Zip/AsyncZip3Sequence.swift @@ -22,7 +22,7 @@ public func zip: AsyncSequence -where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable, Base3: Sendable, Base3.Element: Sendable { + where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable, Base3: Sendable, Base3.Element: Sendable { public typealias Element = (Base1.Element, Base2.Element, Base3.Element) public typealias AsyncIterator = Iterator @@ -37,22 +37,40 @@ where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: } public func makeAsyncIterator() -> AsyncIterator { - Iterator( - base1, - base2, - base3 + Iterator(storage: .init(self.base1, self.base2, self.base3) ) } public struct Iterator: AsyncIteratorProtocol { - let runtime: Zip3Runtime + final class InternalClass { + private let storage: ZipStorage - init(_ base1: Base1, _ base2: Base2, _ base3: Base3) { - self.runtime = Zip3Runtime(base1, base2, base3) + fileprivate init(storage: ZipStorage) { + self.storage = storage + } + + deinit { + self.storage.iteratorDeinitialized() + } + + func next() async rethrows -> Element? { + guard let element = try await self.storage.next() else { + return nil + } + + // This force unwrap is safe since there must be a third element. + return (element.0, element.1, element.2!) + } + } + + let internalClass: InternalClass + + fileprivate init(storage: ZipStorage) { + self.internalClass = InternalClass(storage: storage) } public mutating func next() async rethrows -> Element? { - try await self.runtime.next() + try await self.internalClass.next() } } } diff --git a/Sources/AsyncAlgorithms/Zip/Zip2Runtime.swift b/Sources/AsyncAlgorithms/Zip/Zip2Runtime.swift deleted file mode 100644 index 86e875f8..00000000 --- a/Sources/AsyncAlgorithms/Zip/Zip2Runtime.swift +++ /dev/null @@ -1,212 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -final class Zip2Runtime: Sendable -where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable { - typealias ZipStateMachine = Zip2StateMachine - - private let stateMachine = ManagedCriticalState(ZipStateMachine()) - private let base1: Base1 - private let base2: Base2 - - init(_ base1: Base1, _ base2: Base2) { - self.base1 = base1 - self.base2 = base2 - } - - func next() async rethrows -> (Base1.Element, Base2.Element)? { - try await withTaskCancellationHandler { - let results = await withUnsafeContinuation { continuation in - self.stateMachine.withCriticalRegion { stateMachine in - let output = stateMachine.newDemandFromConsumer(suspendedDemand: continuation) - switch output { - case .startTask(let suspendedDemand): - // first iteration, we start one task per base to iterate over them - self.startTask(stateMachine: &stateMachine, suspendedDemand: suspendedDemand) - - case .resumeBases(let suspendedBases): - // bases can be iterated over for 1 iteration so their next value can be retrieved - suspendedBases.forEach { $0.resume() } - - case .terminate(let suspendedDemand): - // the async sequence is already finished, immediately resuming - suspendedDemand.resume(returning: nil) - } - } - } - - guard let results else { - return nil - } - - self.stateMachine.withCriticalRegion { stateMachine in - // acknowledging the consumption of the zipped values, so we can begin another iteration on the bases - stateMachine.demandIsFulfilled() - } - - return try (results.0._rethrowGet(), results.1._rethrowGet()) - } onCancel: { - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - // clean the allocated resources and state - self.handle(rootTaskIsCancelledOutput: output) - } - } - - private func handle(rootTaskIsCancelledOutput: ZipStateMachine.RootTaskIsCancelledOutput) { - switch rootTaskIsCancelledOutput { - case .terminate(let task, let suspendedBases, let suspendedDemands): - suspendedBases?.forEach { $0.resume() } - suspendedDemands?.forEach { $0?.resume(returning: nil) } - task?.cancel() - } - } - - private func startTask( - stateMachine: inout ZipStateMachine, - suspendedDemand: ZipStateMachine.SuspendedDemand - ) { - let task = Task { - await withTaskGroup(of: Void.self) { group in - group.addTask { - var base1Iterator = self.base1.makeAsyncIterator() - - do { - while true { - await withUnsafeContinuation { continuation in - let output = self.stateMachine.withCriticalRegion { machine in - machine.newLoopFromBase1(suspendedBase: continuation) - } - - self.handle(newLoopFromBaseOutput: output) - } - - guard let element1 = try await base1Iterator.next() else { - break - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.base1HasProducedElement(element: element1) - } - - self.handle(baseHasProducedElementOutput: output) - } - } catch { - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseHasProducedFailure(error: error) - } - - self.handle(baseHasProducedFailureOutput: output) - } - - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.baseIsFinished() - } - - self.handle(baseIsFinishedOutput: output) - } - - group.addTask { - var base2Iterator = self.base2.makeAsyncIterator() - - do { - while true { - await withUnsafeContinuation { continuation in - let output = self.stateMachine.withCriticalRegion { machine in - machine.newLoopFromBase2(suspendedBase: continuation) - } - - self.handle(newLoopFromBaseOutput: output) - } - - guard let element2 = try await base2Iterator.next() else { - break - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.base2HasProducedElement(element: element2) - } - - self.handle(baseHasProducedElementOutput: output) - } - } catch { - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseHasProducedFailure(error: error) - } - - self.handle(baseHasProducedFailureOutput: output) - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseIsFinished() - } - - self.handle(baseIsFinishedOutput: output) - } - } - } - stateMachine.taskIsStarted(task: task, suspendedDemand: suspendedDemand) - } - - private func handle(newLoopFromBaseOutput: ZipStateMachine.NewLoopFromBaseOutput) { - switch newLoopFromBaseOutput { - case .none: - break - - case .resumeBases(let suspendedBases): - suspendedBases.forEach { $0.resume() } - - case .terminate(let suspendedBase): - suspendedBase.resume() - } - } - - private func handle(baseHasProducedElementOutput: ZipStateMachine.BaseHasProducedElementOutput) { - switch baseHasProducedElementOutput { - case .none: - break - - case .resumeDemand(let suspendedDemand, let result1, let result2): - suspendedDemand?.resume(returning: (result1, result2)) - } - } - - private func handle(baseHasProducedFailureOutput: ZipStateMachine.BaseHasProducedFailureOutput) { - switch baseHasProducedFailureOutput { - case .none: - break - - case .resumeDemandAndTerminate(let task, let suspendedDemand, let suspendedBases, let result1, let result2): - suspendedDemand?.resume(returning: (result1, result2)) - suspendedBases.forEach { $0.resume() } - task?.cancel() - } - } - - private func handle(baseIsFinishedOutput: ZipStateMachine.BaseIsFinishedOutput) { - switch baseIsFinishedOutput { - case .terminate(let task, let suspendedBases, let suspendedDemands): - suspendedBases?.forEach { $0.resume() } - suspendedDemands?.forEach { $0?.resume(returning: nil) } - task?.cancel() - } - } - - deinit { - // clean the allocated resources and state - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - - self.handle(rootTaskIsCancelledOutput: output) - } -} diff --git a/Sources/AsyncAlgorithms/Zip/Zip2StateMachine.swift b/Sources/AsyncAlgorithms/Zip/Zip2StateMachine.swift deleted file mode 100644 index 19ebc11f..00000000 --- a/Sources/AsyncAlgorithms/Zip/Zip2StateMachine.swift +++ /dev/null @@ -1,367 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -struct Zip2StateMachine: Sendable -where Element1: Sendable, Element2: Sendable { - typealias SuspendedDemand = UnsafeContinuation<(Result, Result)?, Never> - - private enum State { - case initial - case awaitingDemandFromConsumer( - task: Task?, - suspendedBases: [UnsafeContinuation] - ) - case awaitingBaseResults( - task: Task?, - result1: Result?, - result2: Result?, - suspendedBases: [UnsafeContinuation], - suspendedDemand: SuspendedDemand? - ) - case finished - } - - private var state: State = .initial - - mutating func taskIsStarted( - task: Task, - suspendedDemand: SuspendedDemand - ) { - switch self.state { - case .initial: - self.state = .awaitingBaseResults( - task: task, - result1: nil, - result2: nil, - suspendedBases: [], - suspendedDemand: suspendedDemand - ) - - default: - preconditionFailure("Inconsistent state, the task cannot start while the state is other than initial") - } - } - - enum NewDemandFromConsumerOutput { - case resumeBases(suspendedBases: [UnsafeContinuation]) - case startTask(suspendedDemand: SuspendedDemand) - case terminate(suspendedDemand: SuspendedDemand) - } - - mutating func newDemandFromConsumer( - suspendedDemand: UnsafeContinuation<(Result, Result)?, Never> - ) -> NewDemandFromConsumerOutput { - switch self.state { - case .initial: - return .startTask(suspendedDemand: suspendedDemand) - - case .awaitingDemandFromConsumer(let task, let suspendedBases): - self.state = .awaitingBaseResults(task: task, result1: nil, result2: nil, suspendedBases: [], suspendedDemand: suspendedDemand) - return .resumeBases(suspendedBases: suspendedBases) - - case .awaitingBaseResults: - preconditionFailure("Inconsistent state, a demand is already suspended") - - case .finished: - return .terminate(suspendedDemand: suspendedDemand) - } - } - - enum NewLoopFromBaseOutput { - case none - case resumeBases(suspendedBases: [UnsafeContinuation]) - case terminate(suspendedBase: UnsafeContinuation) - } - - mutating func newLoopFromBase1(suspendedBase: UnsafeContinuation) -> NewLoopFromBaseOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, var suspendedBases): - precondition(suspendedBases.count < 2, "There cannot be more than 2 suspended bases at the same time") - suspendedBases.append(suspendedBase) - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - return .none - - case .awaitingBaseResults(let task, let result1, let result2, var suspendedBases, let suspendedDemand): - precondition(suspendedBases.count < 2, "There cannot be more than 2 suspended bases at the same time") - if result1 != nil { - suspendedBases.append(suspendedBase) - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .resumeBases(suspendedBases: [suspendedBase]) - } - - case .finished: - return .terminate(suspendedBase: suspendedBase) - } - } - - mutating func newLoopFromBase2(suspendedBase: UnsafeContinuation) -> NewLoopFromBaseOutput { - switch state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, var suspendedBases): - precondition(suspendedBases.count < 2, "There cannot be more than 2 suspended bases at the same time") - suspendedBases.append(suspendedBase) - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - return .none - - case .awaitingBaseResults(let task, let result1, let result2, var suspendedBases, let suspendedDemand): - precondition(suspendedBases.count < 2, "There cannot be more than 2 suspended bases at the same time") - if result2 != nil { - suspendedBases.append(suspendedBase) - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .resumeBases(suspendedBases: [suspendedBase]) - } - - case .finished: - return .terminate(suspendedBase: suspendedBase) - } - } - - enum BaseHasProducedElementOutput { - case none - case resumeDemand( - suspendedDemand: SuspendedDemand?, - result1: Result, - result2: Result - ) - } - - mutating func base1HasProducedElement(element: Element1) -> BaseHasProducedElementOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, _, let result2, let suspendedBases, let suspendedDemand): - if let result2 { - self.state = .awaitingBaseResults( - task: task, - result1: .success(element), - result2: result2, - suspendedBases: suspendedBases, - suspendedDemand: nil - ) - return .resumeDemand(suspendedDemand: suspendedDemand, result1: .success(element), result2: result2) - } else { - self.state = .awaitingBaseResults( - task: task, - result1: .success(element), - result2: nil, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } - - case .finished: - return .none - } - } - - mutating func base2HasProducedElement(element: Element2) -> BaseHasProducedElementOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, let result1, _, let suspendedBases, let suspendedDemand): - if let result1 { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: .success(element), - suspendedBases: suspendedBases, - suspendedDemand: nil - ) - return .resumeDemand(suspendedDemand: suspendedDemand, result1: result1, result2: .success(element)) - } else { - self.state = .awaitingBaseResults( - task: task, - result1: nil, - result2: .success(element), - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } - - case .finished: - return .none - } - } - - enum BaseHasProducedFailureOutput { - case none - case resumeDemandAndTerminate( - task: Task?, - suspendedDemand: SuspendedDemand?, - suspendedBases: [UnsafeContinuation], - result1: Result, - result2: Result - ) - } - - mutating func baseHasProducedFailure(error: any Error) -> BaseHasProducedFailureOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .resumeDemandAndTerminate( - task: task, - suspendedDemand: suspendedDemand, - suspendedBases: suspendedBases, - result1: .failure(error), - result2: .failure(error) - ) - - case .finished: - return .none - } - } - - mutating func base2HasProducedFailure(error: Error) -> BaseHasProducedFailureOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .resumeDemandAndTerminate( - task: task, - suspendedDemand: suspendedDemand, - suspendedBases: suspendedBases, - result1: .failure(error), - result2: .failure(error) - ) - - case .finished: - return .none - } - } - - mutating func demandIsFulfilled() { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, results are not yet available to be acknowledged") - - case .awaitingBaseResults(let task, let result1, let result2, let suspendedBases, let suspendedDemand): - precondition(suspendedDemand == nil, "Inconsistent state, there cannot be a suspended demand when ackowledging the demand") - precondition(result1 != nil && result2 != nil, "Inconsistent state, all results are not yet available to be acknowledged") - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - - case .finished: - break - } - } - - enum RootTaskIsCancelledOutput { - case terminate( - task: Task?, - suspendedBases: [UnsafeContinuation]?, - suspendedDemands: [SuspendedDemand?]? - ) - } - - mutating func rootTaskIsCancelled() -> RootTaskIsCancelledOutput { - switch self.state { - case .initial: - assertionFailure("Inconsistent state, the task is not started") - self.state = .finished - return .terminate(task: nil, suspendedBases: nil, suspendedDemands: nil) - - case .awaitingDemandFromConsumer(let task, let suspendedBases): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: nil) - - case .awaitingBaseResults(let task, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: [suspendedDemand]) - - case .finished: - return .terminate(task: nil, suspendedBases: nil, suspendedDemands: nil) - } - } - - enum BaseIsFinishedOutput { - case terminate( - task: Task?, - suspendedBases: [UnsafeContinuation]?, - suspendedDemands: [SuspendedDemand?]? - ) - } - - mutating func baseIsFinished() -> BaseIsFinishedOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, let suspendedBases): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: nil) - - case .awaitingBaseResults(let task, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: [suspendedDemand]) - - case .finished: - return .terminate(task: nil, suspendedBases: nil, suspendedDemands: nil) - } - } -} diff --git a/Sources/AsyncAlgorithms/Zip/Zip3Runtime.swift b/Sources/AsyncAlgorithms/Zip/Zip3Runtime.swift deleted file mode 100644 index 63e83fda..00000000 --- a/Sources/AsyncAlgorithms/Zip/Zip3Runtime.swift +++ /dev/null @@ -1,252 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -final class Zip3Runtime: Sendable -where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable, Base3: Sendable, Base3.Element: Sendable { - typealias ZipStateMachine = Zip3StateMachine - - private let stateMachine = ManagedCriticalState(ZipStateMachine()) - private let base1: Base1 - private let base2: Base2 - private let base3: Base3 - - init(_ base1: Base1, _ base2: Base2, _ base3: Base3) { - self.base1 = base1 - self.base2 = base2 - self.base3 = base3 - } - - func next() async rethrows -> (Base1.Element, Base2.Element, Base3.Element)? { - try await withTaskCancellationHandler { - let results = await withUnsafeContinuation { continuation in - self.stateMachine.withCriticalRegion { stateMachine in - let output = stateMachine.newDemandFromConsumer(suspendedDemand: continuation) - switch output { - case .startTask(let suspendedDemand): - // first iteration, we start one task per base to iterate over them - self.startTask(stateMachine: &stateMachine, suspendedDemand: suspendedDemand) - - case .resumeBases(let suspendedBases): - // bases can be iterated over for 1 iteration so their next value can be retrieved - suspendedBases.forEach { $0.resume() } - - case .terminate(let suspendedDemand): - // the async sequence is already finished, immediately resuming - suspendedDemand.resume(returning: nil) - } - } - } - - guard let results else { - return nil - } - - self.stateMachine.withCriticalRegion { stateMachine in - // acknowledging the consumption of the zipped values, so we can begin another iteration on the bases - stateMachine.demandIsFulfilled() - } - - return try (results.0._rethrowGet(), results.1._rethrowGet(), results.2._rethrowGet()) - } onCancel: { - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - // clean the allocated resources and state - self.handle(rootTaskIsCancelledOutput: output) - } - } - - private func handle(rootTaskIsCancelledOutput: ZipStateMachine.RootTaskIsCancelledOutput) { - switch rootTaskIsCancelledOutput { - case .terminate(let task, let suspendedBases, let suspendedDemands): - suspendedBases?.forEach { $0.resume() } - suspendedDemands?.forEach { $0?.resume(returning: nil) } - task?.cancel() - } - } - - private func startTask( - stateMachine: inout ZipStateMachine, - suspendedDemand: ZipStateMachine.SuspendedDemand - ) { - let task = Task { - await withTaskGroup(of: Void.self) { group in - group.addTask { - var base1Iterator = self.base1.makeAsyncIterator() - - do { - while true { - await withUnsafeContinuation { continuation in - let output = self.stateMachine.withCriticalRegion { machine in - machine.newLoopFromBase1(suspendedBase: continuation) - } - - self.handle(newLoopFromBaseOutput: output) - } - - guard let element1 = try await base1Iterator.next() else { - break - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.base1HasProducedElement(element: element1) - } - - self.handle(baseHasProducedElementOutput: output) - } - } catch { - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseHasProducedFailure(error: error) - } - - self.handle(baseHasProducedFailureOutput: output) - } - - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.baseIsFinished() - } - - self.handle(baseIsFinishedOutput: output) - } - - group.addTask { - var base2Iterator = self.base2.makeAsyncIterator() - - do { - while true { - await withUnsafeContinuation { continuation in - let output = self.stateMachine.withCriticalRegion { machine in - machine.newLoopFromBase2(suspendedBase: continuation) - } - - self.handle(newLoopFromBaseOutput: output) - } - - guard let element2 = try await base2Iterator.next() else { - break - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.base2HasProducedElement(element: element2) - } - - self.handle(baseHasProducedElementOutput: output) - } - } catch { - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseHasProducedFailure(error: error) - } - - self.handle(baseHasProducedFailureOutput: output) - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseIsFinished() - } - - self.handle(baseIsFinishedOutput: output) - } - - group.addTask { - var base3Iterator = self.base3.makeAsyncIterator() - - do { - while true { - await withUnsafeContinuation { continuation in - let output = self.stateMachine.withCriticalRegion { machine in - machine.newLoopFromBase3(suspendedBase: continuation) - } - - self.handle(newLoopFromBaseOutput: output) - } - - guard let element3 = try await base3Iterator.next() else { - break - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.base3HasProducedElement(element: element3) - } - - self.handle(baseHasProducedElementOutput: output) - } - } catch { - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseHasProducedFailure(error: error) - } - - self.handle(baseHasProducedFailureOutput: output) - } - - let output = self.stateMachine.withCriticalRegion { machine in - machine.baseIsFinished() - } - - self.handle(baseIsFinishedOutput: output) - } - } - } - stateMachine.taskIsStarted(task: task, suspendedDemand: suspendedDemand) - } - - private func handle(newLoopFromBaseOutput: ZipStateMachine.NewLoopFromBaseOutput) { - switch newLoopFromBaseOutput { - case .none: - break - - case .resumeBases(let suspendedBases): - suspendedBases.forEach { $0.resume() } - - case .terminate(let suspendedBase): - suspendedBase.resume() - } - } - - private func handle(baseHasProducedElementOutput: ZipStateMachine.BaseHasProducedElementOutput) { - switch baseHasProducedElementOutput { - case .none: - break - - case .resumeDemand(let suspendedDemand, let result1, let result2, let result3): - suspendedDemand?.resume(returning: (result1, result2, result3)) - } - } - - private func handle(baseHasProducedFailureOutput: ZipStateMachine.BaseHasProducedFailureOutput) { - switch baseHasProducedFailureOutput { - case .none: - break - - case .resumeDemandAndTerminate(let task, let suspendedDemand, let suspendedBases, let result1, let result2, let result3): - suspendedDemand?.resume(returning: (result1, result2, result3)) - suspendedBases.forEach { $0.resume() } - task?.cancel() - } - } - - private func handle(baseIsFinishedOutput: ZipStateMachine.BaseIsFinishedOutput) { - switch baseIsFinishedOutput { - case .terminate(let task, let suspendedBases, let suspendedDemands): - suspendedBases?.forEach { $0.resume() } - suspendedDemands?.forEach { $0?.resume(returning: nil) } - task?.cancel() - } - } - - deinit { - // clean the allocated resources and state - let output = self.stateMachine.withCriticalRegion { stateMachine in - stateMachine.rootTaskIsCancelled() - } - - self.handle(rootTaskIsCancelledOutput: output) - } -} diff --git a/Sources/AsyncAlgorithms/Zip/Zip3StateMachine.swift b/Sources/AsyncAlgorithms/Zip/Zip3StateMachine.swift deleted file mode 100644 index 21faf292..00000000 --- a/Sources/AsyncAlgorithms/Zip/Zip3StateMachine.swift +++ /dev/null @@ -1,439 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -struct Zip3StateMachine: Sendable -where Element1: Sendable, Element2: Sendable, Element3: Sendable { - typealias SuspendedDemand = UnsafeContinuation<(Result, Result, Result)?, Never> - - private enum State { - case initial - case awaitingDemandFromConsumer( - task: Task?, - suspendedBases: [UnsafeContinuation] - ) - case awaitingBaseResults( - task: Task?, - result1: Result?, - result2: Result?, - result3: Result?, - suspendedBases: [UnsafeContinuation], - suspendedDemand: SuspendedDemand? - ) - case finished - } - - private var state: State = .initial - - mutating func taskIsStarted( - task: Task, - suspendedDemand: SuspendedDemand - ) { - switch self.state { - case .initial: - self.state = .awaitingBaseResults( - task: task, - result1: nil, - result2: nil, - result3: nil, - suspendedBases: [], - suspendedDemand: suspendedDemand - ) - - default: - preconditionFailure("Inconsistent state, the task cannot start while the state is other than initial") - } - } - - enum NewDemandFromConsumerOutput { - case resumeBases(suspendedBases: [UnsafeContinuation]) - case startTask(suspendedDemand: SuspendedDemand) - case terminate(suspendedDemand: SuspendedDemand) - } - - mutating func newDemandFromConsumer(suspendedDemand: SuspendedDemand) -> NewDemandFromConsumerOutput { - switch self.state { - case .initial: - return .startTask(suspendedDemand: suspendedDemand) - - case .awaitingDemandFromConsumer(let task, let suspendedBases): - self.state = .awaitingBaseResults( - task: task, - result1: nil, - result2: nil, - result3: nil, - suspendedBases: [], - suspendedDemand: suspendedDemand - ) - return .resumeBases(suspendedBases: suspendedBases) - - case .awaitingBaseResults: - preconditionFailure("Inconsistent state, a demand is already suspended") - - case .finished: - return .terminate(suspendedDemand: suspendedDemand) - } - } - - enum NewLoopFromBaseOutput { - case none - case resumeBases(suspendedBases: [UnsafeContinuation]) - case terminate(suspendedBase: UnsafeContinuation) - } - - mutating func newLoopFromBase1(suspendedBase: UnsafeContinuation) -> NewLoopFromBaseOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, var suspendedBases): - precondition(suspendedBases.count < 3, "There cannot be more than 3 suspended bases at the same time") - suspendedBases.append(suspendedBase) - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - return .none - - case .awaitingBaseResults(let task, let result1, let result2, let result3, var suspendedBases, let suspendedDemand): - precondition(suspendedBases.count < 3, "There cannot be more than 3 suspended bases at the same time") - if result1 != nil { - suspendedBases.append(suspendedBase) - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .resumeBases(suspendedBases: [suspendedBase]) - } - - case .finished: - return .terminate(suspendedBase: suspendedBase) - } - } - - mutating func newLoopFromBase2(suspendedBase: UnsafeContinuation) -> NewLoopFromBaseOutput { - switch state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, var suspendedBases): - precondition(suspendedBases.count < 3, "There cannot be more than 3 suspended bases at the same time") - suspendedBases.append(suspendedBase) - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - return .none - - case .awaitingBaseResults(let task, let result1, let result2, let result3, var suspendedBases, let suspendedDemand): - precondition(suspendedBases.count < 3, "There cannot be more than 3 suspended bases at the same time") - if result2 != nil { - suspendedBases.append(suspendedBase) - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .resumeBases(suspendedBases: [suspendedBase]) - } - - case .finished: - return .terminate(suspendedBase: suspendedBase) - } - } - - mutating func newLoopFromBase3(suspendedBase: UnsafeContinuation) -> NewLoopFromBaseOutput { - switch state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, var suspendedBases): - precondition(suspendedBases.count < 3, "There cannot be more than 3 suspended bases at the same time") - suspendedBases.append(suspendedBase) - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - return .none - - case .awaitingBaseResults(let task, let result1, let result2, let result3, var suspendedBases, let suspendedDemand): - precondition(suspendedBases.count < 3, "There cannot be more than 3 suspended bases at the same time") - if result3 != nil { - suspendedBases.append(suspendedBase) - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .resumeBases(suspendedBases: [suspendedBase]) - } - - case .finished: - return .terminate(suspendedBase: suspendedBase) - } - } - - enum BaseHasProducedElementOutput { - case none - case resumeDemand( - suspendedDemand: SuspendedDemand?, - result1: Result, - result2: Result, - result3: Result - ) - } - - mutating func base1HasProducedElement(element: Element1) -> BaseHasProducedElementOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, _, let result2, let result3, let suspendedBases, let suspendedDemand): - if let result2, let result3 { - self.state = .awaitingBaseResults( - task: task, - result1: .success(element), - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: nil - ) - return .resumeDemand(suspendedDemand: suspendedDemand, result1: .success(element), result2: result2, result3: result3) - } else { - self.state = .awaitingBaseResults( - task: task, - result1: .success(element), - result2: result2, - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } - - case .finished: - return .none - } - } - - mutating func base2HasProducedElement(element: Element2) -> BaseHasProducedElementOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, let result1, _, let result3, let suspendedBases, let suspendedDemand): - if let result1, let result3 { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: .success(element), - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: nil - ) - return .resumeDemand(suspendedDemand: suspendedDemand, result1: result1, result2: .success(element), result3: result3) - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: .success(element), - result3: result3, - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } - - case .finished: - return .none - } - } - - mutating func base3HasProducedElement(element: Element3) -> BaseHasProducedElementOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, let result1, let result2, _, let suspendedBases, let suspendedDemand): - if let result1, let result2 { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: .success(element), - suspendedBases: suspendedBases, - suspendedDemand: nil - ) - return .resumeDemand(suspendedDemand: suspendedDemand, result1: result1, result2: result2, result3: .success(element)) - } else { - self.state = .awaitingBaseResults( - task: task, - result1: result1, - result2: result2, - result3: .success(element), - suspendedBases: suspendedBases, - suspendedDemand: suspendedDemand - ) - return .none - } - - case .finished: - return .none - } - } - - enum BaseHasProducedFailureOutput { - case none - case resumeDemandAndTerminate( - task: Task?, - suspendedDemand: SuspendedDemand?, - suspendedBases: [UnsafeContinuation], - result1: Result, - result2: Result, - result3: Result - ) - } - - mutating func baseHasProducedFailure(error: any Error) -> BaseHasProducedFailureOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, a base can only produce an element when the consumer is awaiting for it") - - case .awaitingBaseResults(let task, _, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .resumeDemandAndTerminate( - task: task, - suspendedDemand: suspendedDemand, - suspendedBases: suspendedBases, - result1: .failure(error), - result2: .failure(error), - result3: .failure(error) - ) - - case .finished: - return .none - } - } - - mutating func demandIsFulfilled() { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer: - preconditionFailure("Inconsistent state, results are not yet available to be acknowledged") - - case .awaitingBaseResults(let task, let result1, let result2, let result3, let suspendedBases, let suspendedDemand): - precondition(suspendedDemand == nil, "Inconsistent state, there cannot be a suspended demand when ackowledging the demand") - precondition(result1 != nil && result2 != nil && result3 != nil, "Inconsistent state, all results are not yet available to be acknowledged") - self.state = .awaitingDemandFromConsumer(task: task, suspendedBases: suspendedBases) - - case .finished: - break - } - } - - enum RootTaskIsCancelledOutput { - case terminate( - task: Task?, - suspendedBases: [UnsafeContinuation]?, - suspendedDemands: [SuspendedDemand?]? - ) - } - - mutating func rootTaskIsCancelled() -> RootTaskIsCancelledOutput { - switch self.state { - case .initial: - assertionFailure("Inconsistent state, the task is not started") - self.state = .finished - return .terminate(task: nil, suspendedBases: nil, suspendedDemands: nil) - - case .awaitingDemandFromConsumer(let task, let suspendedBases): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: nil) - - case .awaitingBaseResults(let task, _, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: [suspendedDemand]) - - case .finished: - return .terminate(task: nil, suspendedBases: nil, suspendedDemands: nil) - } - } - - enum BaseIsFinishedOutput { - case terminate( - task: Task?, - suspendedBases: [UnsafeContinuation]?, - suspendedDemands: [SuspendedDemand?]? - ) - } - - mutating func baseIsFinished() -> BaseIsFinishedOutput { - switch self.state { - case .initial: - preconditionFailure("Inconsistent state, the task is not started") - - case .awaitingDemandFromConsumer(let task, let suspendedBases): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: nil) - - case .awaitingBaseResults(let task, _, _, _, let suspendedBases, let suspendedDemand): - self.state = .finished - return .terminate(task: task, suspendedBases: suspendedBases, suspendedDemands: [suspendedDemand]) - - case .finished: - return .terminate(task: nil, suspendedBases: nil, suspendedDemands: nil) - } - } -} diff --git a/Sources/AsyncAlgorithms/Zip/ZipStateMachine.swift b/Sources/AsyncAlgorithms/Zip/ZipStateMachine.swift new file mode 100644 index 00000000..d6b0c9ce --- /dev/null +++ b/Sources/AsyncAlgorithms/Zip/ZipStateMachine.swift @@ -0,0 +1,548 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// State machine for zip +struct ZipStateMachine< + Base1: AsyncSequence, + Base2: AsyncSequence, + Base3: AsyncSequence +>: Sendable where + Base1: Sendable, + Base2: Sendable, + Base3: Sendable, + Base1.Element: Sendable, + Base2.Element: Sendable, + Base3.Element: Sendable { + typealias DownstreamContinuation = UnsafeContinuation, Never> + + private enum State: Sendable { + /// Small wrapper for the state of an upstream sequence. + struct Upstream: Sendable { + /// The upstream continuation. + var continuation: UnsafeContinuation? + /// The produced upstream element. + var element: Element? + } + + /// The initial state before a call to `next` happened. + case initial(base1: Base1, base2: Base2, base3: Base3?) + + /// The state while we are waiting for downstream demand. + case waitingForDemand( + task: Task, + upstreams: (Upstream, Upstream, Upstream) + ) + + /// The state while we are consuming the upstream and waiting until we get a result from all upstreams. + case zipping( + task: Task, + upstreams: (Upstream, Upstream, Upstream), + downstreamContinuation: DownstreamContinuation + ) + + /// The state once one upstream sequences finished/threw or the downstream consumer stopped, i.e. by dropping all references + /// or by getting their `Task` cancelled. + case finished + + /// Internal state to avoid CoW. + case modifying + } + + private var state: State + + private let numberOfUpstreamSequences: Int + + /// Initializes a new `StateMachine`. + init( + base1: Base1, + base2: Base2, + base3: Base3? + ) { + self.state = .initial( + base1: base1, + base2: base2, + base3: base3 + ) + + if base3 == nil { + self.numberOfUpstreamSequences = 2 + } else { + self.numberOfUpstreamSequences = 3 + } + } + + /// Actions returned by `iteratorDeinitialized()`. + enum IteratorDeinitializedAction { + /// Indicates that the `Task` needs to be cancelled and + /// the upstream continuations need to be resumed with a `CancellationError`. + case cancelTaskAndUpstreamContinuations( + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + } + + mutating func iteratorDeinitialized() -> IteratorDeinitializedAction? { + switch self.state { + case .initial: + // Nothing to do here. No demand was signalled until now + return .none + + case .zipping: + // An iterator was deinitialized while we have a suspended continuation. + preconditionFailure("Internal inconsistency current state \(self.state) and received iteratorDeinitialized()") + + case .waitingForDemand(let task, let upstreams): + // The iterator was dropped which signals that the consumer is finished. + // We can transition to finished now and need to clean everything up. + self.state = .finished + + return .cancelTaskAndUpstreamContinuations( + task: task, + upstreamContinuations: [upstreams.0.continuation, upstreams.1.continuation, upstreams.2.continuation].compactMap { $0 } + ) + + case .finished: + // We are already finished so there is nothing left to clean up. + // This is just the references dropping afterwards. + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func taskIsStarted( + task: Task, + downstreamContinuation: DownstreamContinuation + ) { + switch self.state { + case .initial: + // The user called `next` and we are starting the `Task` + // to consume the upstream sequences + self.state = .zipping( + task: task, + upstreams: (.init(), .init(), .init()), + downstreamContinuation: downstreamContinuation + ) + + case .zipping, .waitingForDemand, .finished: + // We only allow a single task to be created so this must never happen. + preconditionFailure("Internal inconsistency current state \(self.state) and received taskStarted()") + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `childTaskSuspended()`. + enum ChildTaskSuspendedAction { + /// Indicates that the continuation should be resumed which will lead to calling `next` on the upstream. + case resumeContinuation( + upstreamContinuation: UnsafeContinuation + ) + /// Indicates that the continuation should be resumed with an Error because another upstream sequence threw. + case resumeContinuationWithError( + upstreamContinuation: UnsafeContinuation, + error: Error + ) + } + + mutating func childTaskSuspended(baseIndex: Int, continuation: UnsafeContinuation) -> ChildTaskSuspendedAction? { + switch self.state { + case .initial: + // Child tasks are only created after we transitioned to `zipping` + preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended()") + + case .waitingForDemand(let task, var upstreams): + self.state = .modifying + + switch baseIndex { + case 0: + upstreams.0.continuation = continuation + + case 1: + upstreams.1.continuation = continuation + + case 2: + upstreams.2.continuation = continuation + + default: + preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended() with base index \(baseIndex)") + } + + self.state = .waitingForDemand( + task: task, + upstreams: upstreams + ) + + return .none + + case .zipping(let task, var upstreams, let downstreamContinuation): + // We are currently zipping. If we have a buffered element from the base + // already then we store the continuation otherwise we just go ahead and resume it + switch baseIndex { + case 0: + if upstreams.0.element == nil { + return .resumeContinuation(upstreamContinuation: continuation) + } else { + self.state = .modifying + upstreams.0.continuation = continuation + self.state = .zipping( + task: task, + upstreams: upstreams, + downstreamContinuation: downstreamContinuation + ) + return .none + } + + case 1: + if upstreams.1.element == nil { + return .resumeContinuation(upstreamContinuation: continuation) + } else { + self.state = .modifying + upstreams.1.continuation = continuation + self.state = .zipping( + task: task, + upstreams: upstreams, + downstreamContinuation: downstreamContinuation + ) + return .none + } + + case 2: + if upstreams.2.element == nil { + return .resumeContinuation(upstreamContinuation: continuation) + } else { + self.state = .modifying + upstreams.2.continuation = continuation + self.state = .zipping( + task: task, + upstreams: upstreams, + downstreamContinuation: downstreamContinuation + ) + return .none + } + + default: + preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended() with base index \(baseIndex)") + } + + case .finished: + // Since cancellation is cooperative it might be that child tasks are still getting + // suspended even though we already cancelled them. We must tolerate this and just resume + // the continuation with an error. + return .resumeContinuationWithError( + upstreamContinuation: continuation, + error: CancellationError() + ) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `elementProduced()`. + enum ElementProducedAction { + /// Indicates that the downstream continuation should be resumed with the element. + case resumeContinuation( + downstreamContinuation: DownstreamContinuation, + result: Result<(Base1.Element, Base2.Element, Base3.Element?)?, Error> + ) + } + + mutating func elementProduced(_ result: (Base1.Element?, Base2.Element?, Base3.Element?)) -> ElementProducedAction? { + switch self.state { + case .initial: + // Child tasks that are producing elements are only created after we transitioned to `zipping` + preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()") + + case .waitingForDemand: + // We are only issuing demand when we get signalled by the downstream. + // We should never receive an element when we are waiting for demand. + preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()") + + case .zipping(let task, var upstreams, let downstreamContinuation): + self.state = .modifying + + switch result { + case (.some(let first), .none, .none): + precondition(upstreams.0.element == nil) + upstreams.0.element = first + + case (.none, .some(let second), .none): + precondition(upstreams.1.element == nil) + upstreams.1.element = second + + case (.none, .none, .some(let third)): + precondition(upstreams.2.element == nil) + upstreams.2.element = third + + default: + preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()") + } + + // Implementing this for the two arities without variadic generics is a bit awkward sadly. + if let first = upstreams.0.element, + let second = upstreams.1.element, + let third = upstreams.2.element { + // We got an element from each upstream so we can resume the downstream now + self.state = .waitingForDemand( + task: task, + upstreams: ( + .init(continuation: upstreams.0.continuation), + .init(continuation: upstreams.1.continuation), + .init(continuation: upstreams.2.continuation) + ) + ) + + return .resumeContinuation( + downstreamContinuation: downstreamContinuation, + result: .success((first, second, third)) + ) + + } else if let first = upstreams.0.element, + let second = upstreams.1.element, + self.numberOfUpstreamSequences == 2 { + // We got an element from each upstream so we can resume the downstream now + self.state = .waitingForDemand( + task: task, + upstreams: ( + .init(continuation: upstreams.0.continuation), + .init(continuation: upstreams.1.continuation), + .init(continuation: upstreams.2.continuation) + ) + ) + + return .resumeContinuation( + downstreamContinuation: downstreamContinuation, + result: .success((first, second, nil)) + ) + } else { + // We are still waiting for one of the upstreams to produce an element + self.state = .zipping( + task: task, + upstreams: ( + .init(continuation: upstreams.0.continuation, element: upstreams.0.element), + .init(continuation: upstreams.1.continuation, element: upstreams.1.element), + .init(continuation: upstreams.2.continuation, element: upstreams.2.element) + ), + downstreamContinuation: downstreamContinuation + ) + + return .none + } + + case .finished: + // Since cancellation is cooperative it might be that child tasks + // are still producing elements after we finished. + // We are just going to drop them since there is nothing we can do + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `upstreamFinished()`. + enum UpstreamFinishedAction { + /// Indicates that the downstream continuation should be resumed with `nil` and + /// the task and the upstream continuations should be cancelled. + case resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: DownstreamContinuation, + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + } + + mutating func upstreamFinished() -> UpstreamFinishedAction? { + switch self.state { + case .initial: + preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()") + + case .waitingForDemand: + // This can't happen. We are only issuing demand for a single element each time. + // There must never be outstanding demand to an upstream while we have no demand ourselves. + preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()") + + case .zipping(let task, let upstreams, let downstreamContinuation): + // One of our upstreams finished. We need to transition to finished ourselves now + // and resume the downstream continuation with nil. Furthermore, we need to cancel all of + // the upstream work. + self.state = .finished + + return .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: downstreamContinuation, + task: task, + upstreamContinuations: [upstreams.0.continuation, upstreams.1.continuation, upstreams.2.continuation].compactMap { $0 } + ) + + case .finished: + // This is just everything finishing up, nothing to do here + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `upstreamThrew()`. + enum UpstreamThrewAction { + /// Indicates that the downstream continuation should be resumed with the `error` and + /// the task and the upstream continuations should be cancelled. + case resumeContinuationWithErrorAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: DownstreamContinuation, + error: Error, + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + } + + mutating func upstreamThrew(_ error: Error) -> UpstreamThrewAction? { + switch self.state { + case .initial: + preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()") + + case .waitingForDemand: + // This can't happen. We are only issuing demand for a single element each time. + // There must never be outstanding demand to an upstream while we have no demand ourselves. + preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()") + + case .zipping(let task, let upstreams, let downstreamContinuation): + // One of our upstreams threw. We need to transition to finished ourselves now + // and resume the downstream continuation with the error. Furthermore, we need to cancel all of + // the upstream work. + self.state = .finished + + return .resumeContinuationWithErrorAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: downstreamContinuation, + error: error, + task: task, + upstreamContinuations: [upstreams.0.continuation, upstreams.1.continuation, upstreams.2.continuation].compactMap { $0 } + ) + + case .finished: + // This is just everything finishing up, nothing to do here + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `cancelled()`. + enum CancelledAction { + /// Indicates that the downstream continuation needs to be resumed and + /// task and the upstream continuations should be cancelled. + case resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: DownstreamContinuation, + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + /// Indicates that the task and the upstream continuations should be cancelled. + case cancelTaskAndUpstreamContinuations( + task: Task, + upstreamContinuations: [UnsafeContinuation] + ) + } + + mutating func cancelled() -> CancelledAction? { + switch self.state { + case .initial: + preconditionFailure("Internal inconsistency current state \(self.state) and received cancelled()") + + case .waitingForDemand(let task, let upstreams): + // The downstream task got cancelled so we need to cancel our upstream Task + // and resume all continuations. We can also transition to finished. + self.state = .finished + + return .cancelTaskAndUpstreamContinuations( + task: task, + upstreamContinuations: [upstreams.0.continuation, upstreams.1.continuation, upstreams.2.continuation].compactMap { $0 } + ) + + case .zipping(let task, let upstreams, let downstreamContinuation): + // The downstream Task got cancelled so we need to cancel our upstream Task + // and resume all continuations. We can also transition to finished. + self.state = .finished + + return .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( + downstreamContinuation: downstreamContinuation, + task: task, + upstreamContinuations: [upstreams.0.continuation, upstreams.1.continuation, upstreams.2.continuation].compactMap { $0 } + ) + + case .finished: + // We are already finished so nothing to do here: + self.state = .finished + + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Actions returned by `next()`. + enum NextAction { + /// Indicates that a new `Task` should be created that consumes the sequence. + case startTask(Base1, Base2, Base3?) + case resumeUpstreamContinuations( + upstreamContinuation: [UnsafeContinuation] + ) + /// Indicates that the downstream continuation should be resumed with `nil`. + case resumeDownstreamContinuationWithNil(DownstreamContinuation) + } + + mutating func next(for continuation: DownstreamContinuation) -> NextAction { + switch self.state { + case .initial(let base1, let base2, let base3): + // This is the first time we get demand singalled so we have to start the task + // The transition to the next state is done in the taskStarted method + return .startTask(base1, base2, base3) + + case .zipping: + // We already got demand signalled and have suspended the downstream task + // Getting a second next calls means the iterator was transferred across Tasks which is not allowed + preconditionFailure("Internal inconsistency current state \(self.state) and received next()") + + case .waitingForDemand(let task, var upstreams): + // We got demand signalled now and can transition to zipping. + // We also need to resume all upstream continuations now + self.state = .modifying + + let upstreamContinuations = [upstreams.0.continuation, upstreams.1.continuation, upstreams.2.continuation].compactMap { $0 } + upstreams.0.continuation = nil + upstreams.1.continuation = nil + upstreams.2.continuation = nil + + self.state = .zipping( + task: task, + upstreams: upstreams, + downstreamContinuation: continuation + ) + + return .resumeUpstreamContinuations( + upstreamContinuation: upstreamContinuations + ) + + case .finished: + // We are already finished so we are just returning `nil` + return .resumeDownstreamContinuationWithNil(continuation) + + case .modifying: + preconditionFailure("Invalid state") + } + } +} diff --git a/Sources/AsyncAlgorithms/Zip/ZipStorage.swift b/Sources/AsyncAlgorithms/Zip/ZipStorage.swift new file mode 100644 index 00000000..57ee1dc5 --- /dev/null +++ b/Sources/AsyncAlgorithms/Zip/ZipStorage.swift @@ -0,0 +1,320 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +final class ZipStorage: Sendable + where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable, Base3: Sendable, Base3.Element: Sendable { + typealias StateMachine = ZipStateMachine + + private let stateMachine: ManagedCriticalState + + init(_ base1: Base1, _ base2: Base2, _ base3: Base3?) { + self.stateMachine = .init(.init(base1: base1, base2: base2, base3: base3)) + } + + func iteratorDeinitialized() { + let action = self.stateMachine.withCriticalRegion { $0.iteratorDeinitialized() } + + switch action { + case .cancelTaskAndUpstreamContinuations( + let task, + let upstreamContinuation + ): + upstreamContinuation.forEach { $0.resume(throwing: CancellationError()) } + + task.cancel() + + case .none: + break + } + } + + func next() async rethrows -> (Base1.Element, Base2.Element, Base3.Element?)? { + try await withTaskCancellationHandler { + let result = await withUnsafeContinuation { continuation in + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.next(for: continuation) + switch action { + case .startTask(let base1, let base2, let base3): + // first iteration, we start one child task per base to iterate over them + self.startTask( + stateMachine: &stateMachine, + base1: base1, + base2: base2, + base3: base3, + downStreamContinuation: continuation + ) + + case .resumeUpstreamContinuations(let upstreamContinuations): + // bases can be iterated over for 1 iteration so their next value can be retrieved + upstreamContinuations.forEach { $0.resume() } + + case .resumeDownstreamContinuationWithNil(let continuation): + // the async sequence is already finished, immediately resuming + continuation.resume(returning: .success(nil)) + } + } + } + + return try result._rethrowGet() + + } onCancel: { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.cancelled() + + switch action { + case .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations( + let downstreamContinuation, + let task, + let upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: .success(nil)) + + case .cancelTaskAndUpstreamContinuations(let task, let upstreamContinuations): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + case .none: + break + } + } + } + } + + private func startTask( + stateMachine: inout StateMachine, + base1: Base1, + base2: Base2, + base3: Base3?, + downStreamContinuation: StateMachine.DownstreamContinuation + ) { + // This creates a new `Task` that is iterating the upstream + // sequences. We must store it to cancel it at the right times. + let task = Task { + await withThrowingTaskGroup(of: Void.self) { group in + // For each upstream sequence we are adding a child task that + // is consuming the upstream sequence + group.addTask { + var base1Iterator = base1.makeAsyncIterator() + + while true { + // We are creating a continuation before requesting the next + // element from upstream. This continuation is only resumed + // if the downstream consumer called `next` to signal his demand. + try await withUnsafeThrowingContinuation { continuation in + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.childTaskSuspended(baseIndex: 0, continuation: continuation) + + switch action { + case .resumeContinuation(let upstreamContinuation): + upstreamContinuation.resume() + + case .resumeContinuationWithError(let upstreamContinuation, let error): + upstreamContinuation.resume(throwing: error) + + case .none: + break + } + } + } + + if let element1 = try await base1Iterator.next() { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.elementProduced((element1, nil, nil)) + + switch action { + case .resumeContinuation(let downstreamContinuation, let result): + downstreamContinuation.resume(returning: result) + + case .none: + break + } + } + } else { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.upstreamFinished() + + switch action { + case .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + let downstreamContinuation, + let task, + let upstreamContinuations + ): + + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: .success(nil)) + + case .none: + break + } + } + } + } + } + + group.addTask { + var base1Iterator = base2.makeAsyncIterator() + + while true { + // We are creating a continuation before requesting the next + // element from upstream. This continuation is only resumed + // if the downstream consumer called `next` to signal his demand. + try await withUnsafeThrowingContinuation { continuation in + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.childTaskSuspended(baseIndex: 1, continuation: continuation) + + switch action { + case .resumeContinuation(let upstreamContinuation): + upstreamContinuation.resume() + + case .resumeContinuationWithError(let upstreamContinuation, let error): + upstreamContinuation.resume(throwing: error) + + case .none: + break + } + } + } + + if let element2 = try await base1Iterator.next() { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.elementProduced((nil, element2, nil)) + + switch action { + case .resumeContinuation(let downstreamContinuation, let result): + downstreamContinuation.resume(returning: result) + + case .none: + break + } + } + } else { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.upstreamFinished() + + switch action { + case .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + let downstreamContinuation, + let task, + let upstreamContinuations + ): + + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: .success(nil)) + + case .none: + break + } + } + } + } + } + + if let base3 = base3 { + group.addTask { + var base1Iterator = base3.makeAsyncIterator() + + while true { + // We are creating a continuation before requesting the next + // element from upstream. This continuation is only resumed + // if the downstream consumer called `next` to signal his demand. + try await withUnsafeThrowingContinuation { continuation in + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.childTaskSuspended(baseIndex: 2, continuation: continuation) + + switch action { + case .resumeContinuation(let upstreamContinuation): + upstreamContinuation.resume() + + case .resumeContinuationWithError(let upstreamContinuation, let error): + upstreamContinuation.resume(throwing: error) + + case .none: + break + } + } + } + + if let element3 = try await base1Iterator.next() { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.elementProduced((nil, nil, element3)) + + switch action { + case .resumeContinuation(let downstreamContinuation, let result): + downstreamContinuation.resume(returning: result) + + case .none: + break + } + } + } else { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.upstreamFinished() + + switch action { + case .resumeContinuationWithNilAndCancelTaskAndUpstreamContinuations( + let downstreamContinuation, + let task, + let upstreamContinuations + ): + + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: .success(nil)) + + case .none: + break + } + } + } + } + } + } + + do { + try await group.waitForAll() + } catch { + // One of the upstream sequences threw an error + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.upstreamThrew(error) + + switch action { + case .resumeContinuationWithErrorAndCancelTaskAndUpstreamContinuations( + let downstreamContinuation, + let error, + let task, + let upstreamContinuations + ): + upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) } + task.cancel() + + downstreamContinuation.resume(returning: .failure(error)) + + case .none: + break + } + } + + group.cancelAll() + } + } + } + + stateMachine.taskIsStarted(task: task, downstreamContinuation: downStreamContinuation) + } +}