Skip to content

Commit 08f75f5

Browse files
authored
refactor (#13)
motivation: simpler concurrency, better tests changes: * ensure http client is called in a single-threaded manner and remove locks * refactor test * make mock server more robust
1 parent b8a6466 commit 08f75f5

14 files changed

+195
-147
lines changed

Sources/MockServer/main.swift

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,41 +64,42 @@ internal final class HTTPHandler: ChannelInboundHandler {
6464
private let mode: Mode
6565
private let keepAlive: Bool
6666

67-
private var requestHead: HTTPRequestHead!
68-
private var requestBody: ByteBuffer?
67+
private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()
6968

7069
public init(logger: Logger, keepAlive: Bool, mode: Mode) {
7170
self.logger = logger
72-
self.mode = mode
7371
self.keepAlive = keepAlive
72+
self.mode = mode
7473
}
7574

7675
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
7776
let requestPart = unwrapInboundIn(data)
7877

7978
switch requestPart {
8079
case .head(let head):
81-
self.requestHead = head
82-
self.requestBody?.clear()
80+
self.pending.append((head: head, body: nil))
8381
case .body(var buffer):
84-
if self.requestBody == nil {
85-
self.requestBody = buffer
82+
var request = self.pending.removeFirst()
83+
if request.body == nil {
84+
request.body = buffer
8685
} else {
87-
self.requestBody!.writeBuffer(&buffer)
86+
request.body!.writeBuffer(&buffer)
8887
}
88+
self.pending.prepend(request)
8989
case .end:
90-
self.processRequest(context: context)
90+
let request = self.pending.removeFirst()
91+
self.processRequest(context: context, request: request)
9192
}
9293
}
9394

94-
func processRequest(context: ChannelHandlerContext) {
95-
self.logger.debug("\(self) processing \(self.requestHead.uri)")
95+
func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
96+
self.logger.debug("\(self) processing \(request.head.uri)")
9697

9798
var responseStatus: HTTPResponseStatus
9899
var responseBody: String?
99100
var responseHeaders: [(String, String)]?
100101

101-
if self.requestHead.uri.hasSuffix("/next") {
102+
if request.head.uri.hasSuffix("/next") {
102103
let requestId = UUID().uuidString
103104
responseStatus = .ok
104105
switch self.mode {
@@ -108,7 +109,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
108109
responseBody = "{ \"body\": \"\(requestId)\" }"
109110
}
110111
responseHeaders = [(AmazonHeaders.requestID, requestId)]
111-
} else if self.requestHead.uri.hasSuffix("/response") {
112+
} else if request.head.uri.hasSuffix("/response") {
112113
responseStatus = .accepted
113114
} else {
114115
responseStatus = .notFound

Sources/SwiftAwsLambda/HttpClient.swift

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ import NIOConcurrencyHelpers
1717
import NIOHTTP1
1818

1919
/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
20-
internal class HTTPClient {
20+
/// Note that Lambda Runtime API dictate that only one requests runs at a time.
21+
/// This means we can avoid locks and other concurrency concern we would otherwise need to build into the client
22+
internal final class HTTPClient {
2123
private let eventLoop: EventLoop
2224
private let configuration: Lambda.Configuration.RuntimeEngine
2325
private let targetHost: String
2426

2527
private var state = State.disconnected
26-
private let lock = Lock()
28+
private let executing = NIOAtomic.makeAtomic(value: false)
2729

2830
init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) {
2931
self.eventLoop = eventLoop
@@ -46,38 +48,34 @@ internal class HTTPClient {
4648
timeout: timeout ?? self.configuration.requestTimeout))
4749
}
4850

49-
private func execute(_ request: Request) -> EventLoopFuture<Response> {
50-
self.lock.lock()
51+
// TODO: cap reconnect attempt
52+
private func execute(_ request: Request, validate: Bool = true) -> EventLoopFuture<Response> {
53+
precondition(!validate || self.executing.compareAndExchange(expected: false, desired: true), "expecting single request at a time")
54+
5155
switch self.state {
56+
case .disconnected:
57+
return self.connect().flatMap { channel -> EventLoopFuture<Response> in
58+
self.state = .connected(channel)
59+
return self.execute(request, validate: false)
60+
}
5261
case .connected(let channel):
5362
guard channel.isActive else {
54-
// attempt to reconnect
5563
self.state = .disconnected
56-
self.lock.unlock()
57-
return self.execute(request)
64+
return self.execute(request, validate: false)
5865
}
59-
self.lock.unlock()
66+
6067
let promise = channel.eventLoop.makePromise(of: Response.self)
61-
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
62-
return channel.writeAndFlush(wrapper).flatMap {
63-
promise.futureResult
64-
}
65-
case .disconnected:
66-
return self.connect().flatMap {
67-
self.lock.unlock()
68-
return self.execute(request)
68+
promise.futureResult.whenComplete { _ in
69+
precondition(self.executing.compareAndExchange(expected: true, desired: false), "invalid execution state")
6970
}
70-
default:
71-
preconditionFailure("invalid state \(self.state)")
71+
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
72+
channel.writeAndFlush(wrapper).cascadeFailure(to: promise)
73+
return promise.futureResult
7274
}
7375
}
7476

75-
private func connect() -> EventLoopFuture<Void> {
76-
guard case .disconnected = self.state else {
77-
preconditionFailure("invalid state \(self.state)")
78-
}
79-
self.state = .connecting
80-
let bootstrap = ClientBootstrap(group: eventLoop)
77+
private func connect() -> EventLoopFuture<Channel> {
78+
let bootstrap = ClientBootstrap(group: self.eventLoop)
8179
.channelInitializer { channel in
8280
channel.pipeline.addHTTPClientHandlers().flatMap {
8381
channel.pipeline.addHandlers([HTTPHandler(keepAlive: self.configuration.keepAlive),
@@ -88,9 +86,7 @@ internal class HTTPClient {
8886
do {
8987
// connect directly via socket address to avoid happy eyeballs (perf)
9088
let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port)
91-
return bootstrap.connect(to: address).flatMapThrowing { channel in
92-
self.state = .connected(channel)
93-
}
89+
return bootstrap.connect(to: address)
9490
} catch {
9591
return self.eventLoop.makeFailedFuture(error)
9692
}
@@ -126,13 +122,12 @@ internal class HTTPClient {
126122
}
127123

128124
private enum State {
129-
case connecting
130-
case connected(Channel)
131125
case disconnected
126+
case connected(Channel)
132127
}
133128
}
134129

135-
private class HTTPHandler: ChannelDuplexHandler {
130+
private final class HTTPHandler: ChannelDuplexHandler {
136131
typealias OutboundIn = HTTPClient.Request
137132
typealias InboundOut = HTTPClient.Response
138133
typealias InboundIn = HTTPClientResponsePart
@@ -207,63 +202,74 @@ private class HTTPHandler: ChannelDuplexHandler {
207202
}
208203
}
209204

210-
private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
205+
// no need in locks since we validate only one request can run at a time
206+
private final class UnaryHandler: ChannelDuplexHandler {
211207
typealias OutboundIn = HTTPRequestWrapper
212208
typealias InboundIn = HTTPClient.Response
213209
typealias OutboundOut = HTTPClient.Request
214210

215211
private let keepAlive: Bool
216212

217-
private let lock = Lock()
218-
private var pendingResponses = CircularBuffer<(EventLoopPromise<HTTPClient.Response>, Scheduled<Void>?)>()
213+
private var pending: (promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)?
219214
private var lastError: Error?
220215

221216
init(keepAlive: Bool) {
222217
self.keepAlive = keepAlive
223218
}
224219

225220
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
221+
guard self.pending == nil else {
222+
preconditionFailure("invalid state, outstanding request")
223+
}
226224
let wrapper = unwrapOutboundIn(data)
227225
let timeoutTask = wrapper.request.timeout.map {
228226
context.eventLoop.scheduleTask(in: $0) {
229-
if (self.lock.withLock { !self.pendingResponses.isEmpty }) {
230-
self.errorCaught(context: context, error: HTTPClient.Errors.timeout)
227+
if self.pending != nil {
228+
context.pipeline.fireErrorCaught(HTTPClient.Errors.timeout)
231229
}
232230
}
233231
}
234-
self.lock.withLockVoid { pendingResponses.append((wrapper.promise, timeoutTask)) }
232+
self.pending = (promise: wrapper.promise, timeout: timeoutTask)
235233
context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise)
236234
}
237235

238236
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
239237
let response = unwrapInboundIn(data)
240-
if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) {
241-
let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive"
242-
let future = self.keepAlive && serverKeepAlive ? context.eventLoop.makeSucceededFuture(()) : context.channel.close()
243-
future.whenComplete { _ in
244-
pending.1?.cancel()
245-
pending.0.succeed(response)
238+
guard let pending = self.pending else {
239+
preconditionFailure("invalid state, no pending request")
240+
}
241+
let serverKeepAlive = response.headers.first(name: "connection")?.lowercased() == "keep-alive"
242+
if !self.keepAlive || !serverKeepAlive {
243+
pending.promise.futureResult.whenComplete { _ in
244+
_ = context.channel.close()
246245
}
247246
}
247+
self.completeWith(.success(response))
248248
}
249249

250250
func errorCaught(context: ChannelHandlerContext, error: Error) {
251251
// pending responses will fail with lastError in channelInactive since we are calling context.close
252-
self.lock.withLockVoid { self.lastError = error }
252+
self.lastError = error
253253
context.channel.close(promise: nil)
254254
}
255255

256256
func channelInactive(context: ChannelHandlerContext) {
257257
// fail any pending responses with last error or assume peer disconnected
258-
self.failPendingResponses(self.lock.withLock { self.lastError } ?? HTTPClient.Errors.connectionResetByPeer)
258+
if self.pending != nil {
259+
let error = self.lastError ?? HTTPClient.Errors.connectionResetByPeer
260+
self.completeWith(.failure(error))
261+
}
259262
context.fireChannelInactive()
260263
}
261264

262-
private func failPendingResponses(_ error: Error) {
263-
while let pending = (self.lock.withLock { pendingResponses.popFirst() }) {
264-
pending.1?.cancel()
265-
pending.0.fail(error)
265+
private func completeWith(_ result: Result<HTTPClient.Response, Error>) {
266+
guard let pending = self.pending else {
267+
preconditionFailure("invalid state, no pending request")
266268
}
269+
self.pending = nil
270+
self.lastError = nil
271+
pending.timeout?.cancel()
272+
pending.promise.completeWith(result)
267273
}
268274
}
269275

Sources/SwiftAwsLambda/Lambda.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public enum Lambda {
105105
}
106106
set {
107107
self.stateLock.withLockVoid {
108-
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(_state)")
108+
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(self._state)")
109109
self._state = newValue
110110
}
111111
}
@@ -124,12 +124,12 @@ public enum Lambda {
124124
}
125125

126126
func stop() {
127-
self.logger.info("lambda lifecycle stopping")
127+
self.logger.debug("lambda lifecycle stopping")
128128
self.state = .stopping
129129
}
130130

131131
func shutdown() {
132-
self.logger.info("lambda lifecycle shutdown")
132+
self.logger.debug("lambda lifecycle shutdown")
133133
self.state = .shutdown
134134
}
135135

Sources/SwiftAwsLambda/LambdaRunner.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ internal struct LambdaRunner {
3636
///
3737
/// - Returns: An `EventLoopFuture<Void>` fulfilled with the outcome of the initialization.
3838
func initialize(logger: Logger) -> EventLoopFuture<Void> {
39-
logger.info("initializing lambda")
39+
logger.debug("initializing lambda")
4040
// We need to use `flatMap` instead of `whenFailure` to ensure we complete reporting the result before stopping.
4141
return self.lambdaHandler.initialize(eventLoop: self.eventLoop,
4242
lifecycleId: self.lifecycleId,
@@ -69,7 +69,7 @@ internal struct LambdaRunner {
6969
}
7070
}.always { result in
7171
// we are done!
72-
logger.log(level: result.successful ? .info : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
72+
logger.log(level: result.successful ? .debug : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
7373
}
7474
}
7575
}

Sources/SwiftAwsLambda/LambdaRuntimeClient.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ internal struct JsonCodecError: Error, Equatable {
145145
}
146146

147147
static func == (lhs: JsonCodecError, rhs: JsonCodecError) -> Bool {
148-
return lhs.cause.localizedDescription == rhs.cause.localizedDescription
148+
return String(describing: lhs.cause) == String(describing: rhs.cause)
149149
}
150150
}
151151

Tests/SwiftAwsLambdaTests/Lambda+CodeableTest+XCTest.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import XCTest
2525
extension CodableLambdaTest {
2626
static var allTests: [(String, (CodableLambdaTest) -> () throws -> Void)] {
2727
return [
28-
("testSuceess", testSuceess),
28+
("testSuccess", testSuccess),
2929
("testFailure", testFailure),
3030
("testClosureSuccess", testClosureSuccess),
3131
("testClosureFailure", testClosureFailure),

Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,47 @@
1616
import XCTest
1717

1818
class CodableLambdaTest: XCTestCase {
19-
func testSuceess() throws {
19+
func testSuccess() {
20+
let server = MockLambdaServer(behavior: GoodBehavior())
21+
XCTAssertNoThrow(try server.start().wait())
22+
defer { XCTAssertNoThrow(try server.stop().wait()) }
23+
2024
let maxTimes = Int.random(in: 1 ... 10)
2125
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
22-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
2326
let result = Lambda.run(handler: CodableEchoHandler(), configuration: configuration)
24-
try server.stop().wait()
2527
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
2628
}
2729

28-
func testFailure() throws {
29-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
30+
func testFailure() {
31+
let server = MockLambdaServer(behavior: BadBehavior())
32+
XCTAssertNoThrow(try server.start().wait())
33+
defer { XCTAssertNoThrow(try server.stop().wait()) }
34+
3035
let result = Lambda.run(handler: CodableEchoHandler())
31-
try server.stop().wait()
3236
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
3337
}
3438

35-
func testClosureSuccess() throws {
39+
func testClosureSuccess() {
40+
let server = MockLambdaServer(behavior: GoodBehavior())
41+
XCTAssertNoThrow(try server.start().wait())
42+
defer { XCTAssertNoThrow(try server.stop().wait()) }
43+
3644
let maxTimes = Int.random(in: 1 ... 10)
3745
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
38-
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
3946
let result = Lambda.run(configuration: configuration) { (_, payload: Request, callback) in
4047
callback(.success(Response(requestId: payload.requestId)))
4148
}
42-
try server.stop().wait()
4349
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
4450
}
4551

46-
func testClosureFailure() throws {
47-
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
52+
func testClosureFailure() {
53+
let server = MockLambdaServer(behavior: BadBehavior())
54+
XCTAssertNoThrow(try server.start().wait())
55+
defer { XCTAssertNoThrow(try server.stop().wait()) }
56+
4857
let result: LambdaLifecycleResult = Lambda.run { (_, payload: Request, callback) in
4958
callback(.success(Response(requestId: payload.requestId)))
5059
}
51-
try server.stop().wait()
5260
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
5361
}
5462
}

Tests/SwiftAwsLambdaTests/Lambda+StringTest+XCTest.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import XCTest
2525
extension StringLambdaTest {
2626
static var allTests: [(String, (StringLambdaTest) -> () throws -> Void)] {
2727
return [
28-
("testSuceess", testSuceess),
28+
("testSuccess", testSuccess),
2929
("testFailure", testFailure),
3030
("testClosureSuccess", testClosureSuccess),
3131
("testClosureFailure", testClosureFailure),

0 commit comments

Comments
 (0)