diff --git a/Sources/OpenAPIRuntime/Deprecated/Deprecated.swift b/Sources/OpenAPIRuntime/Deprecated/Deprecated.swift index c9f538ef..2ce41750 100644 --- a/Sources/OpenAPIRuntime/Deprecated/Deprecated.swift +++ b/Sources/OpenAPIRuntime/Deprecated/Deprecated.swift @@ -59,3 +59,37 @@ extension Configuration { ) } } + +extension AsyncSequence where Element == ArraySlice, Self: Sendable { + /// Returns another sequence that decodes each event's data as the provided type using the provided decoder. + /// + /// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`. + /// - Returns: A sequence that provides the events. + @available(*, deprecated, renamed: "asDecodedServerSentEvents(while:)") @_disfavoredOverload + public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence< + ServerSentEventsLineDeserializationSequence + > { asDecodedServerSentEvents(while: { _ in true }) } + /// Returns another sequence that decodes each event's data as the provided type using the provided decoder. + /// + /// Use this method if the event's `data` field is JSON. + /// - Parameters: + /// - dataType: The type to decode the JSON data into. + /// - decoder: The JSON decoder to use. + /// - Returns: A sequence that provides the events with the decoded JSON data. + @available(*, deprecated, renamed: "asDecodedServerSentEventsWithJSONData(of:decoder:while:)") @_disfavoredOverload + public func asDecodedServerSentEventsWithJSONData( + of dataType: JSONDataType.Type = JSONDataType.self, + decoder: JSONDecoder = .init() + ) -> AsyncThrowingMapSequence< + ServerSentEventsDeserializationSequence>, + ServerSentEventWithJSONData + > { asDecodedServerSentEventsWithJSONData(of: dataType, decoder: decoder, while: { _ in true }) } +} + +extension ServerSentEventsDeserializationSequence { + /// Creates a new sequence. + /// - Parameter upstream: The upstream sequence of arbitrary byte chunks. + @available(*, deprecated, renamed: "init(upstream:while:)") @_disfavoredOverload public init(upstream: Upstream) { + self.init(upstream: upstream, while: { _ in true }) + } +} diff --git a/Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift b/Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift index 421e5319..ff374b39 100644 --- a/Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift +++ b/Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift @@ -28,9 +28,19 @@ where Upstream.Element == ArraySlice { /// The upstream sequence. private let upstream: Upstream + /// A closure that determines whether the given byte chunk should be forwarded to the consumer. + /// - Parameter: A byte chunk. + /// - Returns: `true` if the byte chunk should be forwarded, `false` if this byte chunk is the terminating sequence. + private let predicate: @Sendable (ArraySlice) -> Bool + /// Creates a new sequence. - /// - Parameter upstream: The upstream sequence of arbitrary byte chunks. - public init(upstream: Upstream) { self.upstream = upstream } + /// - Parameters: + /// - upstream: The upstream sequence of arbitrary byte chunks. + /// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer. + public init(upstream: Upstream, while predicate: @escaping @Sendable (ArraySlice) -> Bool) { + self.upstream = upstream + self.predicate = predicate + } } extension ServerSentEventsDeserializationSequence: AsyncSequence { @@ -46,7 +56,16 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence { var upstream: UpstreamIterator /// The state machine of the iterator. - var stateMachine: StateMachine = .init() + var stateMachine: StateMachine + + /// Creates a new sequence. + /// - Parameters: + /// - upstream: The upstream sequence of arbitrary byte chunks. + /// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer. + init(upstream: UpstreamIterator, while predicate: @escaping ((ArraySlice) -> Bool)) { + self.upstream = upstream + self.stateMachine = .init(while: predicate) + } /// Asynchronously advances to the next element and returns it, or ends the /// sequence if there is no next element. @@ -70,7 +89,7 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence { /// Creates the asynchronous iterator that produces elements of this /// asynchronous sequence. public func makeAsyncIterator() -> Iterator { - Iterator(upstream: upstream.makeAsyncIterator()) + Iterator(upstream: upstream.makeAsyncIterator(), while: predicate) } } @@ -79,26 +98,30 @@ extension AsyncSequence where Element == ArraySlice, Self: Sendable { /// Returns another sequence that decodes each event's data as the provided type using the provided decoder. /// /// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`. + /// - Parameter: A closure that determines whether the given byte chunk should be forwarded to the consumer. /// - Returns: A sequence that provides the events. - public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence< - ServerSentEventsLineDeserializationSequence - > { .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self)) } - + public func asDecodedServerSentEvents( + while predicate: @escaping @Sendable (ArraySlice) -> Bool = { _ in true } + ) -> ServerSentEventsDeserializationSequence> { + .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self), while: predicate) + } /// Returns another sequence that decodes each event's data as the provided type using the provided decoder. /// /// Use this method if the event's `data` field is JSON. /// - Parameters: /// - dataType: The type to decode the JSON data into. /// - decoder: The JSON decoder to use. + /// - predicate: A closure that determines whether the given byte sequence is the terminating byte sequence defined by the API. /// - Returns: A sequence that provides the events with the decoded JSON data. public func asDecodedServerSentEventsWithJSONData( of dataType: JSONDataType.Type = JSONDataType.self, - decoder: JSONDecoder = .init() + decoder: JSONDecoder = .init(), + while predicate: @escaping @Sendable (ArraySlice) -> Bool = { _ in true } ) -> AsyncThrowingMapSequence< ServerSentEventsDeserializationSequence>, ServerSentEventWithJSONData > { - asDecodedServerSentEvents() + asDecodedServerSentEvents(while: predicate) .map { event in ServerSentEventWithJSONData( event: event.event, @@ -118,10 +141,10 @@ extension ServerSentEventsDeserializationSequence.Iterator { struct StateMachine { /// The possible states of the state machine. - enum State: Hashable { + enum State { /// Accumulating an event, which hasn't been emitted yet. - case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice]) + case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice], predicate: (ArraySlice) -> Bool) /// Finished, the terminal state. case finished @@ -134,7 +157,9 @@ extension ServerSentEventsDeserializationSequence.Iterator { private(set) var state: State /// Creates a new state machine. - init() { self.state = .accumulatingEvent(.init(), buffer: []) } + init(while predicate: @escaping (ArraySlice) -> Bool) { + self.state = .accumulatingEvent(.init(), buffer: [], predicate: predicate) + } /// An action returned by the `next` method. enum NextAction { @@ -156,20 +181,24 @@ extension ServerSentEventsDeserializationSequence.Iterator { /// - Returns: An action to perform. mutating func next() -> NextAction { switch state { - case .accumulatingEvent(var event, var buffer): + case .accumulatingEvent(var event, var buffer, let predicate): guard let line = buffer.first else { return .needsMore } state = .mutating buffer.removeFirst() if line.isEmpty { // Dispatch the accumulated event. - state = .accumulatingEvent(.init(), buffer: buffer) // If the last character of data is a newline, strip it. if event.data?.hasSuffix("\n") ?? false { event.data?.removeLast() } + if let data = event.data, !predicate(ArraySlice(data.utf8)) { + state = .finished + return .returnNil + } + state = .accumulatingEvent(.init(), buffer: buffer, predicate: predicate) return .emitEvent(event) } if line.first! == ASCII.colon { // A comment, skip this line. - state = .accumulatingEvent(event, buffer: buffer) + state = .accumulatingEvent(event, buffer: buffer, predicate: predicate) return .noop } // Parse the field name and value. @@ -193,7 +222,7 @@ extension ServerSentEventsDeserializationSequence.Iterator { } guard let value else { // An unknown type of event, skip. - state = .accumulatingEvent(event, buffer: buffer) + state = .accumulatingEvent(event, buffer: buffer, predicate: predicate) return .noop } // Process the field. @@ -214,11 +243,11 @@ extension ServerSentEventsDeserializationSequence.Iterator { } default: // An unknown or invalid field, skip. - state = .accumulatingEvent(event, buffer: buffer) + state = .accumulatingEvent(event, buffer: buffer, predicate: predicate) return .noop } // Processed the field, continue. - state = .accumulatingEvent(event, buffer: buffer) + state = .accumulatingEvent(event, buffer: buffer, predicate: predicate) return .noop case .finished: return .returnNil case .mutating: preconditionFailure("Invalid state") @@ -240,11 +269,11 @@ extension ServerSentEventsDeserializationSequence.Iterator { /// - Returns: An action to perform. mutating func receivedValue(_ value: ArraySlice?) -> ReceivedValueAction { switch state { - case .accumulatingEvent(let event, var buffer): + case .accumulatingEvent(let event, var buffer, let predicate): if let value { state = .mutating buffer.append(value) - state = .accumulatingEvent(event, buffer: buffer) + state = .accumulatingEvent(event, buffer: buffer, predicate: predicate) return .noop } else { // If no value is received, drop the existing event on the floor. diff --git a/Tests/OpenAPIRuntimeTests/EventStreams/Test_ServerSentEventsDecoding.swift b/Tests/OpenAPIRuntimeTests/EventStreams/Test_ServerSentEventsDecoding.swift index 79d645a5..2a15b932 100644 --- a/Tests/OpenAPIRuntimeTests/EventStreams/Test_ServerSentEventsDecoding.swift +++ b/Tests/OpenAPIRuntimeTests/EventStreams/Test_ServerSentEventsDecoding.swift @@ -16,10 +16,14 @@ import XCTest import Foundation final class Test_ServerSentEventsDecoding: Test_Runtime { - func _test(input: String, output: [ServerSentEvent], file: StaticString = #filePath, line: UInt = #line) - async throws - { - let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents() + func _test( + input: String, + output: [ServerSentEvent], + file: StaticString = #filePath, + line: UInt = #line, + while predicate: @escaping @Sendable (ArraySlice) -> Bool = { _ in true } + ) async throws { + let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents(while: predicate) let events = try await [ServerSentEvent](collecting: sequence) XCTAssertEqual(events.count, output.count, file: file, line: line) for (index, linePair) in zip(events, output).enumerated() { @@ -27,6 +31,7 @@ final class Test_ServerSentEventsDecoding: Test_Runtime { XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line) } } + func test() async throws { // Simple event. try await _test( @@ -83,15 +88,32 @@ final class Test_ServerSentEventsDecoding: Test_Runtime { .init(id: "123", data: "This is a message with an ID."), ] ) + + try await _test( + input: #""" + data: hello + data: world + + data: [DONE] + + data: hello2 + data: world2 + + + """#, + output: [.init(data: "hello\nworld")], + while: { incomingData in incomingData != ArraySlice(Data("[DONE]".utf8)) } + ) } func _testJSONData( input: String, output: [ServerSentEventWithJSONData], file: StaticString = #filePath, - line: UInt = #line + line: UInt = #line, + while predicate: @escaping @Sendable (ArraySlice) -> Bool = { _ in true } ) async throws { let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)) - .asDecodedServerSentEventsWithJSONData(of: JSONType.self) + .asDecodedServerSentEventsWithJSONData(of: JSONType.self, while: predicate) let events = try await [ServerSentEventWithJSONData](collecting: sequence) XCTAssertEqual(events.count, output.count, file: file, line: line) for (index, linePair) in zip(events, output).enumerated() { @@ -99,6 +121,7 @@ final class Test_ServerSentEventsDecoding: Test_Runtime { XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line) } } + struct TestEvent: Decodable, Hashable, Sendable { var index: Int } func testJSONData() async throws { // Simple event. @@ -121,6 +144,33 @@ final class Test_ServerSentEventsDecoding: Test_Runtime { .init(event: "event2", data: TestEvent(index: 2), id: "2"), ] ) + + try await _testJSONData( + input: #""" + event: event1 + id: 1 + data: {"index":1} + + event: event2 + id: 2 + data: { + data: "index": 2 + data: } + + data: [DONE] + + event: event3 + id: 1 + data: {"index":3} + + + """#, + output: [ + .init(event: "event1", data: TestEvent(index: 1), id: "1"), + .init(event: "event2", data: TestEvent(index: 2), id: "2"), + ], + while: { incomingData in incomingData != ArraySlice(Data("[DONE]".utf8)) } + ) } }