Skip to content

Commit 4809c6d

Browse files
committed
refactor
motivation: safer concurrency, better tests changes: * make http client thread safe * add concurrency tests * refactor test * make mock server more ribust
1 parent 95cf94e commit 4809c6d

12 files changed

+240
-137
lines changed

Sources/MockServer/main.swift

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,41 +64,42 @@ internal final class HTTPHandler: ChannelInboundHandler {
6464
private let mode: Mode
6565
private let keepAlive: Bool
6666

67-
private var requestHead: HTTPRequestHead!
68-
private var requestBody: ByteBuffer?
67+
private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()
6968

7069
public init(logger: Logger, keepAlive: Bool, mode: Mode) {
7170
self.logger = logger
72-
self.mode = mode
7371
self.keepAlive = keepAlive
72+
self.mode = mode
7473
}
7574

7675
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
7776
let requestPart = unwrapInboundIn(data)
7877

7978
switch requestPart {
8079
case .head(let head):
81-
self.requestHead = head
82-
self.requestBody?.clear()
80+
self.pending.append((head: head, body: nil))
8381
case .body(var buffer):
84-
if self.requestBody == nil {
85-
self.requestBody = buffer
82+
var request = self.pending.removeFirst()
83+
if request.body == nil {
84+
request.body = buffer
8685
} else {
87-
self.requestBody!.writeBuffer(&buffer)
86+
request.body!.writeBuffer(&buffer)
8887
}
88+
self.pending.prepend(request)
8989
case .end:
90-
self.processRequest(context: context)
90+
let request = self.pending.removeFirst()
91+
self.processRequest(context: context, request: request)
9192
}
9293
}
9394

94-
func processRequest(context: ChannelHandlerContext) {
95-
self.logger.debug("\(self) processing \(self.requestHead.uri)")
95+
func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
96+
self.logger.debug("\(self) processing \(request.head.uri)")
9697

9798
var responseStatus: HTTPResponseStatus
9899
var responseBody: String?
99100
var responseHeaders: [(String, String)]?
100101

101-
if self.requestHead.uri.hasSuffix("/next") {
102+
if request.head.uri.hasSuffix("/next") {
102103
let requestId = UUID().uuidString
103104
responseStatus = .ok
104105
switch self.mode {
@@ -108,7 +109,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
108109
responseBody = "{ \"body\": \"\(requestId)\" }"
109110
}
110111
responseHeaders = [(AmazonHeaders.requestID, requestId)]
111-
} else if self.requestHead.uri.hasSuffix("/response") {
112+
} else if request.head.uri.hasSuffix("/response") {
112113
responseStatus = .accepted
113114
} else {
114115
responseStatus = .notFound

Sources/SwiftAwsLambda/HttpClient.swift

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ import NIOConcurrencyHelpers
1717
import NIOHTTP1
1818

1919
/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
20-
internal class HTTPClient {
20+
internal final class HTTPClient {
2121
private let eventLoop: EventLoop
2222
private let configuration: Lambda.Configuration.RuntimeEngine
2323
private let targetHost: String
2424

2525
private var state = State.disconnected
26-
private let lock = Lock()
26+
private let stateLock = Lock()
2727

2828
init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) {
2929
self.eventLoop = eventLoop
@@ -46,54 +46,60 @@ internal class HTTPClient {
4646
timeout: timeout ?? self.configuration.requestTimeout))
4747
}
4848

49-
49+
// TODO: cap reconnect attempt
5050
private func execute(_ request: Request) -> EventLoopFuture<Response> {
51-
self.lock.lock()
51+
self.stateLock.lock()
5252
switch self.state {
53+
case .disconnected:
54+
let promise = self.eventLoop.makePromise(of: Response.self)
55+
self.state = .connecting(promise.futureResult)
56+
self.stateLock.unlock()
57+
self.connect().flatMap { channel -> EventLoopFuture<Response> in
58+
self.stateLock.withLock {
59+
guard case .connecting = self.state else {
60+
preconditionFailure("invalid state \(self.state)")
61+
}
62+
self.state = .connected(channel)
63+
}
64+
return self.execute(request)
65+
}.cascade(to: promise)
66+
return promise.futureResult
67+
case .connecting(let future):
68+
let future = future.flatMap { _ in
69+
self.execute(request)
70+
}
71+
self.state = .connecting(future)
72+
self.stateLock.unlock()
73+
return future
5374
case .connected(let channel):
5475
guard channel.isActive else {
55-
// attempt to reconnect
5676
self.state = .disconnected
57-
self.lock.unlock()
77+
self.stateLock.unlock()
5878
return self.execute(request)
5979
}
60-
self.lock.unlock()
80+
self.stateLock.unlock()
6181
let promise = channel.eventLoop.makePromise(of: Response.self)
6282
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
63-
return channel.writeAndFlush(wrapper).flatMap {
64-
promise.futureResult
65-
}
66-
case .disconnected:
67-
return self.connect().flatMap {
68-
self.lock.unlock()
69-
return self.execute(request)
70-
}
71-
default:
72-
preconditionFailure("invalid state \(self.state)")
83+
channel.writeAndFlush(wrapper).cascadeFailure(to: promise)
84+
return promise.futureResult
7385
}
7486
}
7587

76-
private func connect() -> EventLoopFuture<Void> {
77-
guard case .disconnected = self.state else {
78-
preconditionFailure("invalid state \(self.state)")
79-
}
80-
self.state = .connecting
81-
let bootstrap = ClientBootstrap(group: eventLoop)
88+
private func connect() -> EventLoopFuture<Channel> {
89+
let bootstrap = ClientBootstrap(group: self.eventLoop)
8290
.channelInitializer { channel in
8391
channel.pipeline.addHTTPClientHandlers().flatMap {
8492
channel.pipeline.addHandlers([HTTPHandler(keepAlive: self.configuration.keepAlive),
8593
UnaryHandler(keepAlive: self.configuration.keepAlive)])
8694
}
8795
}
88-
96+
8997
do {
9098
// connect directly via socket address to avoid happy eyeballs (perf)
9199
let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port)
92-
return bootstrap.connect(to: address).flatMapThrowing { channel in
93-
self.state = .connected(channel)
94-
}
100+
return bootstrap.connect(to: address)
95101
} catch {
96-
return eventLoop.makeFailedFuture(error)
102+
return self.eventLoop.makeFailedFuture(error)
97103
}
98104
}
99105

@@ -127,13 +133,13 @@ internal class HTTPClient {
127133
}
128134

129135
private enum State {
130-
case connecting
131-
case connected(Channel)
132136
case disconnected
137+
case connecting(EventLoopFuture<Response>)
138+
case connected(Channel)
133139
}
134140
}
135141

136-
private class HTTPHandler: ChannelDuplexHandler {
142+
private final class HTTPHandler: ChannelDuplexHandler {
137143
typealias OutboundIn = HTTPClient.Request
138144
typealias InboundOut = HTTPClient.Response
139145
typealias InboundIn = HTTPClientResponsePart
@@ -208,15 +214,15 @@ private class HTTPHandler: ChannelDuplexHandler {
208214
}
209215
}
210216

211-
private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
217+
private final class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
212218
typealias OutboundIn = HTTPRequestWrapper
213219
typealias InboundIn = HTTPClient.Response
214220
typealias OutboundOut = HTTPClient.Request
215221

216222
private let keepAlive: Bool
217223

218224
private let lock = Lock()
219-
private var pendingResponses = CircularBuffer<(EventLoopPromise<HTTPClient.Response>, Scheduled<Void>?)>()
225+
private var pendingResponses = CircularBuffer<(promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)>()
220226
private var lastError: Error?
221227

222228
init(keepAlive: Bool) {
@@ -232,19 +238,21 @@ private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
232238
}
233239
}
234240
}
235-
self.lock.withLockVoid { pendingResponses.append((wrapper.promise, timeoutTask)) }
241+
self.lock.withLockVoid { pendingResponses.append((promise: wrapper.promise, timeout: timeoutTask)) }
236242
context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise)
237243
}
238244

239245
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
240246
let response = unwrapInboundIn(data)
241247
if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) {
242248
let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive"
243-
let future = self.keepAlive && serverKeepAlive ? context.eventLoop.makeSucceededFuture(()) : context.channel.close()
244-
future.whenComplete { _ in
245-
pending.1?.cancel()
246-
pending.0.succeed(response)
249+
if !self.keepAlive || !serverKeepAlive {
250+
pending.promise.futureResult.whenComplete { _ in
251+
_ = context.channel.close()
252+
}
247253
}
254+
pending.timeout?.cancel()
255+
pending.promise.succeed(response)
248256
}
249257
}
250258

Sources/SwiftAwsLambda/Lambda.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ import Darwin.C
1919
#endif
2020

2121
import Backtrace
22+
import Dispatch
2223
import Logging
2324
import NIO
2425
import NIOConcurrencyHelpers
25-
import Dispatch
2626

2727
public enum Lambda {
2828
/// Run a Lambda defined by implementing the `LambdaClosure` closure.
@@ -105,7 +105,7 @@ public enum Lambda {
105105
}
106106
set {
107107
self.stateLock.withLockVoid {
108-
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(_state)")
108+
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(self._state)")
109109
self._state = newValue
110110
}
111111
}
@@ -124,12 +124,12 @@ public enum Lambda {
124124
}
125125

126126
func stop() {
127-
self.logger.info("lambda lifecycle stopping")
127+
self.logger.debug("lambda lifecycle stopping")
128128
self.state = .stopping
129129
}
130130

131131
func shutdown() {
132-
self.logger.info("lambda lifecycle shutdown")
132+
self.logger.debug("lambda lifecycle shutdown")
133133
self.state = .shutdown
134134
}
135135

@@ -202,8 +202,8 @@ public enum Lambda {
202202
let stopSignal: Signal
203203

204204
init(id: String? = nil, maxTimes: Int? = nil, stopSignal: Signal? = nil) {
205-
self.id = id ?? "\(DispatchTime.now().uptimeNanoseconds)"
206-
self.maxTimes = maxTimes ?? env("MAX_REQUESTS").flatMap(Int.init) ?? 0
205+
self.id = id ?? "lambda" // "\(DispatchTime.now().uptimeNanoseconds)"
206+
self.maxTimes = maxTimes ?? env("MAX_REQUESTS").flatMap(Int.init) ?? 0
207207
self.stopSignal = stopSignal ?? env("STOP_SIGNAL").flatMap(Int32.init).flatMap(Signal.init) ?? Signal.TERM
208208
precondition(self.maxTimes >= 0, "maxTimes must be equal or larger than 0")
209209
}

Sources/SwiftAwsLambda/LambdaRunner.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ internal struct LambdaRunner {
3636
///
3737
/// - Returns: An `EventLoopFuture<Void>` fulfilled with the outcome of the initialization.
3838
func initialize(logger: Logger) -> EventLoopFuture<Void> {
39-
logger.info("initializing lambda")
39+
logger.debug("initializing lambda")
4040
// We need to use `flatMap` instead of `whenFailure` to ensure we complete reporting the result before stopping.
4141
return self.lambdaHandler.initialize(eventLoop: self.eventLoop,
4242
lifecycleId: self.lifecycleId,
@@ -69,7 +69,7 @@ internal struct LambdaRunner {
6969
}
7070
}.always { result in
7171
// we are done!
72-
logger.log(level: result.successful ? .info : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
72+
logger.log(level: result.successful ? .debug : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
7373
}
7474
}
7575
}

Sources/SwiftAwsLambda/LambdaRuntimeClient.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ internal struct JsonCodecError: Error, Equatable {
145145
}
146146

147147
static func == (lhs: JsonCodecError, rhs: JsonCodecError) -> Bool {
148-
return lhs.cause.localizedDescription == rhs.cause.localizedDescription
148+
return String(describing: lhs.cause) == String(describing: rhs.cause)
149149
}
150150
}
151151

Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,47 @@
1616
import XCTest
1717

1818
class CodableLambdaTest: XCTestCase {
19-
func testSuceess() throws {
19+
func testSuceess() {
20+
let server = MockLambdaServer(behavior: GoodBehavior())
21+
XCTAssertNoThrow(try server.start().wait())
22+
defer { XCTAssertNoThrow(try server.stop().wait()) }
23+
2024
let maxTimes = Int.random(in: 1 ... 10)
2125
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
22-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
2326
let result = Lambda.run(handler: CodableEchoHandler(), configuration: configuration)
24-
try server.stop().wait()
2527
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
2628
}
2729

28-
func testFailure() throws {
29-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
30+
func testFailure() {
31+
let server = MockLambdaServer(behavior: BadBehavior())
32+
XCTAssertNoThrow(try server.start().wait())
33+
defer { XCTAssertNoThrow(try server.stop().wait()) }
34+
3035
let result = Lambda.run(handler: CodableEchoHandler())
31-
try server.stop().wait()
3236
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
3337
}
3438

35-
func testClosureSuccess() throws {
39+
func testClosureSuccess() {
40+
let server = MockLambdaServer(behavior: GoodBehavior())
41+
XCTAssertNoThrow(try server.start().wait())
42+
defer { XCTAssertNoThrow(try server.stop().wait()) }
43+
3644
let maxTimes = Int.random(in: 1 ... 10)
3745
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
38-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
3946
let result = Lambda.run(configuration: configuration) { (_, payload: Request, callback) in
4047
callback(.success(Response(requestId: payload.requestId)))
4148
}
42-
try server.stop().wait()
4349
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
4450
}
4551

46-
func testClosureFailure() throws {
47-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
52+
func testClosureFailure() {
53+
let server = MockLambdaServer(behavior: BadBehavior())
54+
XCTAssertNoThrow(try server.start().wait())
55+
defer { XCTAssertNoThrow(try server.stop().wait()) }
56+
4857
let result: LambdaLifecycleResult = Lambda.run { (_, payload: Request, callback) in
4958
callback(.success(Response(requestId: payload.requestId)))
5059
}
51-
try server.stop().wait()
5260
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
5361
}
5462
}

0 commit comments

Comments
 (0)