Skip to content

Commit 30b973c

Browse files
committed
refactor
motivation: safer concurrency, better tests changes: * make http client thread safe * add concurrency tests * refactor test * make mock server more ribust
1 parent b8a6466 commit 30b973c

12 files changed

+234
-131
lines changed

Sources/MockServer/main.swift

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,41 +64,42 @@ internal final class HTTPHandler: ChannelInboundHandler {
6464
private let mode: Mode
6565
private let keepAlive: Bool
6666

67-
private var requestHead: HTTPRequestHead!
68-
private var requestBody: ByteBuffer?
67+
private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()
6968

7069
public init(logger: Logger, keepAlive: Bool, mode: Mode) {
7170
self.logger = logger
72-
self.mode = mode
7371
self.keepAlive = keepAlive
72+
self.mode = mode
7473
}
7574

7675
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
7776
let requestPart = unwrapInboundIn(data)
7877

7978
switch requestPart {
8079
case .head(let head):
81-
self.requestHead = head
82-
self.requestBody?.clear()
80+
self.pending.append((head: head, body: nil))
8381
case .body(var buffer):
84-
if self.requestBody == nil {
85-
self.requestBody = buffer
82+
var request = self.pending.removeFirst()
83+
if request.body == nil {
84+
request.body = buffer
8685
} else {
87-
self.requestBody!.writeBuffer(&buffer)
86+
request.body!.writeBuffer(&buffer)
8887
}
88+
self.pending.prepend(request)
8989
case .end:
90-
self.processRequest(context: context)
90+
let request = self.pending.removeFirst()
91+
self.processRequest(context: context, request: request)
9192
}
9293
}
9394

94-
func processRequest(context: ChannelHandlerContext) {
95-
self.logger.debug("\(self) processing \(self.requestHead.uri)")
95+
func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
96+
self.logger.debug("\(self) processing \(request.head.uri)")
9697

9798
var responseStatus: HTTPResponseStatus
9899
var responseBody: String?
99100
var responseHeaders: [(String, String)]?
100101

101-
if self.requestHead.uri.hasSuffix("/next") {
102+
if request.head.uri.hasSuffix("/next") {
102103
let requestId = UUID().uuidString
103104
responseStatus = .ok
104105
switch self.mode {
@@ -108,7 +109,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
108109
responseBody = "{ \"body\": \"\(requestId)\" }"
109110
}
110111
responseHeaders = [(AmazonHeaders.requestID, requestId)]
111-
} else if self.requestHead.uri.hasSuffix("/response") {
112+
} else if request.head.uri.hasSuffix("/response") {
112113
responseStatus = .accepted
113114
} else {
114115
responseStatus = .notFound

Sources/SwiftAwsLambda/HttpClient.swift

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ import NIOConcurrencyHelpers
1717
import NIOHTTP1
1818

1919
/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
20-
internal class HTTPClient {
20+
internal final class HTTPClient {
2121
private let eventLoop: EventLoop
2222
private let configuration: Lambda.Configuration.RuntimeEngine
2323
private let targetHost: String
2424

2525
private var state = State.disconnected
26-
private let lock = Lock()
26+
private let stateLock = Lock()
2727

2828
init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) {
2929
self.eventLoop = eventLoop
@@ -47,37 +47,45 @@ internal class HTTPClient {
4747
}
4848

4949
private func execute(_ request: Request) -> EventLoopFuture<Response> {
50-
self.lock.lock()
50+
self.stateLock.lock()
5151
switch self.state {
52+
case .disconnected:
53+
let promise = self.eventLoop.makePromise(of: Response.self)
54+
self.state = .connecting(promise.futureResult)
55+
self.stateLock.unlock()
56+
self.connect().flatMap { channel -> EventLoopFuture<Response> in
57+
self.stateLock.withLock {
58+
guard case .connecting = self.state else {
59+
preconditionFailure("invalid state \(self.state)")
60+
}
61+
self.state = .connected(channel)
62+
}
63+
return self.execute(request)
64+
}.cascade(to: promise)
65+
return promise.futureResult
66+
case .connecting(let future):
67+
let future = future.flatMap { _ in
68+
self.execute(request)
69+
}
70+
self.state = .connecting(future)
71+
self.stateLock.unlock()
72+
return future
5273
case .connected(let channel):
5374
guard channel.isActive else {
54-
// attempt to reconnect
5575
self.state = .disconnected
56-
self.lock.unlock()
76+
self.stateLock.unlock()
5777
return self.execute(request)
5878
}
59-
self.lock.unlock()
79+
self.stateLock.unlock()
6080
let promise = channel.eventLoop.makePromise(of: Response.self)
6181
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
62-
return channel.writeAndFlush(wrapper).flatMap {
63-
promise.futureResult
64-
}
65-
case .disconnected:
66-
return self.connect().flatMap {
67-
self.lock.unlock()
68-
return self.execute(request)
69-
}
70-
default:
71-
preconditionFailure("invalid state \(self.state)")
82+
channel.writeAndFlush(wrapper).cascadeFailure(to: promise)
83+
return promise.futureResult
7284
}
7385
}
7486

75-
private func connect() -> EventLoopFuture<Void> {
76-
guard case .disconnected = self.state else {
77-
preconditionFailure("invalid state \(self.state)")
78-
}
79-
self.state = .connecting
80-
let bootstrap = ClientBootstrap(group: eventLoop)
87+
private func connect() -> EventLoopFuture<Channel> {
88+
let bootstrap = ClientBootstrap(group: self.eventLoop)
8189
.channelInitializer { channel in
8290
channel.pipeline.addHTTPClientHandlers().flatMap {
8391
channel.pipeline.addHandlers([HTTPHandler(keepAlive: self.configuration.keepAlive),
@@ -88,9 +96,7 @@ internal class HTTPClient {
8896
do {
8997
// connect directly via socket address to avoid happy eyeballs (perf)
9098
let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port)
91-
return bootstrap.connect(to: address).flatMapThrowing { channel in
92-
self.state = .connected(channel)
93-
}
99+
return bootstrap.connect(to: address)
94100
} catch {
95101
return self.eventLoop.makeFailedFuture(error)
96102
}
@@ -126,13 +132,13 @@ internal class HTTPClient {
126132
}
127133

128134
private enum State {
129-
case connecting
130-
case connected(Channel)
131135
case disconnected
136+
case connecting(EventLoopFuture<Response>)
137+
case connected(Channel)
132138
}
133139
}
134140

135-
private class HTTPHandler: ChannelDuplexHandler {
141+
private final class HTTPHandler: ChannelDuplexHandler {
136142
typealias OutboundIn = HTTPClient.Request
137143
typealias InboundOut = HTTPClient.Response
138144
typealias InboundIn = HTTPClientResponsePart
@@ -207,15 +213,15 @@ private class HTTPHandler: ChannelDuplexHandler {
207213
}
208214
}
209215

210-
private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
216+
private final class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
211217
typealias OutboundIn = HTTPRequestWrapper
212218
typealias InboundIn = HTTPClient.Response
213219
typealias OutboundOut = HTTPClient.Request
214220

215221
private let keepAlive: Bool
216222

217223
private let lock = Lock()
218-
private var pendingResponses = CircularBuffer<(EventLoopPromise<HTTPClient.Response>, Scheduled<Void>?)>()
224+
private var pendingResponses = CircularBuffer<(promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)>()
219225
private var lastError: Error?
220226

221227
init(keepAlive: Bool) {
@@ -231,19 +237,21 @@ private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
231237
}
232238
}
233239
}
234-
self.lock.withLockVoid { pendingResponses.append((wrapper.promise, timeoutTask)) }
240+
self.lock.withLockVoid { pendingResponses.append((promise: wrapper.promise, timeout: timeoutTask)) }
235241
context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise)
236242
}
237243

238244
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
239245
let response = unwrapInboundIn(data)
240246
if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) {
241247
let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive"
242-
let future = self.keepAlive && serverKeepAlive ? context.eventLoop.makeSucceededFuture(()) : context.channel.close()
243-
future.whenComplete { _ in
244-
pending.1?.cancel()
245-
pending.0.succeed(response)
248+
if !self.keepAlive || !serverKeepAlive {
249+
pending.promise.futureResult.whenComplete { _ in
250+
_ = context.channel.close()
251+
}
246252
}
253+
pending.timeout?.cancel()
254+
pending.promise.succeed(response)
247255
}
248256
}
249257

Sources/SwiftAwsLambda/Lambda.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public enum Lambda {
105105
}
106106
set {
107107
self.stateLock.withLockVoid {
108-
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(_state)")
108+
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(self._state)")
109109
self._state = newValue
110110
}
111111
}
@@ -124,12 +124,12 @@ public enum Lambda {
124124
}
125125

126126
func stop() {
127-
self.logger.info("lambda lifecycle stopping")
127+
self.logger.debug("lambda lifecycle stopping")
128128
self.state = .stopping
129129
}
130130

131131
func shutdown() {
132-
self.logger.info("lambda lifecycle shutdown")
132+
self.logger.debug("lambda lifecycle shutdown")
133133
self.state = .shutdown
134134
}
135135

Sources/SwiftAwsLambda/LambdaRunner.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ internal struct LambdaRunner {
3636
///
3737
/// - Returns: An `EventLoopFuture<Void>` fulfilled with the outcome of the initialization.
3838
func initialize(logger: Logger) -> EventLoopFuture<Void> {
39-
logger.info("initializing lambda")
39+
logger.debug("initializing lambda")
4040
// We need to use `flatMap` instead of `whenFailure` to ensure we complete reporting the result before stopping.
4141
return self.lambdaHandler.initialize(eventLoop: self.eventLoop,
4242
lifecycleId: self.lifecycleId,
@@ -69,7 +69,7 @@ internal struct LambdaRunner {
6969
}
7070
}.always { result in
7171
// we are done!
72-
logger.log(level: result.successful ? .info : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
72+
logger.log(level: result.successful ? .debug : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
7373
}
7474
}
7575
}

Sources/SwiftAwsLambda/LambdaRuntimeClient.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ internal struct JsonCodecError: Error, Equatable {
145145
}
146146

147147
static func == (lhs: JsonCodecError, rhs: JsonCodecError) -> Bool {
148-
return lhs.cause.localizedDescription == rhs.cause.localizedDescription
148+
return String(describing: lhs.cause) == String(describing: rhs.cause)
149149
}
150150
}
151151

Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,47 @@
1616
import XCTest
1717

1818
class CodableLambdaTest: XCTestCase {
19-
func testSuceess() throws {
19+
func testSuceess() {
20+
let server = MockLambdaServer(behavior: GoodBehavior())
21+
XCTAssertNoThrow(try server.start().wait())
22+
defer { XCTAssertNoThrow(try server.stop().wait()) }
23+
2024
let maxTimes = Int.random(in: 1 ... 10)
2125
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
22-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
2326
let result = Lambda.run(handler: CodableEchoHandler(), configuration: configuration)
24-
try server.stop().wait()
2527
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
2628
}
2729

28-
func testFailure() throws {
29-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
30+
func testFailure() {
31+
let server = MockLambdaServer(behavior: BadBehavior())
32+
XCTAssertNoThrow(try server.start().wait())
33+
defer { XCTAssertNoThrow(try server.stop().wait()) }
34+
3035
let result = Lambda.run(handler: CodableEchoHandler())
31-
try server.stop().wait()
3236
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
3337
}
3438

35-
func testClosureSuccess() throws {
39+
func testClosureSuccess() {
40+
let server = MockLambdaServer(behavior: GoodBehavior())
41+
XCTAssertNoThrow(try server.start().wait())
42+
defer { XCTAssertNoThrow(try server.stop().wait()) }
43+
3644
let maxTimes = Int.random(in: 1 ... 10)
3745
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
38-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
3946
let result = Lambda.run(configuration: configuration) { (_, payload: Request, callback) in
4047
callback(.success(Response(requestId: payload.requestId)))
4148
}
42-
try server.stop().wait()
4349
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
4450
}
4551

46-
func testClosureFailure() throws {
47-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
52+
func testClosureFailure() {
53+
let server = MockLambdaServer(behavior: BadBehavior())
54+
XCTAssertNoThrow(try server.start().wait())
55+
defer { XCTAssertNoThrow(try server.stop().wait()) }
56+
4857
let result: LambdaLifecycleResult = Lambda.run { (_, payload: Request, callback) in
4958
callback(.success(Response(requestId: payload.requestId)))
5059
}
51-
try server.stop().wait()
5260
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
5361
}
5462
}

Tests/SwiftAwsLambdaTests/Lambda+StringTest.swift

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,47 @@
1616
import XCTest
1717

1818
class StringLambdaTest: XCTestCase {
19-
func testSuceess() throws {
19+
func testSuceess() {
20+
let server = MockLambdaServer(behavior: GoodBehavior())
21+
XCTAssertNoThrow(try server.start().wait())
22+
defer { XCTAssertNoThrow(try server.stop().wait()) }
23+
2024
let maxTimes = Int.random(in: 1 ... 10)
2125
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
22-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
2326
let result = Lambda.run(handler: StringEchoHandler(), configuration: configuration)
24-
try server.stop().wait()
2527
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
2628
}
2729

28-
func testFailure() throws {
29-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
30+
func testFailure() {
31+
let server = MockLambdaServer(behavior: BadBehavior())
32+
XCTAssertNoThrow(try server.start().wait())
33+
defer { XCTAssertNoThrow(try server.stop().wait()) }
34+
3035
let result = Lambda.run(handler: StringEchoHandler())
31-
try server.stop().wait()
3236
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
3337
}
3438

35-
func testClosureSuccess() throws {
39+
func testClosureSuccess() {
40+
let server = MockLambdaServer(behavior: GoodBehavior())
41+
XCTAssertNoThrow(try server.start().wait())
42+
defer { XCTAssertNoThrow(try server.stop().wait()) }
43+
3644
let maxTimes = Int.random(in: 1 ... 10)
3745
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
38-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
3946
let result = Lambda.run(configuration: configuration) { (_, payload: String, callback) in
4047
callback(.success(payload))
4148
}
42-
try server.stop().wait()
4349
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
4450
}
4551

46-
func testClosureFailure() throws {
47-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
52+
func testClosureFailure() {
53+
let server = MockLambdaServer(behavior: BadBehavior())
54+
XCTAssertNoThrow(try server.start().wait())
55+
defer { XCTAssertNoThrow(try server.stop().wait()) }
56+
4857
let result: LambdaLifecycleResult = Lambda.run { (_, payload: String, callback) in
4958
callback(.success(payload))
5059
}
51-
try server.stop().wait()
5260
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
5361
}
5462
}

Tests/SwiftAwsLambdaTests/LambdaRunnerTest+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ extension LambdaRunnerTest {
2727
return [
2828
("testSuccess", testSuccess),
2929
("testFailure", testFailure),
30+
("testConcurrency", testConcurrency),
3031
]
3132
}
3233
}

0 commit comments

Comments
 (0)