diff --git a/Sources/MockServer/main.swift b/Sources/MockServer/main.swift index 60167d4c..2d318881 100644 --- a/Sources/MockServer/main.swift +++ b/Sources/MockServer/main.swift @@ -64,13 +64,12 @@ internal final class HTTPHandler: ChannelInboundHandler { private let mode: Mode private let keepAlive: Bool - private var requestHead: HTTPRequestHead! - private var requestBody: ByteBuffer? + private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() public init(logger: Logger, keepAlive: Bool, mode: Mode) { self.logger = logger - self.mode = mode self.keepAlive = keepAlive + self.mode = mode } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -78,27 +77,29 @@ internal final class HTTPHandler: ChannelInboundHandler { switch requestPart { case .head(let head): - self.requestHead = head - self.requestBody?.clear() + self.pending.append((head: head, body: nil)) case .body(var buffer): - if self.requestBody == nil { - self.requestBody = buffer + var request = self.pending.removeFirst() + if request.body == nil { + request.body = buffer } else { - self.requestBody!.writeBuffer(&buffer) + request.body!.writeBuffer(&buffer) } + self.pending.prepend(request) case .end: - self.processRequest(context: context) + let request = self.pending.removeFirst() + self.processRequest(context: context, request: request) } } - func processRequest(context: ChannelHandlerContext) { - self.logger.debug("\(self) processing \(self.requestHead.uri)") + func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { + self.logger.debug("\(self) processing \(request.head.uri)") var responseStatus: HTTPResponseStatus var responseBody: String? var responseHeaders: [(String, String)]? - if self.requestHead.uri.hasSuffix("/next") { + if request.head.uri.hasSuffix("/next") { let requestId = UUID().uuidString responseStatus = .ok switch self.mode { @@ -108,7 +109,7 @@ internal final class HTTPHandler: ChannelInboundHandler { responseBody = "{ \"body\": \"\(requestId)\" }" } responseHeaders = [(AmazonHeaders.requestID, requestId)] - } else if self.requestHead.uri.hasSuffix("/response") { + } else if request.head.uri.hasSuffix("/response") { responseStatus = .accepted } else { responseStatus = .notFound diff --git a/Sources/SwiftAwsLambda/HttpClient.swift b/Sources/SwiftAwsLambda/HttpClient.swift index 95f43472..bfd95bcb 100644 --- a/Sources/SwiftAwsLambda/HttpClient.swift +++ b/Sources/SwiftAwsLambda/HttpClient.swift @@ -17,13 +17,15 @@ import NIOConcurrencyHelpers import NIOHTTP1 /// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server. -internal class HTTPClient { +/// Note that Lambda Runtime API dictate that only one requests runs at a time. +/// This means we can avoid locks and other concurrency concern we would otherwise need to build into the client +internal final class HTTPClient { private let eventLoop: EventLoop private let configuration: Lambda.Configuration.RuntimeEngine private let targetHost: String private var state = State.disconnected - private let lock = Lock() + private let executing = NIOAtomic.makeAtomic(value: false) init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) { self.eventLoop = eventLoop @@ -46,38 +48,34 @@ internal class HTTPClient { timeout: timeout ?? self.configuration.requestTimeout)) } - private func execute(_ request: Request) -> EventLoopFuture { - self.lock.lock() + // TODO: cap reconnect attempt + private func execute(_ request: Request, validate: Bool = true) -> EventLoopFuture { + precondition(!validate || self.executing.compareAndExchange(expected: false, desired: true), "expecting single request at a time") + switch self.state { + case .disconnected: + return self.connect().flatMap { channel -> EventLoopFuture in + self.state = .connected(channel) + return self.execute(request, validate: false) + } case .connected(let channel): guard channel.isActive else { - // attempt to reconnect self.state = .disconnected - self.lock.unlock() - return self.execute(request) + return self.execute(request, validate: false) } - self.lock.unlock() + let promise = channel.eventLoop.makePromise(of: Response.self) - let wrapper = HTTPRequestWrapper(request: request, promise: promise) - return channel.writeAndFlush(wrapper).flatMap { - promise.futureResult - } - case .disconnected: - return self.connect().flatMap { - self.lock.unlock() - return self.execute(request) + promise.futureResult.whenComplete { _ in + precondition(self.executing.compareAndExchange(expected: true, desired: false), "invalid execution state") } - default: - preconditionFailure("invalid state \(self.state)") + let wrapper = HTTPRequestWrapper(request: request, promise: promise) + channel.writeAndFlush(wrapper).cascadeFailure(to: promise) + return promise.futureResult } } - private func connect() -> EventLoopFuture { - guard case .disconnected = self.state else { - preconditionFailure("invalid state \(self.state)") - } - self.state = .connecting - let bootstrap = ClientBootstrap(group: eventLoop) + private func connect() -> EventLoopFuture { + let bootstrap = ClientBootstrap(group: self.eventLoop) .channelInitializer { channel in channel.pipeline.addHTTPClientHandlers().flatMap { channel.pipeline.addHandlers([HTTPHandler(keepAlive: self.configuration.keepAlive), @@ -88,9 +86,7 @@ internal class HTTPClient { do { // connect directly via socket address to avoid happy eyeballs (perf) let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port) - return bootstrap.connect(to: address).flatMapThrowing { channel in - self.state = .connected(channel) - } + return bootstrap.connect(to: address) } catch { return self.eventLoop.makeFailedFuture(error) } @@ -126,13 +122,12 @@ internal class HTTPClient { } private enum State { - case connecting - case connected(Channel) case disconnected + case connected(Channel) } } -private class HTTPHandler: ChannelDuplexHandler { +private final class HTTPHandler: ChannelDuplexHandler { typealias OutboundIn = HTTPClient.Request typealias InboundOut = HTTPClient.Response typealias InboundIn = HTTPClientResponsePart @@ -207,15 +202,15 @@ private class HTTPHandler: ChannelDuplexHandler { } } -private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler { +// no need in locks since we validate only one request can run at a time +private final class UnaryHandler: ChannelDuplexHandler { typealias OutboundIn = HTTPRequestWrapper typealias InboundIn = HTTPClient.Response typealias OutboundOut = HTTPClient.Request private let keepAlive: Bool - private let lock = Lock() - private var pendingResponses = CircularBuffer<(EventLoopPromise, Scheduled?)>() + private var pending: (promise: EventLoopPromise, timeout: Scheduled?)? private var lastError: Error? init(keepAlive: Bool) { @@ -223,47 +218,58 @@ private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler { } func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + guard self.pending == nil else { + preconditionFailure("invalid state, outstanding request") + } let wrapper = unwrapOutboundIn(data) let timeoutTask = wrapper.request.timeout.map { context.eventLoop.scheduleTask(in: $0) { - if (self.lock.withLock { !self.pendingResponses.isEmpty }) { - self.errorCaught(context: context, error: HTTPClient.Errors.timeout) + if self.pending != nil { + context.pipeline.fireErrorCaught(HTTPClient.Errors.timeout) } } } - self.lock.withLockVoid { pendingResponses.append((wrapper.promise, timeoutTask)) } + self.pending = (promise: wrapper.promise, timeout: timeoutTask) context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let response = unwrapInboundIn(data) - if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) { - let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive" - let future = self.keepAlive && serverKeepAlive ? context.eventLoop.makeSucceededFuture(()) : context.channel.close() - future.whenComplete { _ in - pending.1?.cancel() - pending.0.succeed(response) + guard let pending = self.pending else { + preconditionFailure("invalid state, no pending request") + } + let serverKeepAlive = response.headers.first(name: "connection")?.lowercased() == "keep-alive" + if !self.keepAlive || !serverKeepAlive { + pending.promise.futureResult.whenComplete { _ in + _ = context.channel.close() } } + self.completeWith(.success(response)) } func errorCaught(context: ChannelHandlerContext, error: Error) { // pending responses will fail with lastError in channelInactive since we are calling context.close - self.lock.withLockVoid { self.lastError = error } + self.lastError = error context.channel.close(promise: nil) } func channelInactive(context: ChannelHandlerContext) { // fail any pending responses with last error or assume peer disconnected - self.failPendingResponses(self.lock.withLock { self.lastError } ?? HTTPClient.Errors.connectionResetByPeer) + if self.pending != nil { + let error = self.lastError ?? HTTPClient.Errors.connectionResetByPeer + self.completeWith(.failure(error)) + } context.fireChannelInactive() } - private func failPendingResponses(_ error: Error) { - while let pending = (self.lock.withLock { pendingResponses.popFirst() }) { - pending.1?.cancel() - pending.0.fail(error) + private func completeWith(_ result: Result) { + guard let pending = self.pending else { + preconditionFailure("invalid state, no pending request") } + self.pending = nil + self.lastError = nil + pending.timeout?.cancel() + pending.promise.completeWith(result) } } diff --git a/Sources/SwiftAwsLambda/Lambda.swift b/Sources/SwiftAwsLambda/Lambda.swift index e41fe4b1..0ac03870 100644 --- a/Sources/SwiftAwsLambda/Lambda.swift +++ b/Sources/SwiftAwsLambda/Lambda.swift @@ -105,7 +105,7 @@ public enum Lambda { } set { self.stateLock.withLockVoid { - precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(_state)") + precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(self._state)") self._state = newValue } } @@ -124,12 +124,12 @@ public enum Lambda { } func stop() { - self.logger.info("lambda lifecycle stopping") + self.logger.debug("lambda lifecycle stopping") self.state = .stopping } func shutdown() { - self.logger.info("lambda lifecycle shutdown") + self.logger.debug("lambda lifecycle shutdown") self.state = .shutdown } diff --git a/Sources/SwiftAwsLambda/LambdaRunner.swift b/Sources/SwiftAwsLambda/LambdaRunner.swift index 58664260..705c93ca 100644 --- a/Sources/SwiftAwsLambda/LambdaRunner.swift +++ b/Sources/SwiftAwsLambda/LambdaRunner.swift @@ -36,7 +36,7 @@ internal struct LambdaRunner { /// /// - Returns: An `EventLoopFuture` fulfilled with the outcome of the initialization. func initialize(logger: Logger) -> EventLoopFuture { - logger.info("initializing lambda") + logger.debug("initializing lambda") // We need to use `flatMap` instead of `whenFailure` to ensure we complete reporting the result before stopping. return self.lambdaHandler.initialize(eventLoop: self.eventLoop, lifecycleId: self.lifecycleId, @@ -69,7 +69,7 @@ internal struct LambdaRunner { } }.always { result in // we are done! - logger.log(level: result.successful ? .info : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")") + logger.log(level: result.successful ? .debug : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")") } } } diff --git a/Sources/SwiftAwsLambda/LambdaRuntimeClient.swift b/Sources/SwiftAwsLambda/LambdaRuntimeClient.swift index 05bb857b..0d546ca3 100644 --- a/Sources/SwiftAwsLambda/LambdaRuntimeClient.swift +++ b/Sources/SwiftAwsLambda/LambdaRuntimeClient.swift @@ -145,7 +145,7 @@ internal struct JsonCodecError: Error, Equatable { } static func == (lhs: JsonCodecError, rhs: JsonCodecError) -> Bool { - return lhs.cause.localizedDescription == rhs.cause.localizedDescription + return String(describing: lhs.cause) == String(describing: rhs.cause) } } diff --git a/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest+XCTest.swift b/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest+XCTest.swift index b35ad1fa..6940ddab 100644 --- a/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest+XCTest.swift +++ b/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest+XCTest.swift @@ -25,7 +25,7 @@ import XCTest extension CodableLambdaTest { static var allTests: [(String, (CodableLambdaTest) -> () throws -> Void)] { return [ - ("testSuceess", testSuceess), + ("testSuccess", testSuccess), ("testFailure", testFailure), ("testClosureSuccess", testClosureSuccess), ("testClosureFailure", testClosureFailure), diff --git a/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift b/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift index a540b1da..539fb20a 100644 --- a/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift +++ b/Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift @@ -16,39 +16,47 @@ import XCTest class CodableLambdaTest: XCTestCase { - func testSuceess() throws { + func testSuccess() { + let server = MockLambdaServer(behavior: GoodBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = Int.random(in: 1 ... 10) let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait() let result = Lambda.run(handler: CodableEchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } - func testFailure() throws { - let server = try MockLambdaServer(behavior: BadBehavior()).start().wait() + func testFailure() { + let server = MockLambdaServer(behavior: BadBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let result = Lambda.run(handler: CodableEchoHandler()) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError)) } - func testClosureSuccess() throws { + func testClosureSuccess() { + let server = MockLambdaServer(behavior: GoodBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = Int.random(in: 1 ... 10) let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait() let result = Lambda.run(configuration: configuration) { (_, payload: Request, callback) in callback(.success(Response(requestId: payload.requestId))) } - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } - func testClosureFailure() throws { - let server = try MockLambdaServer(behavior: BadBehavior()).start().wait() + func testClosureFailure() { + let server = MockLambdaServer(behavior: BadBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let result: LambdaLifecycleResult = Lambda.run { (_, payload: Request, callback) in callback(.success(Response(requestId: payload.requestId))) } - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError)) } } diff --git a/Tests/SwiftAwsLambdaTests/Lambda+StringTest+XCTest.swift b/Tests/SwiftAwsLambdaTests/Lambda+StringTest+XCTest.swift index 3d5d7a85..637fc175 100644 --- a/Tests/SwiftAwsLambdaTests/Lambda+StringTest+XCTest.swift +++ b/Tests/SwiftAwsLambdaTests/Lambda+StringTest+XCTest.swift @@ -25,7 +25,7 @@ import XCTest extension StringLambdaTest { static var allTests: [(String, (StringLambdaTest) -> () throws -> Void)] { return [ - ("testSuceess", testSuceess), + ("testSuccess", testSuccess), ("testFailure", testFailure), ("testClosureSuccess", testClosureSuccess), ("testClosureFailure", testClosureFailure), diff --git a/Tests/SwiftAwsLambdaTests/Lambda+StringTest.swift b/Tests/SwiftAwsLambdaTests/Lambda+StringTest.swift index ec77e7a2..0d3f93cc 100644 --- a/Tests/SwiftAwsLambdaTests/Lambda+StringTest.swift +++ b/Tests/SwiftAwsLambdaTests/Lambda+StringTest.swift @@ -16,39 +16,47 @@ import XCTest class StringLambdaTest: XCTestCase { - func testSuceess() throws { + func testSuccess() { + let server = MockLambdaServer(behavior: GoodBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = Int.random(in: 1 ... 10) let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait() let result = Lambda.run(handler: StringEchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } - func testFailure() throws { - let server = try MockLambdaServer(behavior: BadBehavior()).start().wait() + func testFailure() { + let server = MockLambdaServer(behavior: BadBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let result = Lambda.run(handler: StringEchoHandler()) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError)) } - func testClosureSuccess() throws { + func testClosureSuccess() { + let server = MockLambdaServer(behavior: GoodBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = Int.random(in: 1 ... 10) let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait() let result = Lambda.run(configuration: configuration) { (_, payload: String, callback) in callback(.success(payload)) } - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } - func testClosureFailure() throws { - let server = try MockLambdaServer(behavior: BadBehavior()).start().wait() + func testClosureFailure() { + let server = MockLambdaServer(behavior: BadBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let result: LambdaLifecycleResult = Lambda.run { (_, payload: String, callback) in callback(.success(payload)) } - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError)) } } diff --git a/Tests/SwiftAwsLambdaTests/LambdaRunnerTest.swift b/Tests/SwiftAwsLambdaTests/LambdaRunnerTest.swift index a3e0439d..583ce27b 100644 --- a/Tests/SwiftAwsLambdaTests/LambdaRunnerTest.swift +++ b/Tests/SwiftAwsLambdaTests/LambdaRunnerTest.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +import Logging +import NIO @testable import SwiftAwsLambda import XCTest diff --git a/Tests/SwiftAwsLambdaTests/LambdaTest+XCTest.swift b/Tests/SwiftAwsLambdaTests/LambdaTest+XCTest.swift index 6e31095a..98402082 100644 --- a/Tests/SwiftAwsLambdaTests/LambdaTest+XCTest.swift +++ b/Tests/SwiftAwsLambdaTests/LambdaTest+XCTest.swift @@ -25,7 +25,7 @@ import XCTest extension LambdaTest { static var allTests: [(String, (LambdaTest) -> () throws -> Void)] { return [ - ("testSuceess", testSuceess), + ("testSuccess", testSuccess), ("testFailure", testFailure), ("testInitFailure", testInitFailure), ("testInitFailureAndReportErrorFailure", testInitFailureAndReportErrorFailure), diff --git a/Tests/SwiftAwsLambdaTests/LambdaTest.swift b/Tests/SwiftAwsLambdaTests/LambdaTest.swift index 25e702ab..5f4fee02 100644 --- a/Tests/SwiftAwsLambdaTests/LambdaTest.swift +++ b/Tests/SwiftAwsLambdaTests/LambdaTest.swift @@ -17,57 +17,69 @@ import NIO import XCTest class LambdaTest: XCTestCase { - func testSuceess() throws { + func testSuccess() { + let server = MockLambdaServer(behavior: GoodBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = Int.random(in: 10 ... 20) let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait() let handler = EchoHandler() let result = Lambda.run(handler: handler, configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) XCTAssertEqual(handler.initializeCalls, 1) } - func testFailure() throws { - let server = try MockLambdaServer(behavior: BadBehavior()).start().wait() + func testFailure() { + let server = MockLambdaServer(behavior: BadBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let result = Lambda.run(handler: EchoHandler()) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError)) } - func testInitFailure() throws { - let server = try MockLambdaServer(behavior: GoodBehaviourWhenInitFails()).start().wait() + func testInitFailure() { + let server = MockLambdaServer(behavior: GoodBehaviourWhenInitFails()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let handler = FailedInitializerHandler("kaboom") let result = Lambda.run(handler: handler) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: FailedInitializerHandler.Error(description: "kaboom")) } - func testInitFailureAndReportErrorFailure() throws { - let server = try MockLambdaServer(behavior: BadBehaviourWhenInitFails()).start().wait() + func testInitFailureAndReportErrorFailure() { + let server = MockLambdaServer(behavior: BadBehaviourWhenInitFails()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let handler = FailedInitializerHandler("kaboom") let result = Lambda.run(handler: handler) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: FailedInitializerHandler.Error(description: "kaboom")) } - func testClosureSuccess() throws { + func testClosureSuccess() { + let server = MockLambdaServer(behavior: GoodBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = Int.random(in: 10 ... 20) let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait() let result = Lambda.run(configuration: configuration) { (_, payload: [UInt8], callback: LambdaCallback) in callback(.success(payload)) } - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } - func testClosureFailure() throws { - let server = try MockLambdaServer(behavior: BadBehavior()).start().wait() + func testClosureFailure() { + let server = MockLambdaServer(behavior: BadBehavior()) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let result: LambdaLifecycleResult = Lambda.run { (_, payload: [UInt8], callback: LambdaCallback) in callback(.success(payload)) } - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError)) } @@ -94,49 +106,58 @@ class LambdaTest: XCTestCase { try eventLoopGroup.syncShutdownGracefully() } - func testTimeout() throws { + func testTimeout() { let timeout: Int64 = 100 + let server = MockLambdaServer(behavior: GoodBehavior(requestId: "timeout", payload: "\(timeout * 2)")) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1), runtimeEngine: .init(requestTimeout: .milliseconds(timeout))) - - let server = try MockLambdaServer(behavior: GoodBehavior(requestId: "timeout", payload: "\(timeout * 2)")).start().wait() let result = Lambda.run(handler: EchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.upstreamError("timeout")) } - func testDisconnect() throws { + func testDisconnect() { + let server = MockLambdaServer(behavior: GoodBehavior(requestId: "disconnect")) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1)) - let server = try MockLambdaServer(behavior: GoodBehavior(requestId: "disconnect")).start().wait() let result = Lambda.run(handler: EchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.upstreamError("connectionResetByPeer")) } - func testBigPayload() throws { - let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1)) + func testBigPayload() { let payload = String(repeating: "*", count: 104_448) - let server = try MockLambdaServer(behavior: GoodBehavior(payload: payload)).start().wait() + let server = MockLambdaServer(behavior: GoodBehavior(payload: payload)) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + + let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1)) let result = Lambda.run(handler: EchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: 1) } - func testKeepAliveServer() throws { + func testKeepAliveServer() { + let server = MockLambdaServer(behavior: GoodBehavior(), keepAlive: true) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = 10 let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior(), keepAlive: true).start().wait() let result = Lambda.run(handler: EchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } - func testNoKeepAliveServer() throws { + func testNoKeepAliveServer() { + let server = MockLambdaServer(behavior: GoodBehavior(), keepAlive: false) + XCTAssertNoThrow(try server.start().wait()) + defer { XCTAssertNoThrow(try server.stop().wait()) } + let maxTimes = 10 let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes)) - let server = try MockLambdaServer(behavior: GoodBehavior(), keepAlive: false).start().wait() let result = Lambda.run(handler: EchoHandler(), configuration: configuration) - try server.stop().wait() assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes) } } diff --git a/Tests/SwiftAwsLambdaTests/MockLambdaServer.swift b/Tests/SwiftAwsLambdaTests/MockLambdaServer.swift index c0c18b8c..4daba58e 100644 --- a/Tests/SwiftAwsLambdaTests/MockLambdaServer.swift +++ b/Tests/SwiftAwsLambdaTests/MockLambdaServer.swift @@ -79,8 +79,7 @@ internal final class HTTPHandler: ChannelInboundHandler { private let keepAlive: Bool private let behavior: LambdaServerBehavior - private var requestHead: HTTPRequestHead! - private var requestBody: ByteBuffer? + private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() public init(logger: Logger, keepAlive: Bool, behavior: LambdaServerBehavior) { self.logger = logger @@ -93,23 +92,25 @@ internal final class HTTPHandler: ChannelInboundHandler { switch requestPart { case .head(let head): - self.requestHead = head - self.requestBody?.clear() + self.pending.append((head: head, body: nil)) case .body(var buffer): - if self.requestBody == nil { - self.requestBody = buffer + var request = self.pending.removeFirst() + if request.body == nil { + request.body = buffer } else { - self.requestBody!.writeBuffer(&buffer) + request.body!.writeBuffer(&buffer) } + self.pending.prepend(request) case .end: - self.processRequest(context: context) + let request = self.pending.removeFirst() + self.processRequest(context: context, request: request) } } - func processRequest(context: ChannelHandlerContext) { - self.logger.info("\(self) processing \(self.requestHead.uri)") + func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { + self.logger.info("\(self) processing \(request.head.uri)") - let requestBody = self.requestBody.flatMap { (buffer: ByteBuffer) -> String? in + let requestBody = request.body.flatMap { (buffer: ByteBuffer) -> String? in var buffer = buffer return buffer.readString(length: buffer.readableBytes) } @@ -119,7 +120,7 @@ internal final class HTTPHandler: ChannelInboundHandler { var responseHeaders: [(String, String)]? // Handle post-init-error first to avoid matching the less specific post-error suffix. - if self.requestHead.uri.hasSuffix(Consts.postInitErrorURL) { + if request.head.uri.hasSuffix(Consts.postInitErrorURL) { guard let json = requestBody, let error = ErrorResponse.fromJson(json) else { return self.writeResponse(context: context, status: .badRequest) } @@ -129,7 +130,7 @@ internal final class HTTPHandler: ChannelInboundHandler { case .failure(let error): responseStatus = .init(statusCode: error.rawValue) } - } else if self.requestHead.uri.hasSuffix(Consts.requestWorkURLSuffix) { + } else if request.head.uri.hasSuffix(Consts.requestWorkURLSuffix) { switch self.behavior.getWork() { case .success(let (requestId, result)): if requestId == "timeout" { @@ -143,8 +144,8 @@ internal final class HTTPHandler: ChannelInboundHandler { case .failure(let error): responseStatus = .init(statusCode: error.rawValue) } - } else if self.requestHead.uri.hasSuffix(Consts.postResponseURLSuffix) { - guard let requestId = requestHead.uri.split(separator: "/").dropFirst(3).first, let response = requestBody else { + } else if request.head.uri.hasSuffix(Consts.postResponseURLSuffix) { + guard let requestId = request.head.uri.split(separator: "/").dropFirst(3).first, let response = requestBody else { return self.writeResponse(context: context, status: .badRequest) } switch self.behavior.processResponse(requestId: String(requestId), response: response) { @@ -153,8 +154,8 @@ internal final class HTTPHandler: ChannelInboundHandler { case .failure(let error): responseStatus = .init(statusCode: error.rawValue) } - } else if self.requestHead.uri.hasSuffix(Consts.postErrorURLSuffix) { - guard let requestId = requestHead.uri.split(separator: "/").dropFirst(3).first, + } else if request.head.uri.hasSuffix(Consts.postErrorURLSuffix) { + guard let requestId = request.head.uri.split(separator: "/").dropFirst(3).first, let json = requestBody, let error = ErrorResponse.fromJson(json) else { @@ -169,6 +170,7 @@ internal final class HTTPHandler: ChannelInboundHandler { } else { responseStatus = .notFound } + self.logger.info("\(self) responding to \(request.head.uri)") self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody) } diff --git a/Tests/SwiftAwsLambdaTests/Utils.swift b/Tests/SwiftAwsLambdaTests/Utils.swift index 93b09ffe..13059ecc 100644 --- a/Tests/SwiftAwsLambdaTests/Utils.swift +++ b/Tests/SwiftAwsLambdaTests/Utils.swift @@ -30,7 +30,7 @@ func runLambda(behavior: LambdaServerBehavior, handler: LambdaHandler) throws { }.wait() } -class EchoHandler: LambdaHandler { +final class EchoHandler: LambdaHandler { var initializeCalls = 0 func initialize(callback: @escaping LambdaInitCallBack) {