Skip to content

Commit dd47f34

Browse files
committed
EventStreams: add ability to customise the terminating byte sequence of
a stream
1 parent f95974c commit dd47f34

File tree

1 file changed

+66
-8
lines changed

1 file changed

+66
-8
lines changed

Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,19 @@ where Upstream.Element == ArraySlice<UInt8> {
2828
/// The upstream sequence.
2929
private let upstream: Upstream
3030

31+
/// An optional closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
32+
/// - Parameter: A byte chunk.
33+
/// - Returns: `True` if the given byte sequence is the terminating byte sequence defined by the API.
34+
private let terminate: (@Sendable (ArraySlice<UInt8>) -> Bool)?
35+
3136
/// Creates a new sequence.
32-
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
33-
public init(upstream: Upstream) { self.upstream = upstream }
37+
/// - Parameters:
38+
/// - upstream: The upstream sequence of arbitrary byte chunks.
39+
/// - terminate: An optional closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
40+
public init(upstream: Upstream, terminate: (@Sendable (ArraySlice<UInt8>) -> Bool)?) {
41+
self.upstream = upstream
42+
self.terminate = terminate
43+
}
3444
}
3545

3646
extension ServerSentEventsDeserializationSequence: AsyncSequence {
@@ -48,6 +58,17 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
4858
/// The state machine of the iterator.
4959
var stateMachine: StateMachine = .init()
5060

61+
/// An optional closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
62+
/// - Parameter: A byte chunk.
63+
/// - Returns: `True` if the given byte sequence is the terminating byte sequence defined by the API.
64+
let terminate: ((ArraySlice<UInt8>) -> Bool)?
65+
66+
init(upstream: any AsyncIteratorProtocol, terminate: ((ArraySlice<UInt8>) -> Bool)?) {
67+
self.upstream = upstream as! UpstreamIterator
68+
self.stateMachine = .init(terminate: terminate)
69+
self.terminate = terminate
70+
}
71+
5172
/// Asynchronously advances to the next element and returns it, or ends the
5273
/// sequence if there is no next element.
5374
public mutating func next() async throws -> ServerSentEvent? {
@@ -70,7 +91,7 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
7091
/// Creates the asynchronous iterator that produces elements of this
7192
/// asynchronous sequence.
7293
public func makeAsyncIterator() -> Iterator<Upstream.AsyncIterator> {
73-
Iterator(upstream: upstream.makeAsyncIterator())
94+
Iterator(upstream: upstream.makeAsyncIterator(), terminate: terminate)
7495
}
7596
}
7697

@@ -79,26 +100,34 @@ extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
79100
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
80101
///
81102
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
103+
/// - Parameter: An optional closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
82104
/// - Returns: A sequence that provides the events.
83-
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
105+
public func asDecodedServerSentEvents(terminate: (@Sendable (ArraySlice<UInt8>) -> Bool)? = nil) -> ServerSentEventsDeserializationSequence<
84106
ServerSentEventsLineDeserializationSequence<Self>
85-
> { .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self)) }
107+
> { .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self), terminate: terminate) }
108+
109+
/// Convenience function for `asDecodedServerSentEvents` that directly receives the terminating byte sequence.
110+
public func asDecodedServerSentEvents(terminatingSequence: ArraySlice<UInt8>) -> ServerSentEventsDeserializationSequence<
111+
ServerSentEventsLineDeserializationSequence<Self>
112+
> { asDecodedServerSentEvents(terminate: { incomingSequence in return incomingSequence == terminatingSequence }) }
86113

87114
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
88115
///
89116
/// Use this method if the event's `data` field is JSON.
90117
/// - Parameters:
91118
/// - dataType: The type to decode the JSON data into.
92119
/// - decoder: The JSON decoder to use.
120+
/// - terminate: An optional closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
93121
/// - Returns: A sequence that provides the events with the decoded JSON data.
94122
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
95123
of dataType: JSONDataType.Type = JSONDataType.self,
96-
decoder: JSONDecoder = .init()
124+
decoder: JSONDecoder = .init(),
125+
terminate: (@Sendable (ArraySlice<UInt8>) -> Bool)? = nil
97126
) -> AsyncThrowingMapSequence<
98127
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
99128
ServerSentEventWithJSONData<JSONDataType>
100129
> {
101-
asDecodedServerSentEvents()
130+
asDecodedServerSentEvents(terminate: terminate)
102131
.map { event in
103132
ServerSentEventWithJSONData(
104133
event: event.event,
@@ -110,6 +139,19 @@ extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
110139
)
111140
}
112141
}
142+
143+
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
144+
of dataType: JSONDataType.Type = JSONDataType.self,
145+
decoder: JSONDecoder = .init(),
146+
terminatingData: ArraySlice<UInt8>
147+
) -> AsyncThrowingMapSequence<
148+
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
149+
ServerSentEventWithJSONData<JSONDataType>
150+
> {
151+
asDecodedServerSentEventsWithJSONData(of: dataType, decoder: decoder) { incomingData in
152+
terminatingData == incomingData
153+
}
154+
}
113155
}
114156

115157
extension ServerSentEventsDeserializationSequence.Iterator {
@@ -133,8 +175,16 @@ extension ServerSentEventsDeserializationSequence.Iterator {
133175
/// The current state of the state machine.
134176
private(set) var state: State
135177

178+
179+
/// An optional closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
180+
/// - Parameter: A sequence of byte chunks.
181+
/// - Returns: `True` if the given byte sequence is the terminating byte sequence defined by the API.
182+
let terminate: ((ArraySlice<UInt8>) -> Bool)?
183+
136184
/// Creates a new state machine.
137-
init() { self.state = .accumulatingEvent(.init(), buffer: []) }
185+
init(terminate: ((ArraySlice<UInt8>) -> Bool)? = nil) {
186+
self.state = .accumulatingEvent(.init(), buffer: [])
187+
self.terminate = terminate}
138188

139189
/// An action returned by the `next` method.
140190
enum NextAction {
@@ -165,6 +215,14 @@ extension ServerSentEventsDeserializationSequence.Iterator {
165215
state = .accumulatingEvent(.init(), buffer: buffer)
166216
// If the last character of data is a newline, strip it.
167217
if event.data?.hasSuffix("\n") ?? false { event.data?.removeLast() }
218+
219+
if let terminate = terminate {
220+
if let data = event.data {
221+
if terminate(ArraySlice(Data(data.utf8))) {
222+
return .returnNil
223+
}
224+
}
225+
}
168226
return .emitEvent(event)
169227
}
170228
if line.first! == ASCII.colon {

0 commit comments

Comments
 (0)