Skip to content

refactor #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions Sources/MockServer/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,42 @@ internal final class HTTPHandler: ChannelInboundHandler {
private let mode: Mode
private let keepAlive: Bool

private var requestHead: HTTPRequestHead!
private var requestBody: ByteBuffer?
private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()

public init(logger: Logger, keepAlive: Bool, mode: Mode) {
self.logger = logger
self.mode = mode
self.keepAlive = keepAlive
self.mode = mode
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let requestPart = unwrapInboundIn(data)

switch requestPart {
case .head(let head):
self.requestHead = head
self.requestBody?.clear()
self.pending.append((head: head, body: nil))
case .body(var buffer):
if self.requestBody == nil {
self.requestBody = buffer
var request = self.pending.removeFirst()
if request.body == nil {
request.body = buffer
} else {
self.requestBody!.writeBuffer(&buffer)
request.body!.writeBuffer(&buffer)
}
self.pending.prepend(request)
case .end:
self.processRequest(context: context)
let request = self.pending.removeFirst()
self.processRequest(context: context, request: request)
}
}

func processRequest(context: ChannelHandlerContext) {
self.logger.debug("\(self) processing \(self.requestHead.uri)")
func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
self.logger.debug("\(self) processing \(request.head.uri)")

var responseStatus: HTTPResponseStatus
var responseBody: String?
var responseHeaders: [(String, String)]?

if self.requestHead.uri.hasSuffix("/next") {
if request.head.uri.hasSuffix("/next") {
let requestId = UUID().uuidString
responseStatus = .ok
switch self.mode {
Expand All @@ -108,7 +109,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
responseBody = "{ \"body\": \"\(requestId)\" }"
}
responseHeaders = [(AmazonHeaders.requestID, requestId)]
} else if self.requestHead.uri.hasSuffix("/response") {
} else if request.head.uri.hasSuffix("/response") {
responseStatus = .accepted
} else {
responseStatus = .notFound
Expand Down
102 changes: 54 additions & 48 deletions Sources/SwiftAwsLambda/HttpClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ import NIOConcurrencyHelpers
import NIOHTTP1

/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
internal class HTTPClient {
/// Note that Lambda Runtime API dictate that only one requests runs at a time.
/// This means we can avoid locks and other concurrency concern we would otherwise need to build into the client
internal final class HTTPClient {
private let eventLoop: EventLoop
private let configuration: Lambda.Configuration.RuntimeEngine
private let targetHost: String

private var state = State.disconnected
private let lock = Lock()
private let executing = NIOAtomic.makeAtomic(value: false)

init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) {
self.eventLoop = eventLoop
Expand All @@ -46,38 +48,34 @@ internal class HTTPClient {
timeout: timeout ?? self.configuration.requestTimeout))
}

private func execute(_ request: Request) -> EventLoopFuture<Response> {
self.lock.lock()
// TODO: cap reconnect attempt
private func execute(_ request: Request, validate: Bool = true) -> EventLoopFuture<Response> {
precondition(!validate || self.executing.compareAndExchange(expected: false, desired: true), "expecting single request at a time")

switch self.state {
case .disconnected:
return self.connect().flatMap { channel -> EventLoopFuture<Response> in
self.state = .connected(channel)
return self.execute(request, validate: false)
}
case .connected(let channel):
guard channel.isActive else {
// attempt to reconnect
self.state = .disconnected
self.lock.unlock()
return self.execute(request)
return self.execute(request, validate: false)
}
self.lock.unlock()

let promise = channel.eventLoop.makePromise(of: Response.self)
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
return channel.writeAndFlush(wrapper).flatMap {
promise.futureResult
}
case .disconnected:
return self.connect().flatMap {
self.lock.unlock()
return self.execute(request)
promise.futureResult.whenComplete { _ in
precondition(self.executing.compareAndExchange(expected: true, desired: false), "invalid execution state")
}
default:
preconditionFailure("invalid state \(self.state)")
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
channel.writeAndFlush(wrapper).cascadeFailure(to: promise)
return promise.futureResult
}
}

private func connect() -> EventLoopFuture<Void> {
guard case .disconnected = self.state else {
preconditionFailure("invalid state \(self.state)")
}
self.state = .connecting
let bootstrap = ClientBootstrap(group: eventLoop)
private func connect() -> EventLoopFuture<Channel> {
let bootstrap = ClientBootstrap(group: self.eventLoop)
.channelInitializer { channel in
channel.pipeline.addHTTPClientHandlers().flatMap {
channel.pipeline.addHandlers([HTTPHandler(keepAlive: self.configuration.keepAlive),
Expand All @@ -88,9 +86,7 @@ internal class HTTPClient {
do {
// connect directly via socket address to avoid happy eyeballs (perf)
let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port)
return bootstrap.connect(to: address).flatMapThrowing { channel in
self.state = .connected(channel)
}
return bootstrap.connect(to: address)
} catch {
return self.eventLoop.makeFailedFuture(error)
}
Expand Down Expand Up @@ -126,13 +122,12 @@ internal class HTTPClient {
}

private enum State {
case connecting
case connected(Channel)
case disconnected
case connected(Channel)
}
}

private class HTTPHandler: ChannelDuplexHandler {
private final class HTTPHandler: ChannelDuplexHandler {
typealias OutboundIn = HTTPClient.Request
typealias InboundOut = HTTPClient.Response
typealias InboundIn = HTTPClientResponsePart
Expand Down Expand Up @@ -207,63 +202,74 @@ private class HTTPHandler: ChannelDuplexHandler {
}
}

private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
// no need in locks since we validate only one request can run at a time
private final class UnaryHandler: ChannelDuplexHandler {
typealias OutboundIn = HTTPRequestWrapper
typealias InboundIn = HTTPClient.Response
typealias OutboundOut = HTTPClient.Request

private let keepAlive: Bool

private let lock = Lock()
private var pendingResponses = CircularBuffer<(EventLoopPromise<HTTPClient.Response>, Scheduled<Void>?)>()
private var pending: (promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)?
private var lastError: Error?

init(keepAlive: Bool) {
self.keepAlive = keepAlive
}

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
guard self.pending == nil else {
preconditionFailure("invalid state, outstanding request")
}
let wrapper = unwrapOutboundIn(data)
let timeoutTask = wrapper.request.timeout.map {
context.eventLoop.scheduleTask(in: $0) {
if (self.lock.withLock { !self.pendingResponses.isEmpty }) {
self.errorCaught(context: context, error: HTTPClient.Errors.timeout)
if self.pending != nil {
context.pipeline.fireErrorCaught(HTTPClient.Errors.timeout)
}
}
}
self.lock.withLockVoid { pendingResponses.append((wrapper.promise, timeoutTask)) }
self.pending = (promise: wrapper.promise, timeout: timeoutTask)
context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise)
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = unwrapInboundIn(data)
if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) {
let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive"
let future = self.keepAlive && serverKeepAlive ? context.eventLoop.makeSucceededFuture(()) : context.channel.close()
future.whenComplete { _ in
pending.1?.cancel()
pending.0.succeed(response)
guard let pending = self.pending else {
preconditionFailure("invalid state, no pending request")
}
let serverKeepAlive = response.headers.first(name: "connection")?.lowercased() == "keep-alive"
if !self.keepAlive || !serverKeepAlive {
pending.promise.futureResult.whenComplete { _ in
_ = context.channel.close()
}
}
self.completeWith(.success(response))
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
// pending responses will fail with lastError in channelInactive since we are calling context.close
self.lock.withLockVoid { self.lastError = error }
self.lastError = error
context.channel.close(promise: nil)
}

func channelInactive(context: ChannelHandlerContext) {
// fail any pending responses with last error or assume peer disconnected
self.failPendingResponses(self.lock.withLock { self.lastError } ?? HTTPClient.Errors.connectionResetByPeer)
if self.pending != nil {
let error = self.lastError ?? HTTPClient.Errors.connectionResetByPeer
self.completeWith(.failure(error))
}
context.fireChannelInactive()
}

private func failPendingResponses(_ error: Error) {
while let pending = (self.lock.withLock { pendingResponses.popFirst() }) {
pending.1?.cancel()
pending.0.fail(error)
private func completeWith(_ result: Result<HTTPClient.Response, Error>) {
guard let pending = self.pending else {
preconditionFailure("invalid state, no pending request")
}
self.pending = nil
self.lastError = nil
pending.timeout?.cancel()
pending.promise.completeWith(result)
}
}

Expand Down
6 changes: 3 additions & 3 deletions Sources/SwiftAwsLambda/Lambda.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public enum Lambda {
}
set {
self.stateLock.withLockVoid {
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(_state)")
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(self._state)")
self._state = newValue
}
}
Expand All @@ -124,12 +124,12 @@ public enum Lambda {
}

func stop() {
self.logger.info("lambda lifecycle stopping")
self.logger.debug("lambda lifecycle stopping")
self.state = .stopping
}

func shutdown() {
self.logger.info("lambda lifecycle shutdown")
self.logger.debug("lambda lifecycle shutdown")
self.state = .shutdown
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/SwiftAwsLambda/LambdaRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal struct LambdaRunner {
///
/// - Returns: An `EventLoopFuture<Void>` fulfilled with the outcome of the initialization.
func initialize(logger: Logger) -> EventLoopFuture<Void> {
logger.info("initializing lambda")
logger.debug("initializing lambda")
// We need to use `flatMap` instead of `whenFailure` to ensure we complete reporting the result before stopping.
return self.lambdaHandler.initialize(eventLoop: self.eventLoop,
lifecycleId: self.lifecycleId,
Expand Down Expand Up @@ -69,7 +69,7 @@ internal struct LambdaRunner {
}
}.always { result in
// we are done!
logger.log(level: result.successful ? .info : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
logger.log(level: result.successful ? .debug : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/SwiftAwsLambda/LambdaRuntimeClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ internal struct JsonCodecError: Error, Equatable {
}

static func == (lhs: JsonCodecError, rhs: JsonCodecError) -> Bool {
return lhs.cause.localizedDescription == rhs.cause.localizedDescription
return String(describing: lhs.cause) == String(describing: rhs.cause)
}
}

Expand Down
2 changes: 1 addition & 1 deletion Tests/SwiftAwsLambdaTests/Lambda+CodeableTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import XCTest
extension CodableLambdaTest {
static var allTests: [(String, (CodableLambdaTest) -> () throws -> Void)] {
return [
("testSuceess", testSuceess),
("testSuccess", testSuccess),
("testFailure", testFailure),
("testClosureSuccess", testClosureSuccess),
("testClosureFailure", testClosureFailure),
Expand Down
32 changes: 20 additions & 12 deletions Tests/SwiftAwsLambdaTests/Lambda+CodeableTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,47 @@
import XCTest

class CodableLambdaTest: XCTestCase {
func testSuceess() throws {
func testSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }

let maxTimes = Int.random(in: 1 ... 10)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(handler: CodableEchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}

func testFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }

let result = Lambda.run(handler: CodableEchoHandler())
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}

func testClosureSuccess() throws {
func testClosureSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }

let maxTimes = Int.random(in: 1 ... 10)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(configuration: configuration) { (_, payload: Request, callback) in
callback(.success(Response(requestId: payload.requestId)))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}

func testClosureFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testClosureFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }

let result: LambdaLifecycleResult = Lambda.run { (_, payload: Request, callback) in
callback(.success(Response(requestId: payload.requestId)))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/SwiftAwsLambdaTests/Lambda+StringTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import XCTest
extension StringLambdaTest {
static var allTests: [(String, (StringLambdaTest) -> () throws -> Void)] {
return [
("testSuceess", testSuceess),
("testSuccess", testSuccess),
("testFailure", testFailure),
("testClosureSuccess", testClosureSuccess),
("testClosureFailure", testClosureFailure),
Expand Down
Loading