@@ -17,13 +17,15 @@ import NIOConcurrencyHelpers
17
17
import NIOHTTP1
18
18
19
19
/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
20
- internal class HTTPClient {
20
+ /// Note that Lambda Runtime API dictate that only one requests runs at a time.
21
+ /// This means we can avoid locks and other concurrency concern we would otherwise need to build into the client
22
+ internal final class HTTPClient {
21
23
private let eventLoop : EventLoop
22
24
private let configuration : Lambda . Configuration . RuntimeEngine
23
25
private let targetHost : String
24
26
25
27
private var state = State . disconnected
26
- private let lock = Lock ( )
28
+ private let executing = NIOAtomic . makeAtomic ( value : false )
27
29
28
30
init ( eventLoop: EventLoop , configuration: Lambda . Configuration . RuntimeEngine ) {
29
31
self . eventLoop = eventLoop
@@ -46,38 +48,34 @@ internal class HTTPClient {
46
48
timeout: timeout ?? self . configuration. requestTimeout) )
47
49
}
48
50
49
- private func execute( _ request: Request ) -> EventLoopFuture < Response > {
50
- self . lock. lock ( )
51
+ // TODO: cap reconnect attempt
52
+ private func execute( _ request: Request , validate: Bool = true ) -> EventLoopFuture < Response > {
53
+ precondition ( !validate || self . executing. compareAndExchange ( expected: false , desired: true ) , " expecting single request at a time " )
54
+
51
55
switch self . state {
56
+ case . disconnected:
57
+ return self . connect ( ) . flatMap { channel -> EventLoopFuture < Response > in
58
+ self . state = . connected( channel)
59
+ return self . execute ( request, validate: false )
60
+ }
52
61
case . connected( let channel) :
53
62
guard channel. isActive else {
54
- // attempt to reconnect
55
63
self . state = . disconnected
56
- self . lock. unlock ( )
57
- return self . execute ( request)
64
+ return self . execute ( request, validate: false )
58
65
}
59
- self . lock . unlock ( )
66
+
60
67
let promise = channel. eventLoop. makePromise ( of: Response . self)
61
- let wrapper = HTTPRequestWrapper ( request: request, promise: promise)
62
- return channel. writeAndFlush ( wrapper) . flatMap {
63
- promise. futureResult
64
- }
65
- case . disconnected:
66
- return self . connect ( ) . flatMap {
67
- self . lock. unlock ( )
68
- return self . execute ( request)
68
+ promise. futureResult. whenComplete { _ in
69
+ precondition ( self . executing. compareAndExchange ( expected: true , desired: false ) , " invalid execution state " )
69
70
}
70
- default :
71
- preconditionFailure ( " invalid state \( self . state) " )
71
+ let wrapper = HTTPRequestWrapper ( request: request, promise: promise)
72
+ channel. writeAndFlush ( wrapper) . cascadeFailure ( to: promise)
73
+ return promise. futureResult
72
74
}
73
75
}
74
76
75
- private func connect( ) -> EventLoopFuture < Void > {
76
- guard case . disconnected = self . state else {
77
- preconditionFailure ( " invalid state \( self . state) " )
78
- }
79
- self . state = . connecting
80
- let bootstrap = ClientBootstrap ( group: eventLoop)
77
+ private func connect( ) -> EventLoopFuture < Channel > {
78
+ let bootstrap = ClientBootstrap ( group: self . eventLoop)
81
79
. channelInitializer { channel in
82
80
channel. pipeline. addHTTPClientHandlers ( ) . flatMap {
83
81
channel. pipeline. addHandlers ( [ HTTPHandler ( keepAlive: self . configuration. keepAlive) ,
@@ -88,9 +86,7 @@ internal class HTTPClient {
88
86
do {
89
87
// connect directly via socket address to avoid happy eyeballs (perf)
90
88
let address = try SocketAddress ( ipAddress: self . configuration. ip, port: self . configuration. port)
91
- return bootstrap. connect ( to: address) . flatMapThrowing { channel in
92
- self . state = . connected( channel)
93
- }
89
+ return bootstrap. connect ( to: address)
94
90
} catch {
95
91
return self . eventLoop. makeFailedFuture ( error)
96
92
}
@@ -126,13 +122,12 @@ internal class HTTPClient {
126
122
}
127
123
128
124
private enum State {
129
- case connecting
130
- case connected( Channel )
131
125
case disconnected
126
+ case connected( Channel )
132
127
}
133
128
}
134
129
135
- private class HTTPHandler : ChannelDuplexHandler {
130
+ private final class HTTPHandler : ChannelDuplexHandler {
136
131
typealias OutboundIn = HTTPClient . Request
137
132
typealias InboundOut = HTTPClient . Response
138
133
typealias InboundIn = HTTPClientResponsePart
@@ -207,63 +202,74 @@ private class HTTPHandler: ChannelDuplexHandler {
207
202
}
208
203
}
209
204
210
- private class UnaryHandler : ChannelInboundHandler , ChannelOutboundHandler {
205
+ // no need in locks since we validate only one request can run at a time
206
+ private final class UnaryHandler : ChannelDuplexHandler {
211
207
typealias OutboundIn = HTTPRequestWrapper
212
208
typealias InboundIn = HTTPClient . Response
213
209
typealias OutboundOut = HTTPClient . Request
214
210
215
211
private let keepAlive : Bool
216
212
217
- private let lock = Lock ( )
218
- private var pendingResponses = CircularBuffer < ( EventLoopPromise < HTTPClient . Response > , Scheduled < Void > ? ) > ( )
213
+ private var pending : ( promise: EventLoopPromise < HTTPClient . Response > , timeout: Scheduled < Void > ? ) ?
219
214
private var lastError : Error ?
220
215
221
216
init ( keepAlive: Bool ) {
222
217
self . keepAlive = keepAlive
223
218
}
224
219
225
220
func write( context: ChannelHandlerContext , data: NIOAny , promise: EventLoopPromise < Void > ? ) {
221
+ guard self . pending == nil else {
222
+ preconditionFailure ( " invalid state, outstanding request " )
223
+ }
226
224
let wrapper = unwrapOutboundIn ( data)
227
225
let timeoutTask = wrapper. request. timeout. map {
228
226
context. eventLoop. scheduleTask ( in: $0) {
229
- if ( self . lock . withLock { ! self . pendingResponses . isEmpty } ) {
230
- self . errorCaught ( context : context , error : HTTPClient . Errors. timeout)
227
+ if self . pending != nil {
228
+ context . pipeline . fireErrorCaught ( HTTPClient . Errors. timeout)
231
229
}
232
230
}
233
231
}
234
- self . lock . withLockVoid { pendingResponses . append ( ( wrapper. promise, timeoutTask) ) }
232
+ self . pending = ( promise : wrapper. promise, timeout : timeoutTask)
235
233
context. writeAndFlush ( wrapOutboundOut ( wrapper. request) , promise: promise)
236
234
}
237
235
238
236
func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
239
237
let response = unwrapInboundIn ( data)
240
- if let pending = ( self . lock. withLock { self . pendingResponses. popFirst ( ) } ) {
241
- let serverKeepAlive = response. headers [ " connection " ] . first? . lowercased ( ) == " keep-alive "
242
- let future = self . keepAlive && serverKeepAlive ? context. eventLoop. makeSucceededFuture ( ( ) ) : context. channel. close ( )
243
- future. whenComplete { _ in
244
- pending. 1 ? . cancel ( )
245
- pending. 0 . succeed ( response)
238
+ guard let pending = self . pending else {
239
+ preconditionFailure ( " invalid state, no pending request " )
240
+ }
241
+ let serverKeepAlive = response. headers. first ( name: " connection " ) ? . lowercased ( ) == " keep-alive "
242
+ if !self . keepAlive || !serverKeepAlive {
243
+ pending. promise. futureResult. whenComplete { _ in
244
+ _ = context. channel. close ( )
246
245
}
247
246
}
247
+ self . completeWith ( . success( response) )
248
248
}
249
249
250
250
func errorCaught( context: ChannelHandlerContext , error: Error ) {
251
251
// pending responses will fail with lastError in channelInactive since we are calling context.close
252
- self . lock . withLockVoid { self . lastError = error }
252
+ self . lastError = error
253
253
context. channel. close ( promise: nil )
254
254
}
255
255
256
256
func channelInactive( context: ChannelHandlerContext ) {
257
257
// fail any pending responses with last error or assume peer disconnected
258
- self . failPendingResponses ( self . lock. withLock { self . lastError } ?? HTTPClient . Errors. connectionResetByPeer)
258
+ if self . pending != nil {
259
+ let error = self . lastError ?? HTTPClient . Errors. connectionResetByPeer
260
+ self . completeWith ( . failure( error) )
261
+ }
259
262
context. fireChannelInactive ( )
260
263
}
261
264
262
- private func failPendingResponses( _ error: Error ) {
263
- while let pending = ( self . lock. withLock { pendingResponses. popFirst ( ) } ) {
264
- pending. 1 ? . cancel ( )
265
- pending. 0 . fail ( error)
265
+ private func completeWith( _ result: Result < HTTPClient . Response , Error > ) {
266
+ guard let pending = self . pending else {
267
+ preconditionFailure ( " invalid state, no pending request " )
266
268
}
269
+ self . pending = nil
270
+ self . lastError = nil
271
+ pending. timeout? . cancel ( )
272
+ pending. promise. completeWith ( result)
267
273
}
268
274
}
269
275
0 commit comments