diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift deleted file mode 100644 index 1b41c779..00000000 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ /dev/null @@ -1,275 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -@preconcurrency @_implementationOnly import OrderedCollections - -/// A channel for sending elements from one task to another with back pressure. -/// -/// The `AsyncChannel` class is intended to be used as a communication type between tasks, -/// particularly when one task produces values and another task consumes those values. The back -/// pressure applied by `send(_:)` via the suspension/resume ensures that -/// the production of values does not exceed the consumption of values from iteration. This method -/// suspends after enqueuing the event and is resumed when the next call to `next()` -/// on the `Iterator` is made, or when `finish()` is called from another Task. -/// As `finish()` induces a terminal state, there is no need for a back pressure management. -/// This function does not suspend and will finish all the pending iterations. -public final class AsyncChannel: AsyncSequence, Sendable { - /// The iterator for a `AsyncChannel` instance. - public struct Iterator: AsyncIteratorProtocol, Sendable { - let channel: AsyncChannel - var active: Bool = true - - init(_ channel: AsyncChannel) { - self.channel = channel - } - - /// Await the next sent element or finish. - public mutating func next() async -> Element? { - guard active else { - return nil - } - - let generation = channel.establish() - let nextTokenStatus = ManagedCriticalState(.new) - - let value = await withTaskCancellationHandler { - await channel.next(nextTokenStatus, generation) - } onCancel: { [channel] in - channel.cancelNext(nextTokenStatus, generation) - } - - if let value { - return value - } else { - active = false - return nil - } - } - } - - typealias Pending = ChannelToken?, Never>> - typealias Awaiting = ChannelToken> - - struct ChannelToken: Hashable, Sendable { - var generation: Int - var continuation: Continuation? - - init(generation: Int, continuation: Continuation) { - self.generation = generation - self.continuation = continuation - } - - init(placeholder generation: Int) { - self.generation = generation - self.continuation = nil - } - - func hash(into hasher: inout Hasher) { - hasher.combine(generation) - } - - static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool { - return lhs.generation == rhs.generation - } - } - - enum ChannelTokenStatus: Equatable { - case new - case cancelled - } - - enum Emission : Sendable { - case idle - case pending(OrderedSet) - case awaiting(OrderedSet) - case finished - } - - struct State : Sendable { - var emission: Emission = .idle - var generation = 0 - } - - let state = ManagedCriticalState(State()) - - /// Create a new `AsyncChannel` given an element type. - public init(element elementType: Element.Type = Element.self) { } - - func establish() -> Int { - state.withCriticalRegion { state in - defer { state.generation &+= 1 } - return state.generation - } - } - - func cancelNext(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) { - state.withCriticalRegion { state in - let continuation: UnsafeContinuation? - - switch state.emission { - case .awaiting(var nexts): - continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation - if nexts.isEmpty { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - default: - continuation = nil - } - - nextTokenStatus.withCriticalRegion { status in - if status == .new { - status = .cancelled - } - } - - continuation?.resume(returning: nil) - } - } - - func next(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) async -> Element? { - return await withUnsafeContinuation { (continuation: UnsafeContinuation) in - var cancelled = false - var terminal = false - state.withCriticalRegion { state in - - if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled { - cancelled = true - } - - switch state.emission { - case .idle: - state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) - case .pending(var sends): - let send = sends.removeFirst() - if sends.count == 0 { - state.emission = .idle - } else { - state.emission = .pending(sends) - } - send.continuation?.resume(returning: continuation) - case .awaiting(var nexts): - nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation)) - state.emission = .awaiting(nexts) - case .finished: - terminal = true - } - } - - if cancelled || terminal { - continuation.resume(returning: nil) - } - } - } - - func cancelSend(_ sendTokenStatus: ManagedCriticalState, _ generation: Int) { - state.withCriticalRegion { state in - let continuation: UnsafeContinuation?, Never>? - - switch state.emission { - case .pending(var sends): - let send = sends.remove(Pending(placeholder: generation)) - if sends.isEmpty { - state.emission = .idle - } else { - state.emission = .pending(sends) - } - continuation = send?.continuation - default: - continuation = nil - } - - sendTokenStatus.withCriticalRegion { status in - if status == .new { - status = .cancelled - } - } - - continuation?.resume(returning: nil) - } - } - - func send(_ sendTokenStatus: ManagedCriticalState, _ generation: Int, _ element: Element) async { - let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in - state.withCriticalRegion { state in - - if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled { - continuation.resume(returning: nil) - return - } - - switch state.emission { - case .idle: - state.emission = .pending([Pending(generation: generation, continuation: continuation)]) - case .pending(var sends): - sends.updateOrAppend(Pending(generation: generation, continuation: continuation)) - state.emission = .pending(sends) - case .awaiting(var nexts): - let next = nexts.removeFirst().continuation - if nexts.count == 0 { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - continuation.resume(returning: next) - case .finished: - continuation.resume(returning: nil) - } - } - } - continuation?.resume(returning: element) - } - - /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made - /// or when a call to `finish()` is made from another Task. - /// If the channel is already finished then this returns immediately - /// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active. - public func send(_ element: Element) async { - let generation = establish() - let sendTokenStatus = ManagedCriticalState(.new) - - await withTaskCancellationHandler { - await send(sendTokenStatus, generation, element) - } onCancel: { [weak self] in - self?.cancelSend(sendTokenStatus, generation) - } - } - - /// Send a finish to all awaiting iterations. - /// All subsequent calls to `next(_:)` will resume immediately. - public func finish() { - state.withCriticalRegion { state in - - defer { state.emission = .finished } - - switch state.emission { - case .pending(let sends): - for send in sends { - send.continuation?.resume(returning: nil) - } - case .awaiting(let nexts): - for next in nexts { - next.continuation?.resume(returning: nil) - } - default: - break - } - } - - - } - - /// Create an `Iterator` for iteration of an `AsyncChannel` - public func makeAsyncIterator() -> Iterator { - return Iterator(self) - } -} diff --git a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift deleted file mode 100644 index 8359287e..00000000 --- a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift +++ /dev/null @@ -1,320 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Async Algorithms open source project -// -// Copyright (c) 2022 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// -//===----------------------------------------------------------------------===// - -@preconcurrency @_implementationOnly import OrderedCollections - -/// An error-throwing channel for sending elements from on task to another with back pressure. -/// -/// The `AsyncThrowingChannel` class is intended to be used as a communication types between tasks, -/// particularly when one task produces values and another task consumes those values. The back -/// pressure applied by `send(_:)` via suspension/resume ensures that the production of values does -/// not exceed the consumption of values from iteration. This method suspends after enqueuing the event -/// and is resumed when the next call to `next()` on the `Iterator` is made, or when `finish()`/`fail(_:)` is called -/// from another Task. As `finish()` and `fail(_:)` induce a terminal state, there is no need for a back pressure management. -/// Those functions do not suspend and will finish all the pending iterations. -public final class AsyncThrowingChannel: AsyncSequence, Sendable { - /// The iterator for an `AsyncThrowingChannel` instance. - public struct Iterator: AsyncIteratorProtocol, Sendable { - let channel: AsyncThrowingChannel - var active: Bool = true - - init(_ channel: AsyncThrowingChannel) { - self.channel = channel - } - - public mutating func next() async throws -> Element? { - guard active else { - return nil - } - - let generation = channel.establish() - let nextTokenStatus = ManagedCriticalState(.new) - - do { - let value = try await withTaskCancellationHandler { - try await channel.next(nextTokenStatus, generation) - } onCancel: { [channel] in - channel.cancelNext(nextTokenStatus, generation) - } - - if let value = value { - return value - } else { - active = false - return nil - } - } catch { - active = false - throw error - } - } - } - - typealias Pending = ChannelToken?, Never>> - typealias Awaiting = ChannelToken> - - struct ChannelToken: Hashable, Sendable { - var generation: Int - var continuation: Continuation? - - init(generation: Int, continuation: Continuation) { - self.generation = generation - self.continuation = continuation - } - - init(placeholder generation: Int) { - self.generation = generation - self.continuation = nil - } - - func hash(into hasher: inout Hasher) { - hasher.combine(generation) - } - - static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool { - return lhs.generation == rhs.generation - } - } - - - enum ChannelTokenStatus: Equatable { - case new - case cancelled - } - - enum Termination { - case finished - case failed(Error) - } - - enum Emission: Sendable { - case idle - case pending(OrderedSet) - case awaiting(OrderedSet) - case terminated(Termination) - } - - struct State : Sendable { - var emission: Emission = .idle - var generation = 0 - } - - let state = ManagedCriticalState(State()) - - public init(_ elementType: Element.Type = Element.self) { } - - func establish() -> Int { - state.withCriticalRegion { state in - defer { state.generation &+= 1 } - return state.generation - } - } - - func cancelNext(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) { - state.withCriticalRegion { state in - let continuation: UnsafeContinuation? - - switch state.emission { - case .awaiting(var nexts): - continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation - if nexts.isEmpty { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - default: - continuation = nil - } - - nextTokenStatus.withCriticalRegion { status in - if status == .new { - status = .cancelled - } - } - - continuation?.resume(returning: nil) - } - } - - func next(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) async throws -> Element? { - return try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in - var cancelled = false - var potentialTermination: Termination? - - state.withCriticalRegion { state in - - if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled { - cancelled = true - return - } - - switch state.emission { - case .idle: - state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) - case .pending(var sends): - let send = sends.removeFirst() - if sends.count == 0 { - state.emission = .idle - } else { - state.emission = .pending(sends) - } - send.continuation?.resume(returning: continuation) - case .awaiting(var nexts): - nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation)) - state.emission = .awaiting(nexts) - case .terminated(let termination): - potentialTermination = termination - state.emission = .terminated(.finished) - } - } - - if cancelled { - continuation.resume(returning: nil) - return - } - - switch potentialTermination { - case .none: - return - case .failed(let error): - continuation.resume(throwing: error) - return - case .finished: - continuation.resume(returning: nil) - return - } - } - } - - func cancelSend(_ sendTokenStatus: ManagedCriticalState, _ generation: Int) { - state.withCriticalRegion { state in - let continuation: UnsafeContinuation?, Never>? - - switch state.emission { - case .pending(var sends): - let send = sends.remove(Pending(placeholder: generation)) - if sends.isEmpty { - state.emission = .idle - } else { - state.emission = .pending(sends) - } - continuation = send?.continuation - default: - continuation = nil - } - - sendTokenStatus.withCriticalRegion { status in - if status == .new { - status = .cancelled - } - } - - continuation?.resume(returning: nil) - } - } - - func send(_ sendTokenStatus: ManagedCriticalState, _ generation: Int, _ element: Element) async { - let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in - state.withCriticalRegion { state in - - if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled { - continuation.resume(returning: nil) - return - } - - switch state.emission { - case .idle: - state.emission = .pending([Pending(generation: generation, continuation: continuation)]) - case .pending(var sends): - sends.updateOrAppend(Pending(generation: generation, continuation: continuation)) - state.emission = .pending(sends) - case .awaiting(var nexts): - let next = nexts.removeFirst().continuation - if nexts.count == 0 { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - continuation.resume(returning: next) - case .terminated: - continuation.resume(returning: nil) - } - } - } - continuation?.resume(returning: element) - } - - func terminateAll(error: Failure? = nil) { - state.withCriticalRegion { state in - - let nextState: Emission - if let error = error { - nextState = .terminated(.failed(error)) - } else { - nextState = .terminated(.finished) - } - - switch state.emission { - case .idle: - state.emission = nextState - case .pending(let sends): - state.emission = nextState - for send in sends { - send.continuation?.resume(returning: nil) - } - case .awaiting(let nexts): - state.emission = nextState - if let error = error { - for next in nexts { - next.continuation?.resume(throwing: error) - } - } else { - for next in nexts { - next.continuation?.resume(returning: nil) - } - } - case .terminated: - break - } - } - } - - /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made - /// or when a call to `finish()`/`fail(_:)` is made from another Task. - /// If the channel is already finished then this returns immediately - /// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active. - public func send(_ element: Element) async { - let generation = establish() - let sendTokenStatus = ManagedCriticalState(.new) - - await withTaskCancellationHandler { - await send(sendTokenStatus, generation, element) - } onCancel: { [weak self] in - self?.cancelSend(sendTokenStatus, generation) - } - } - - /// Send an error to all awaiting iterations. - /// All subsequent calls to `next(_:)` will resume immediately. - public func fail(_ error: Error) where Failure == Error { - terminateAll(error: error) - } - - /// Send a finish to all awaiting iterations. - /// All subsequent calls to `next(_:)` will resume immediately. - public func finish() { - terminateAll() - } - - public func makeAsyncIterator() -> Iterator { - return Iterator(self) - } -} diff --git a/Sources/AsyncAlgorithms/Channels/AsyncChannel.swift b/Sources/AsyncAlgorithms/Channels/AsyncChannel.swift new file mode 100644 index 00000000..75becf2d --- /dev/null +++ b/Sources/AsyncAlgorithms/Channels/AsyncChannel.swift @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// A channel for sending elements from one task to another with back pressure. +/// +/// The `AsyncChannel` class is intended to be used as a communication type between tasks, +/// particularly when one task produces values and another task consumes those values. The back +/// pressure applied by `send(_:)` via the suspension/resume ensures that +/// the production of values does not exceed the consumption of values from iteration. This method +/// suspends after enqueuing the event and is resumed when the next call to `next()` +/// on the `Iterator` is made, or when `finish()` is called from another Task. +/// As `finish()` induces a terminal state, there is no more need for a back pressure management. +/// This function does not suspend and will finish all the pending iterations. +public final class AsyncChannel: AsyncSequence, @unchecked Sendable { + public typealias Element = Element + public typealias AsyncIterator = Iterator + + let storage: ChannelStorage + + public init() { + self.storage = ChannelStorage() + } + + /// Sends an element to an awaiting iteration. This function will resume when the next call to `next()` is made + /// or when a call to `finish()` is made from another task. + /// If the channel is already finished then this returns immediately. + /// If the task is cancelled, this function will resume without sending the element. + /// Other sending operations from other tasks will remain active. + public func send(_ element: Element) async { + await self.storage.send(element: element) + } + + /// Immediately resumes all the suspended operations. + /// All subsequent calls to `next(_:)` will resume immediately. + public func finish() { + self.storage.finish() + } + + public func makeAsyncIterator() -> Iterator { + Iterator(storage: self.storage) + } + + public struct Iterator: AsyncIteratorProtocol { + let storage: ChannelStorage + + public mutating func next() async -> Element? { + // Although the storage can throw, its usage in the context of an `AsyncChannel` guarantees it cannot. + // There is no public way of sending a failure to it. + try! await self.storage.next() + } + } +} diff --git a/Sources/AsyncAlgorithms/Channels/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/Channels/AsyncThrowingChannel.swift new file mode 100644 index 00000000..28de36ae --- /dev/null +++ b/Sources/AsyncAlgorithms/Channels/AsyncThrowingChannel.swift @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// An error-throwing channel for sending elements from on task to another with back pressure. +/// +/// The `AsyncThrowingChannel` class is intended to be used as a communication types between tasks, +/// particularly when one task produces values and another task consumes those values. The back +/// pressure applied by `send(_:)` via suspension/resume ensures that the production of values does +/// not exceed the consumption of values from iteration. This method suspends after enqueuing the event +/// and is resumed when the next call to `next()` on the `Iterator` is made, or when `finish()`/`fail(_:)` is called +/// from another Task. As `finish()` and `fail(_:)` induce a terminal state, there is no more need for a back pressure management. +/// Those functions do not suspend and will finish all the pending iterations. +public final class AsyncThrowingChannel: AsyncSequence, @unchecked Sendable { + public typealias Element = Element + public typealias AsyncIterator = Iterator + + let storage: ChannelStorage + + public init() { + self.storage = ChannelStorage() + } + + /// Sends an element to an awaiting iteration. This function will resume when the next call to `next()` is made + /// or when a call to `finish()` or `fail` is made from another task. + /// If the channel is already finished then this returns immediately. + /// If the task is cancelled, this function will resume without sending the element. + /// Other sending operations from other tasks will remain active. + public func send(_ element: Element) async { + await self.storage.send(element: element) + } + + /// Sends an error to all awaiting iterations. + /// All subsequent calls to `next(_:)` will resume immediately. + public func fail(_ error: Error) where Failure == Error { + self.storage.finish(error: error) + } + + /// Immediately resumes all the suspended operations. + /// All subsequent calls to `next(_:)` will resume immediately. + public func finish() { + self.storage.finish() + } + + public func makeAsyncIterator() -> Iterator { + Iterator(storage: self.storage) + } + + public struct Iterator: AsyncIteratorProtocol { + let storage: ChannelStorage + + public mutating func next() async throws -> Element? { + try await self.storage.next() + } + } +} diff --git a/Sources/AsyncAlgorithms/Channels/ChannelStateMachine.swift b/Sources/AsyncAlgorithms/Channels/ChannelStateMachine.swift new file mode 100644 index 00000000..2972c754 --- /dev/null +++ b/Sources/AsyncAlgorithms/Channels/ChannelStateMachine.swift @@ -0,0 +1,344 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// +@_implementationOnly import OrderedCollections + +struct ChannelStateMachine: Sendable { + private struct SuspendedProducer: Hashable { + let id: UInt64 + let continuation: UnsafeContinuation? + let element: Element? + + func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + static func == (_ lhs: SuspendedProducer, _ rhs: SuspendedProducer) -> Bool { + return lhs.id == rhs.id + } + + static func placeHolder(id: UInt64) -> SuspendedProducer { + SuspendedProducer(id: id, continuation: nil, element: nil) + } + } + + private struct SuspendedConsumer: Hashable { + let id: UInt64 + let continuation: UnsafeContinuation? + + func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + static func == (_ lhs: SuspendedConsumer, _ rhs: SuspendedConsumer) -> Bool { + return lhs.id == rhs.id + } + + static func placeHolder(id: UInt64) -> SuspendedConsumer { + SuspendedConsumer(id: id, continuation: nil) + } + } + + private enum Termination { + case finished + case failed(Error) + } + + private enum State { + case channeling( + suspendedProducers: OrderedSet, + cancelledProducers: Set, + suspendedConsumers: OrderedSet, + cancelledConsumers: Set + ) + case terminated(Termination) + } + + private var state: State = .channeling(suspendedProducers: [], cancelledProducers: [], suspendedConsumers: [], cancelledConsumers: []) + + enum SendAction { + case resumeConsumer(continuation: UnsafeContinuation?) + case suspend + } + + mutating func send() -> SendAction { + switch self.state { + case .channeling(_, _, let suspendedConsumers, _) where suspendedConsumers.isEmpty: + // we are idle or waiting for consumers, we have to suspend the producer + return .suspend + + case .channeling(let suspendedProducers, let cancelledProducers, var suspendedConsumers, let cancelledConsumers): + // we are waiting for producers, we can resume the first available consumer + let suspendedConsumer = suspendedConsumers.removeFirst() + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeConsumer(continuation: suspendedConsumer.continuation) + + case .terminated: + return .resumeConsumer(continuation: nil) + } + } + + enum SendSuspendedAction { + case resumeProducer + case resumeProducerAndConsumer(continuation: UnsafeContinuation?) + } + + mutating func sendSuspended( + continuation: UnsafeContinuation, + element: Element, + producerID: UInt64 + ) -> SendSuspendedAction? { + switch self.state { + case .channeling(var suspendedProducers, var cancelledProducers, var suspendedConsumers, let cancelledConsumers): + let suspendedProducer = SuspendedProducer(id: producerID, continuation: continuation, element: element) + if let _ = cancelledProducers.remove(suspendedProducer) { + // the producer was already cancelled, we resume it + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeProducer + } + + if suspendedConsumers.isEmpty { + // we are idle or waiting for consumers + // we stack the incoming producer in a suspended state + suspendedProducers.append(suspendedProducer) + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .none + } else { + // we are waiting for producers + // we resume the first consumer + let suspendedConsumer = suspendedConsumers.removeFirst() + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeProducerAndConsumer(continuation: suspendedConsumer.continuation) + } + + case .terminated: + return .resumeProducer + } + } + + enum SendCancelledAction { + case none + case resumeProducer(continuation: UnsafeContinuation?) + } + + mutating func sendCancelled(producerID: UInt64) -> SendCancelledAction { + switch self.state { + case .channeling(var suspendedProducers, var cancelledProducers, let suspendedConsumers, let cancelledConsumers): + // the cancelled producer might be part of the waiting list + let placeHolder = SuspendedProducer.placeHolder(id: producerID) + + if let removed = suspendedProducers.remove(placeHolder) { + // the producer was cancelled after being added to the suspended ones, we resume it + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeProducer(continuation: removed.continuation) + } + + // the producer was cancelled before being added to the suspended ones + cancelledProducers.update(with: placeHolder) + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .none + + case .terminated: + return .none + } + } + + enum FinishAction { + case none + case resumeProducersAndConsumers( + producerSontinuations: [UnsafeContinuation?], + consumerContinuations: [UnsafeContinuation?] + ) + } + + mutating func finish(error: Failure?) -> FinishAction { + switch self.state { + case .channeling(let suspendedProducers, _, let suspendedConsumers, _): + // no matter if we are idle, waiting for producers or waiting for consumers, we resume every thing that is suspended + if let error { + if suspendedConsumers.isEmpty { + self.state = .terminated(.failed(error)) + } else { + self.state = .terminated(.finished) + } + } else { + self.state = .terminated(.finished) + } + return .resumeProducersAndConsumers( + producerSontinuations: suspendedProducers.map { $0.continuation }, + consumerContinuations: suspendedConsumers.map { $0.continuation } + ) + + case .terminated: + return .none + } + } + + enum NextAction { + case resumeProducer(continuation: UnsafeContinuation?, result: Result) + case suspend + } + + mutating func next() -> NextAction { + switch self.state { + case .channeling(let suspendedProducers, _, _, _) where suspendedProducers.isEmpty: + // we are idle or waiting for producers, we must suspend + return .suspend + + case .channeling(var suspendedProducers, let cancelledProducers, let suspendedConsumers, let cancelledConsumers): + // we are waiting for consumers, we can resume the first awaiting producer + let suspendedProducer = suspendedProducers.removeFirst() + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeProducer( + continuation: suspendedProducer.continuation, + result: .success(suspendedProducer.element) + ) + + case .terminated(.failed(let error)): + self.state = .terminated(.finished) + return .resumeProducer(continuation: nil, result: .failure(error)) + + case .terminated: + return .resumeProducer(continuation: nil, result: .success(nil)) + } + } + + enum NextSuspendedAction { + case resumeConsumer(element: Element?) + case resumeConsumerWithError(error: Error) + case resumeProducerAndConsumer(continuation: UnsafeContinuation?, element: Element?) + } + + mutating func nextSuspended( + continuation: UnsafeContinuation, + consumerID: UInt64 + ) -> NextSuspendedAction? { + switch self.state { + case .channeling(var suspendedProducers, let cancelledProducers, var suspendedConsumers, var cancelledConsumers): + let suspendedConsumer = SuspendedConsumer(id: consumerID, continuation: continuation) + if let _ = cancelledConsumers.remove(suspendedConsumer) { + // the consumer was already cancelled, we resume it + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeConsumer(element: nil) + } + + if suspendedProducers.isEmpty { + // we are idle or waiting for producers + // we stack the incoming consumer in a suspended state + suspendedConsumers.append(suspendedConsumer) + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .none + } else { + // we are waiting for consumers + // we resume the first producer + let suspendedProducer = suspendedProducers.removeFirst() + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeProducerAndConsumer( + continuation: suspendedProducer.continuation, + element: suspendedProducer.element + ) + } + + case .terminated(.finished): + return .resumeConsumer(element: nil) + + case .terminated(.failed(let error)): + self.state = .terminated(.finished) + return .resumeConsumerWithError(error: error) + } + } + + enum NextCancelledAction { + case none + case resumeConsumer(continuation: UnsafeContinuation?) + } + + mutating func nextCancelled(consumerID: UInt64) -> NextCancelledAction { + switch self.state { + case .channeling(let suspendedProducers, let cancelledProducers, var suspendedConsumers, var cancelledConsumers): + // the cancelled consumer might be part of the suspended ones + let placeHolder = SuspendedConsumer.placeHolder(id: consumerID) + + if let removed = suspendedConsumers.remove(placeHolder) { + // the consumer was cancelled after being added to the suspended ones, we resume it + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .resumeConsumer(continuation: removed.continuation) + } + + // the consumer was cancelled before being added to the suspended ones + cancelledConsumers.update(with: placeHolder) + self.state = .channeling( + suspendedProducers: suspendedProducers, + cancelledProducers: cancelledProducers, + suspendedConsumers: suspendedConsumers, + cancelledConsumers: cancelledConsumers + ) + return .none + + case .terminated: + return .none + } + } +} diff --git a/Sources/AsyncAlgorithms/Channels/ChannelStorage.swift b/Sources/AsyncAlgorithms/Channels/ChannelStorage.swift new file mode 100644 index 00000000..da398dbc --- /dev/null +++ b/Sources/AsyncAlgorithms/Channels/ChannelStorage.swift @@ -0,0 +1,149 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// +struct ChannelStorage: Sendable { + private let stateMachine: ManagedCriticalState> + private let ids = ManagedCriticalState(0) + + init() { + self.stateMachine = ManagedCriticalState(ChannelStateMachine()) + } + + func generateId() -> UInt64 { + self.ids.withCriticalRegion { ids in + defer { ids &+= 1 } + return ids + } + } + + func send(element: Element) async { + // check if a suspension is needed + let shouldExit = self.stateMachine.withCriticalRegion { stateMachine -> Bool in + let action = stateMachine.send() + + switch action { + case .suspend: + // the element has not been delivered because no consumer available, we must suspend + return false + case .resumeConsumer(let continuation): + continuation?.resume(returning: element) + return true + } + } + + if shouldExit { + return + } + + let producerID = self.generateId() + + await withTaskCancellationHandler { + // a suspension is needed + await withUnsafeContinuation { (continuation: UnsafeContinuation) in + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.sendSuspended(continuation: continuation, element: element, producerID: producerID) + + switch action { + case .none: + break + case .resumeProducer: + continuation.resume() + case .resumeProducerAndConsumer(let consumerContinuation): + continuation.resume() + consumerContinuation?.resume(returning: element) + } + } + } + } onCancel: { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.sendCancelled(producerID: producerID) + + switch action { + case .none: + break + case .resumeProducer(let continuation): + continuation?.resume() + } + } + } + } + + func finish(error: Failure? = nil) { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.finish(error: error) + + switch action { + case .none: + break + case .resumeProducersAndConsumers(let producerContinuations, let consumerContinuations): + producerContinuations.forEach { $0?.resume() } + if let error { + consumerContinuations.forEach { $0?.resume(throwing: error) } + } else { + consumerContinuations.forEach { $0?.resume(returning: nil) } + } + } + } + } + + func next() async throws -> Element? { + let (shouldExit, result) = self.stateMachine.withCriticalRegion { stateMachine -> (Bool, Result?) in + let action = stateMachine.next() + + switch action { + case .suspend: + return (false, nil) + case .resumeProducer(let producerContinuation, let result): + producerContinuation?.resume() + return (true, result) + } + } + + if shouldExit { + return try result?._rethrowGet() + } + + let consumerID = self.generateId() + + return try await withTaskCancellationHandler { + try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.nextSuspended( + continuation: continuation, + consumerID: consumerID + ) + + switch action { + case .none: + break + case .resumeConsumer(let element): + continuation.resume(returning: element) + case .resumeConsumerWithError(let error): + continuation.resume(throwing: error) + case .resumeProducerAndConsumer(let producerContinuation, let element): + producerContinuation?.resume() + continuation.resume(returning: element) + } + } + } + } onCancel: { + self.stateMachine.withCriticalRegion { stateMachine in + let action = stateMachine.nextCancelled(consumerID: consumerID) + + switch action { + case .none: + break + case .resumeConsumer(let continuation): + continuation?.resume(returning: nil) + } + } + } + } +} diff --git a/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift b/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift index 9038500f..834b7303 100644 --- a/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift +++ b/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift @@ -14,6 +14,12 @@ import AsyncAlgorithms #if canImport(Darwin) final class TestThroughput: XCTestCase { + func test_channel() async { + await measureChannelThroughput(output: 1) + } + func test_throwingChannel() async { + await measureThrowingChannelThroughput(output: 1) + } func test_chain2() async { await measureSequenceThroughput(firstOutput: 1, secondOutput: 2) { chain($0, $1) diff --git a/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift b/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift index 16531d9b..c8991bdd 100644 --- a/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift +++ b/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift @@ -56,6 +56,64 @@ final class _ThroughputMetric: NSObject, XCTMetric, @unchecked Sendable { } extension XCTestCase { + public func measureChannelThroughput(output: @escaping @autoclosure () -> Output) async { + let metric = _ThroughputMetric() + let sampleTime: Double = 0.1 + + measure(metrics: [metric]) { + let channel = AsyncChannel() + + let exp = self.expectation(description: "Finished") + let iterTask = Task { + var eventCount = 0 + for try await _ in channel { + eventCount += 1 + } + metric.eventCount = eventCount + exp.fulfill() + return eventCount + } + let sendTask = Task { + while !Task.isCancelled { + await channel.send(output()) + } + } + usleep(UInt32(sampleTime * Double(USEC_PER_SEC))) + iterTask.cancel() + sendTask.cancel() + self.wait(for: [exp], timeout: sampleTime * 2) + } + } + + public func measureThrowingChannelThroughput(output: @escaping @autoclosure () -> Output) async { + let metric = _ThroughputMetric() + let sampleTime: Double = 0.1 + + measure(metrics: [metric]) { + let channel = AsyncThrowingChannel() + + let exp = self.expectation(description: "Finished") + let iterTask = Task { + var eventCount = 0 + for try await _ in channel { + eventCount += 1 + } + metric.eventCount = eventCount + exp.fulfill() + return eventCount + } + let sendTask = Task { + while !Task.isCancelled { + await channel.send(output()) + } + } + usleep(UInt32(sampleTime * Double(USEC_PER_SEC))) + iterTask.cancel() + sendTask.cancel() + self.wait(for: [exp], timeout: sampleTime * 2) + } + } + public func measureSequenceThroughput( output: @autoclosure () -> Output, _ sequenceBuilder: (InfiniteAsyncSequence) -> S) async where S: Sendable { let metric = _ThroughputMetric() let sampleTime: Double = 0.1 diff --git a/Tests/AsyncAlgorithmsTests/TestChannel.swift b/Tests/AsyncAlgorithmsTests/TestChannel.swift index 3d53fe3a..b181bddd 100644 --- a/Tests/AsyncAlgorithmsTests/TestChannel.swift +++ b/Tests/AsyncAlgorithmsTests/TestChannel.swift @@ -13,265 +13,164 @@ import XCTest import AsyncAlgorithms final class TestChannel: XCTestCase { - func test_asyncChannel_delivers_values_when_two_producers_and_two_consumers() async { - let (sentFromProducer1, sentFromProducer2) = ("test1", "test2") - let expected = Set([sentFromProducer1, sentFromProducer2]) + func test_asyncChannel_delivers_elements_when_several_producers_and_several_consumers() async { + let sents = (1...10) + let expected = Set(sents) - let channel = AsyncChannel() - Task { - await channel.send(sentFromProducer1) - } - Task { - await channel.send(sentFromProducer2) - } - - let t: Task = Task { - var iterator = channel.makeAsyncIterator() - let value = await iterator.next() - return value - } - var iterator = channel.makeAsyncIterator() - - let (collectedFromConsumer1, collectedFromConsumer2) = (await t.value, await iterator.next()) - let collected = Set([collectedFromConsumer1, collectedFromConsumer2]) + // Given: an AsyncChannel + let sut = AsyncChannel() + // When: sending elements from tasks in a group + Task { + await withTaskGroup(of: Void.self) { group in + for sent in sents { + group.addTask { + await sut.send(sent) + } + } + } + } + + // When: receiving those elements from tasks in a group + let collected = await withTaskGroup(of: Int.self, returning: Set.self) { group in + for _ in sents { + group.addTask { + var iterator = sut.makeAsyncIterator() + let received = await iterator.next() + return received! + } + } + + var collected = Set() + for await element in group { + collected.update(with: element) + } + return collected + } + + // Then: all elements are sent and received XCTAssertEqual(collected, expected) } - - func test_asyncThrowingChannel_delivers_values_when_two_producers_and_two_consumers() async throws { - let (sentFromProducer1, sentFromProducer2) = ("test1", "test2") - let expected = Set([sentFromProducer1, sentFromProducer2]) - let channel = AsyncThrowingChannel() - Task { - await channel.send("test1") - } - Task { - await channel.send("test2") - } - - let t: Task = Task { - var iterator = channel.makeAsyncIterator() - let value = try await iterator.next() - return value - } - var iterator = channel.makeAsyncIterator() + func test_asyncChannel_resumes_producers_and_discards_additional_elements_when_finish_is_called() async { + // Given: an AsyncChannel + let sut = AsyncChannel() - let (collectedFromConsumer1, collectedFromConsumer2) = (try await t.value, try await iterator.next()) - let collected = Set([collectedFromConsumer1, collectedFromConsumer2]) - - XCTAssertEqual(collected, expected) - } - - func test_asyncThrowingChannel_throws_and_discards_additional_sent_values_when_fail_is_called() async { - let sendImmediatelyResumes = expectation(description: "Send immediately resumes after fail") - - let channel = AsyncThrowingChannel() - channel.fail(Failure()) - - var iterator = channel.makeAsyncIterator() - do { - let _ = try await iterator.next() - XCTFail("The AsyncThrowingChannel should have thrown") - } catch { - XCTAssertEqual(error as? Failure, Failure()) + // Given: 2 suspended send operations + let task1 = Task { + await sut.send(1) } - do { - let pastFailure = try await iterator.next() - XCTAssertNil(pastFailure) - } catch { - XCTFail("The AsyncThrowingChannel should not fail when failure has already been fired") + let task2 = Task { + await sut.send(2) } - await channel.send("send") - sendImmediatelyResumes.fulfill() - wait(for: [sendImmediatelyResumes], timeout: 1.0) - } - - func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async { - let channel = AsyncChannel() - let complete = ManagedCriticalState(false) - let finished = expectation(description: "finished") + // When: finishing the channel + sut.finish() - Task { - channel.finish() - complete.withCriticalRegion { $0 = true } - finished.fulfill() - } + // Then: the send operations are resumed + _ = await (task1.value, task2.value) - let valueFromConsumer1 = ManagedCriticalState(nil) - let valueFromConsumer2 = ManagedCriticalState(nil) + // When: sending an extra value + await sut.send(3) - let received = expectation(description: "received") - received.expectedFulfillmentCount = 2 + // Then: the operation and the iteration are immediately resumed + var collected = [Int]() + for await element in sut { + collected.append(element) + } + XCTAssertTrue(collected.isEmpty) + } - let pastEnd = expectation(description: "pastEnd") - pastEnd.expectedFulfillmentCount = 2 + func test_asyncChannel_resumes_consumers_when_finish_is_called() async { + // Given: an AsyncChannel + let sut = AsyncChannel() - Task { - var iterator = channel.makeAsyncIterator() - let ending = await iterator.next() - valueFromConsumer1.withCriticalRegion { $0 = ending } - received.fulfill() - let item = await iterator.next() - XCTAssertNil(item) - pastEnd.fulfill() + // Given: 2 suspended iterations + let task1 = Task { + var iterator = sut.makeAsyncIterator() + return await iterator.next() } - Task { - var iterator = channel.makeAsyncIterator() - let ending = await iterator.next() - valueFromConsumer2.withCriticalRegion { $0 = ending } - received.fulfill() - let item = await iterator.next() - XCTAssertNil(item) - pastEnd.fulfill() + let task2 = Task { + var iterator = sut.makeAsyncIterator() + return await iterator.next() } - - wait(for: [finished, received], timeout: 1.0) - XCTAssertTrue(complete.withCriticalRegion { $0 }) - XCTAssertEqual(valueFromConsumer1.withCriticalRegion { $0 }, nil) - XCTAssertEqual(valueFromConsumer2.withCriticalRegion { $0 }, nil) + // When: finishing the channel + sut.finish() - wait(for: [pastEnd], timeout: 1.0) - let additionalSend = expectation(description: "additional send") - Task { - await channel.send("test") - additionalSend.fulfill() - } - wait(for: [additionalSend], timeout: 1.0) - } + // Then: the iterations are resumed with nil values + let (collected1, collected2) = await (task1.value, task2.value) + XCTAssertNil(collected1) + XCTAssertNil(collected2) - func test_asyncThrowingChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async { - let channel = AsyncThrowingChannel() - let complete = ManagedCriticalState(false) - let finished = expectation(description: "finished") - - Task { - channel.finish() - complete.withCriticalRegion { $0 = true } - finished.fulfill() - } + // When: requesting a next value + var iterator = sut.makeAsyncIterator() + let pastEnd = await iterator.next() - let valueFromConsumer1 = ManagedCriticalState(nil) - let valueFromConsumer2 = ManagedCriticalState(nil) + // Then: the past end is nil + XCTAssertNil(pastEnd) + } - let received = expectation(description: "received") - received.expectedFulfillmentCount = 2 + func test_asyncChannel_resumes_producer_when_task_is_cancelled() async { + let send1IsResumed = expectation(description: "The first send operation is resumed") - let pastEnd = expectation(description: "pastEnd") - pastEnd.expectedFulfillmentCount = 2 + // Given: an AsyncChannel + let sut = AsyncChannel() - Task { - var iterator = channel.makeAsyncIterator() - let ending = try await iterator.next() - valueFromConsumer1.withCriticalRegion { $0 = ending } - received.fulfill() - let item = try await iterator.next() - XCTAssertNil(item) - pastEnd.fulfill() + // Given: 2 suspended send operations + let task1 = Task { + await sut.send(1) + send1IsResumed.fulfill() } - Task { - var iterator = channel.makeAsyncIterator() - let ending = try await iterator.next() - valueFromConsumer2.withCriticalRegion { $0 = ending } - received.fulfill() - let item = try await iterator.next() - XCTAssertNil(item) - pastEnd.fulfill() + let task2 = Task { + await sut.send(2) } - wait(for: [finished, received], timeout: 1.0) + // When: cancelling the first task + task1.cancel() - XCTAssertTrue(complete.withCriticalRegion { $0 }) - XCTAssertEqual(valueFromConsumer1.withCriticalRegion { $0 }, nil) - XCTAssertEqual(valueFromConsumer2.withCriticalRegion { $0 }, nil) + // Then: the first sending operation is resumed + wait(for: [send1IsResumed], timeout: 1.0) - wait(for: [pastEnd], timeout: 1.0) - let additionalSend = expectation(description: "additional send") - Task { - await channel.send("test") - additionalSend.fulfill() - } - wait(for: [additionalSend], timeout: 1.0) - } - - func test_asyncChannel_ends_iterator_when_task_is_cancelled() async { - let channel = AsyncChannel() - let ready = expectation(description: "ready") - let task: Task = Task { - var iterator = channel.makeAsyncIterator() - ready.fulfill() - return await iterator.next() - } - wait(for: [ready], timeout: 1.0) - task.cancel() - let value = await task.value - XCTAssertNil(value) - } + // When: collecting elements + var iterator = sut.makeAsyncIterator() + let collected = await iterator.next() - func test_asyncThrowingChannel_ends_iterator_when_task_is_cancelled() async throws { - let channel = AsyncThrowingChannel() - let ready = expectation(description: "ready") - let task: Task = Task { - var iterator = channel.makeAsyncIterator() - ready.fulfill() - return try await iterator.next() - } - wait(for: [ready], timeout: 1.0) - task.cancel() - let value = try await task.value - XCTAssertNil(value) + // Then: the second operation resumes and the iteration receives the element + _ = await task2.value + XCTAssertEqual(collected, 2) } - - func test_asyncChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async { - let channel = AsyncChannel() - let notYetDone = expectation(description: "not yet done") - notYetDone.isInverted = true - let done = expectation(description: "done") - let task = Task { - await channel.send(1) - notYetDone.fulfill() - done.fulfill() - } - - Task { - await channel.send(2) - } - wait(for: [notYetDone], timeout: 0.1) - task.cancel() - wait(for: [done], timeout: 1.0) + func test_asyncChannel_resumes_consumer_when_task_is_cancelled() async { + // Given: an AsyncChannel + let sut = AsyncChannel() - var iterator = channel.makeAsyncIterator() - let received = await iterator.next() - XCTAssertEqual(received, 2) - } - - func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async throws { - let channel = AsyncThrowingChannel() - let notYetDone = expectation(description: "not yet done") - notYetDone.isInverted = true - let done = expectation(description: "done") - let task = Task { - await channel.send(1) - notYetDone.fulfill() - done.fulfill() + // Given: 2 suspended iterations + let task1 = Task { + var iterator = sut.makeAsyncIterator() + return await iterator.next() } - Task { - await channel.send(2) + let task2 = Task { + var iterator = sut.makeAsyncIterator() + return await iterator.next() } - wait(for: [notYetDone], timeout: 0.1) - task.cancel() - wait(for: [done], timeout: 1.0) + // When: cancelling the first task + task1.cancel() + + // Then: the iteration is resumed with a nil element + let collected1 = await task1.value + XCTAssertNil(collected1) + + // When: sending an element + await sut.send(1) - var iterator = channel.makeAsyncIterator() - let received = try await iterator.next() - XCTAssertEqual(received, 2) + // Then: the second iteration is resumed with the element + let collected2 = await task2.value + XCTAssertEqual(collected2, 1) } } diff --git a/Tests/AsyncAlgorithmsTests/TestThrowingChannel.swift b/Tests/AsyncAlgorithmsTests/TestThrowingChannel.swift new file mode 100644 index 00000000..7dd60f18 --- /dev/null +++ b/Tests/AsyncAlgorithmsTests/TestThrowingChannel.swift @@ -0,0 +1,313 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import XCTest +import AsyncAlgorithms + +final class TestThrowingChannel: XCTestCase { + func test_asyncThrowingChannel_delivers_elements_when_several_producers_and_several_consumers() async throws { + let sents = (1...10) + let expected = Set(sents) + + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // When: sending elements from tasks in a group + Task { + await withTaskGroup(of: Void.self) { group in + for sent in sents { + group.addTask { + await sut.send(sent) + } + } + } + } + + // When: receiving those elements from tasks in a group + let collected = try await withThrowingTaskGroup(of: Int.self, returning: Set.self) { group in + for _ in sents { + group.addTask { + var iterator = sut.makeAsyncIterator() + let received = try await iterator.next() + return received! + } + } + + var collected = Set() + for try await element in group { + collected.update(with: element) + } + return collected + } + + // Then: all elements are sent and received + XCTAssertEqual(collected, expected) + } + + func test_asyncThrowingChannel_resumes_producers_and_discards_additional_elements_when_finish_is_called() async throws { + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: 2 suspended send operations + let task1 = Task { + await sut.send(1) + } + + let task2 = Task { + await sut.send(2) + } + + // When: finishing the channel + sut.finish() + + // Then: the send operations are resumed + _ = await (task1.value, task2.value) + + // When: sending an extra value + await sut.send(3) + + // Then: the operation and the iteration are immediately resumed + var collected = [Int]() + for try await element in sut { + collected.append(element) + } + XCTAssertTrue(collected.isEmpty) + } + + func test_asyncThrowingChannel_resumes_producers_and_discards_additional_elements_when_fail_is_called() async throws { + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: 2 suspended send operations + let task1 = Task { + await sut.send(1) + } + + let task2 = Task { + await sut.send(2) + } + + // When: failing the channel + sut.fail(Failure()) + + // Then: the send operations are resumed + _ = await (task1.value, task2.value) + + // When: sending an extra value + await sut.send(3) + + // Then: the send operation is resumed + // Then: the iteration is resumed with a failure + var collected = [Int]() + do { + for try await element in sut { + collected.append(element) + } + } catch { + XCTAssertTrue(collected.isEmpty) + XCTAssertEqual(error as? Failure, Failure()) + } + + // When: requesting a next value + var iterator = sut.makeAsyncIterator() + let pastFailure = try await iterator.next() + + // Then: the past failure is nil + XCTAssertNil(pastFailure) + } + + func test_asyncThrowingChannel_resumes_consumers_when_finish_is_called() async throws { + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: 2 suspended iterations + let task1 = Task { + var iterator = sut.makeAsyncIterator() + return try await iterator.next() + } + + let task2 = Task { + var iterator = sut.makeAsyncIterator() + return try await iterator.next() + } + + + // When: finishing the channel + sut.finish() + + // Then: the iterations are resumed with nil values + let (collected1, collected2) = try await (task1.value, task2.value) + XCTAssertNil(collected1) + XCTAssertNil(collected2) + + // When: requesting a next value + var iterator = sut.makeAsyncIterator() + let pastEnd = try await iterator.next() + + // Then: the past end is nil + XCTAssertNil(pastEnd) + } + + func test_asyncThrowingChannel_resumes_consumer_when_fail_is_called() async throws { + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: suspended iteration + let task = Task { + var iterator = sut.makeAsyncIterator() + + do { + _ = try await iterator.next() + XCTFail("We expect the above call to throw") + } catch { + XCTAssertEqual(error as? Failure, Failure()) + } + + return try await iterator.next() + } + + // When: failing the channel + sut.fail(Failure()) + + // Then: the iterations are resumed with the error and the next element is nil + do { + let collected = try await task.value + XCTAssertNil(collected) + } catch { + XCTFail("The task should not fail, the past failure element should be nil, not a failure.") + } + } + + func test_asyncThrowingChannel_resumes_consumers_when_fail_is_called() async throws { + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: 2 suspended iterations + let task1 = Task { + var iterator = sut.makeAsyncIterator() + return try await iterator.next() + } + + let task2 = Task { + var iterator = sut.makeAsyncIterator() + return try await iterator.next() + } + + // When: failing the channel + sut.fail(Failure()) + + // Then: the iterations are resumed with the error + do { + _ = try await task1.value + } catch { + XCTAssertEqual(error as? Failure, Failure()) + } + + do { + _ = try await task2.value + } catch { + XCTAssertEqual(error as? Failure, Failure()) + } + + // When: requesting a next value + var iterator = sut.makeAsyncIterator() + let pastFailure = try await iterator.next() + + // Then: the past failure is nil + XCTAssertNil(pastFailure) + } + + func test_asyncThrowingChannel_resumes_consumer_with_error_when_already_failed() async throws { + // Given: an AsyncThrowingChannel that is failed + let sut = AsyncThrowingChannel() + sut.fail(Failure()) + + var iterator = sut.makeAsyncIterator() + + // When: requesting the next element + do { + _ = try await iterator.next() + } catch { + // Then: the iteration is resumed with the error + XCTAssertEqual(error as? Failure, Failure()) + } + + // When: requesting the next element past failure + do { + let pastFailure = try await iterator.next() + // Then: the past failure is nil + XCTAssertNil(pastFailure) + } catch { + XCTFail("The past failure should not throw") + } + } + + func test_asyncThrowingChannel_resumes_producer_when_task_is_cancelled() async throws { + let send1IsResumed = expectation(description: "The first send operation is resumed") + + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: 2 suspended send operations + let task1 = Task { + await sut.send(1) + send1IsResumed.fulfill() + } + + let task2 = Task { + await sut.send(2) + } + + // When: cancelling the first task + task1.cancel() + + // Then: the first sending operation is resumed + wait(for: [send1IsResumed], timeout: 1.0) + + // When: collecting elements + var iterator = sut.makeAsyncIterator() + let collected = try await iterator.next() + + // Then: the second operation resumes and the iteration receives the element + _ = await task2.value + XCTAssertEqual(collected, 2) + } + + func test_asyncThrowingChannel_resumes_consumer_when_task_is_cancelled() async throws { + // Given: an AsyncThrowingChannel + let sut = AsyncThrowingChannel() + + // Given: 2 suspended iterations + let task1 = Task { + var iterator = sut.makeAsyncIterator() + return try await iterator.next() + } + + let task2 = Task { + var iterator = sut.makeAsyncIterator() + return try await iterator.next() + } + + // When: cancelling the first task + task1.cancel() + + // Then: the first iteration is resumed with a nil element + let collected1 = try await task1.value + XCTAssertNil(collected1) + + // When: sending an element + await sut.send(1) + + // Then: the second iteration is resumed with the element + let collected2 = try await task2.value + XCTAssertEqual(collected2, 1) + } +}