Skip to content

Commit 95884a5

Browse files
committed
another approach
1 parent 84fed4b commit 95884a5

File tree

3 files changed

+29
-92
lines changed

3 files changed

+29
-92
lines changed

Sources/SwiftAwsLambda/HttpClient.swift

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ import NIOConcurrencyHelpers
1717
import NIOHTTP1
1818

1919
/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
20+
/// Note that Lambda Runtime API dictate that only one requests runs at a time.
21+
/// This means we can avoid locks and other concurrency concern we would otherwise need to build into the client
2022
internal final class HTTPClient {
2123
private let eventLoop: EventLoop
2224
private let configuration: Lambda.Configuration.RuntimeEngine
2325
private let targetHost: String
2426

2527
private var state = State.disconnected
26-
private let stateLock = Lock()
28+
private let executing = NIOAtomic.makeAtomic(value: false)
2729

2830
init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) {
2931
self.eventLoop = eventLoop
@@ -46,38 +48,26 @@ internal final class HTTPClient {
4648
timeout: timeout ?? self.configuration.requestTimeout))
4749
}
4850

49-
private func execute(_ request: Request) -> EventLoopFuture<Response> {
50-
self.stateLock.lock()
51+
// TODO: cap reconnect attempt
52+
private func execute(_ request: Request, validate: Bool = true) -> EventLoopFuture<Response> {
53+
precondition(!validate || self.executing.compareAndExchange(expected: false, desired: true), "expecting single request at a time")
54+
5155
switch self.state {
5256
case .disconnected:
53-
let promise = self.eventLoop.makePromise(of: Response.self)
54-
self.state = .connecting(promise.futureResult)
55-
self.stateLock.unlock()
56-
self.connect().flatMap { channel -> EventLoopFuture<Response> in
57-
self.stateLock.withLock {
58-
guard case .connecting = self.state else {
59-
preconditionFailure("invalid state \(self.state)")
60-
}
61-
self.state = .connected(channel)
62-
}
63-
return self.execute(request)
64-
}.cascade(to: promise)
65-
return promise.futureResult
66-
case .connecting(let future):
67-
let future = future.flatMap { _ in
68-
self.execute(request)
57+
return self.connect().flatMap { channel -> EventLoopFuture<Response> in
58+
self.state = .connected(channel)
59+
return self.execute(request, validate: false)
6960
}
70-
self.state = .connecting(future)
71-
self.stateLock.unlock()
72-
return future
7361
case .connected(let channel):
7462
guard channel.isActive else {
7563
self.state = .disconnected
76-
self.stateLock.unlock()
77-
return self.execute(request)
64+
return self.execute(request, validate: false)
7865
}
79-
self.stateLock.unlock()
66+
8067
let promise = channel.eventLoop.makePromise(of: Response.self)
68+
promise.futureResult.whenComplete { _ in
69+
precondition(self.executing.compareAndExchange(expected: true, desired: false), "invalid execution state")
70+
}
8171
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
8272
channel.writeAndFlush(wrapper).cascadeFailure(to: promise)
8373
return promise.futureResult
@@ -133,7 +123,6 @@ internal final class HTTPClient {
133123

134124
private enum State {
135125
case disconnected
136-
case connecting(EventLoopFuture<Response>)
137126
case connected(Channel)
138127
}
139128
}
@@ -213,15 +202,15 @@ private final class HTTPHandler: ChannelDuplexHandler {
213202
}
214203
}
215204

216-
private final class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
205+
// no need in locks since we validate only one request can run at a time
206+
private final class UnaryHandler: ChannelDuplexHandler {
217207
typealias OutboundIn = HTTPRequestWrapper
218208
typealias InboundIn = HTTPClient.Response
219209
typealias OutboundOut = HTTPClient.Request
220210

221211
private let keepAlive: Bool
222212

223-
private let lock = Lock()
224-
private var pendingResponses = CircularBuffer<(promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)>()
213+
private var pending: (promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)?
225214
private var lastError: Error?
226215

227216
init(keepAlive: Bool) {
@@ -232,19 +221,20 @@ private final class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler
232221
let wrapper = unwrapOutboundIn(data)
233222
let timeoutTask = wrapper.request.timeout.map {
234223
context.eventLoop.scheduleTask(in: $0) {
235-
if (self.lock.withLock { !self.pendingResponses.isEmpty }) {
236-
self.errorCaught(context: context, error: HTTPClient.Errors.timeout)
224+
if self.pending != nil {
225+
// TODO: need to verify this is thread safe i.e tha the timeout event wont interleave with the normal hander events
226+
context.pipeline.fireErrorCaught(HTTPClient.Errors.timeout)
237227
}
238228
}
239229
}
240-
self.lock.withLockVoid { pendingResponses.append((promise: wrapper.promise, timeout: timeoutTask)) }
230+
self.pending = (promise: wrapper.promise, timeout: timeoutTask)
241231
context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise)
242232
}
243233

244234
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
245235
let response = unwrapInboundIn(data)
246-
if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) {
247-
let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive"
236+
if let pending = self.pending {
237+
let serverKeepAlive = response.headers.first(name: "connection")?.lowercased() == "keep-alive"
248238
if !self.keepAlive || !serverKeepAlive {
249239
pending.promise.futureResult.whenComplete { _ in
250240
_ = context.channel.close()
@@ -257,20 +247,20 @@ private final class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler
257247

258248
func errorCaught(context: ChannelHandlerContext, error: Error) {
259249
// pending responses will fail with lastError in channelInactive since we are calling context.close
260-
self.lock.withLockVoid { self.lastError = error }
250+
self.lastError = error
261251
context.channel.close(promise: nil)
262252
}
263253

264254
func channelInactive(context: ChannelHandlerContext) {
265255
// fail any pending responses with last error or assume peer disconnected
266-
self.failPendingResponses(self.lock.withLock { self.lastError } ?? HTTPClient.Errors.connectionResetByPeer)
256+
self.failPendingResponses(self.lastError ?? HTTPClient.Errors.connectionResetByPeer)
267257
context.fireChannelInactive()
268258
}
269259

270260
private func failPendingResponses(_ error: Error) {
271-
while let pending = (self.lock.withLock { pendingResponses.popFirst() }) {
272-
pending.1?.cancel()
273-
pending.0.fail(error)
261+
if let pending = self.pending {
262+
pending.timeout?.cancel()
263+
pending.promise.fail(error)
274264
}
275265
}
276266
}

Tests/SwiftAwsLambdaTests/LambdaRunnerTest+XCTest.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ extension LambdaRunnerTest {
2727
return [
2828
("testSuccess", testSuccess),
2929
("testFailure", testFailure),
30-
("testConcurrency", testConcurrency),
3130
]
3231
}
3332
}

Tests/SwiftAwsLambdaTests/LambdaRunnerTest.swift

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -71,56 +71,4 @@ class LambdaRunnerTest: XCTestCase {
7171
}
7272
XCTAssertNoThrow(try runLambda(behavior: Behavior(), handler: FailedHandler(Behavior.error)))
7373
}
74-
75-
func testConcurrency() throws {
76-
struct Behavior: LambdaServerBehavior {
77-
let requestId = UUID().uuidString
78-
let payload = "hello"
79-
func getWork() -> GetWorkResult {
80-
return .success((self.requestId, self.payload))
81-
}
82-
83-
func processResponse(requestId: String, response: String) -> ProcessResponseResult {
84-
XCTAssertEqual(self.requestId, requestId, "expecting requestId to match")
85-
XCTAssertEqual(self.payload, response, "expecting response to match")
86-
return .success
87-
}
88-
89-
func processError(requestId: String, error: ErrorResponse) -> ProcessErrorResult {
90-
XCTFail("should not report error")
91-
return .failure(.internalServerError)
92-
}
93-
94-
func processInitError(error: ErrorResponse) -> ProcessInitErrorResult {
95-
XCTFail("should not report init error")
96-
return .failure(.internalServerError)
97-
}
98-
}
99-
100-
let server = try MockLambdaServer(behavior: Behavior()).start().wait()
101-
defer { XCTAssertNoThrow(try server.stop().wait()) }
102-
103-
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
104-
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
105-
let logger = Logger(label: "TestLogger")
106-
let configuration = Lambda.Configuration(runtimeEngine: .init(requestTimeout: .seconds(1)))
107-
let runner = LambdaRunner(eventLoop: eventLoopGroup.next(), configuration: configuration, lambdaHandler: EchoHandler())
108-
XCTAssertNoThrow(try runner.initialize(logger: logger).wait())
109-
110-
let total = 50
111-
let group = DispatchGroup()
112-
for _ in 0 ..< total {
113-
group.enter()
114-
DispatchQueue.global().async {
115-
runner.run(logger: logger).whenComplete { result in
116-
if case .failure(let error) = result {
117-
XCTFail("should not fail, but failed with \(error)")
118-
}
119-
group.leave()
120-
}
121-
}
122-
}
123-
124-
group.wait()
125-
}
12674
}

0 commit comments

Comments
 (0)