Skip to content

Add a connection pool #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
653 changes: 653 additions & 0 deletions Sources/AsyncHTTPClient/ConnectionPool.swift

Large diffs are not rendered by default.

325 changes: 222 additions & 103 deletions Sources/AsyncHTTPClient/HTTPClient.swift

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
case awaitingResponse
case connecting
case connected
case failed
}

private let host: String
Expand Down Expand Up @@ -102,6 +103,7 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
// blank line that concludes the successful response's header section
break
case 407:
self.readState = .failed
context.fireErrorCaught(HTTPClientError.proxyAuthenticationRequired)
default:
// Any response other than a successful response
Expand All @@ -119,6 +121,8 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan
self.readBuffer.append(data)
case .connected:
context.fireChannelRead(data)
case .failed:
break
}
}

Expand Down
125 changes: 98 additions & 27 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import NIO
import NIOConcurrencyHelpers
import NIOFoundationCompat
import NIOHTTP1
import NIOHTTPCompression
import NIOSSL

extension HTTPClient {
Expand Down Expand Up @@ -486,22 +487,31 @@ extension URL {
extension HTTPClient {
/// Response execution context. Will be created by the library and could be used for obtaining
/// `EventLoopFuture<Response>` of the execution or cancellation of the execution.
public final class Task<Response> {
public final class Task<Response>: TaskProtocol {
/// The `EventLoop` the delegate will be executed on.
public let eventLoop: EventLoop

let promise: EventLoopPromise<Response>
var channel: Channel?
private var cancelled: Bool
private let lock: Lock
var completion: EventLoopFuture<Void>
var connection: ConnectionPool.Connection?
var cancelled: Bool
let lock: Lock
let id = UUID()

init(eventLoop: EventLoop) {
self.eventLoop = eventLoop
self.promise = eventLoop.makePromise()
self.completion = self.promise.futureResult.map { _ in }
self.cancelled = false
self.lock = Lock()
}

static func failedTask(eventLoop: EventLoop, error: Error) -> Task<Response> {
let task = self.init(eventLoop: eventLoop)
task.promise.fail(error)
return task
}

/// `EventLoopFuture` for the response returned by this request.
public var futureResult: EventLoopFuture<Response> {
return self.promise.futureResult
Expand All @@ -520,28 +530,74 @@ extension HTTPClient {
let channel: Channel? = self.lock.withLock {
if !cancelled {
cancelled = true
return self.channel
return self.connection?.channel
} else {
return nil
}
return nil
}
channel?.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil)
}

@discardableResult
func setChannel(_ channel: Channel) -> Channel {
func setConnection(_ connection: ConnectionPool.Connection) -> ConnectionPool.Connection {
return self.lock.withLock {
self.channel = channel
return channel
self.connection = connection
if self.cancelled {
connection.channel.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil)
}
return connection
}
}

func succeed<Delegate: HTTPClientResponseDelegate>(promise: EventLoopPromise<Response>?, with value: Response, delegateType: Delegate.Type) {
self.releaseAssociatedConnection(delegateType: delegateType).whenSuccess {
promise?.succeed(value)
}
}

func fail<Delegate: HTTPClientResponseDelegate>(with error: Error, delegateType: Delegate.Type) {
if let connection = self.connection {
connection.close().whenComplete { _ in
self.releaseAssociatedConnection(delegateType: delegateType).whenComplete { _ in
self.promise.fail(error)
}
}
}
}

func releaseAssociatedConnection<Delegate: HTTPClientResponseDelegate>(delegateType: Delegate.Type) -> EventLoopFuture<Void> {
if let connection = self.connection {
return connection.removeHandler(NIOHTTPResponseDecompressor.self).flatMap {
connection.removeHandler(IdleStateHandler.self)
}.flatMap {
connection.removeHandler(TaskHandler<Delegate>.self)
}.map {
connection.release()
}.flatMapError { error in
fatalError("Couldn't remove taskHandler: \(error)")
}

} else {
// TODO: This seems only reached in some internal unit test
// Maybe there could be a better handling in the future to make
// it an error outside of testing contexts
return self.eventLoop.makeSucceededFuture(())
}
}
}
}

internal struct TaskCancelEvent {}

internal protocol TaskProtocol {
func cancel()
var id: UUID { get }
var completion: EventLoopFuture<Void> { get }
}

// MARK: - TaskHandler

internal class TaskHandler<Delegate: HTTPClientResponseDelegate> {
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
enum State {
case idle
case sent
Expand Down Expand Up @@ -581,7 +637,7 @@ extension TaskHandler {
_ body: @escaping (HTTPClient.Task<Delegate.Response>, Err) -> Void) {
func doIt() {
body(self.task, error)
self.task.promise.fail(error)
self.task.fail(with: error, delegateType: Delegate.self)
}

if self.task.eventLoop.inEventLoop {
Expand Down Expand Up @@ -621,13 +677,14 @@ extension TaskHandler {
}

func callOutToDelegate<Response>(promise: EventLoopPromise<Response>? = nil,
_ body: @escaping (HTTPClient.Task<Delegate.Response>) throws -> Response) {
_ body: @escaping (HTTPClient.Task<Delegate.Response>) throws -> Response) where Response == Delegate.Response {
func doIt() {
do {
let result = try body(self.task)
promise?.succeed(result)

self.task.succeed(promise: promise, with: result, delegateType: Delegate.self)
} catch {
promise?.fail(error)
self.task.fail(with: error, delegateType: Delegate.self)
}
}

Expand All @@ -641,7 +698,7 @@ extension TaskHandler {
}

func callOutToDelegate<Response>(channelEventLoop: EventLoop,
_ body: @escaping (HTTPClient.Task<Delegate.Response>) throws -> Response) -> EventLoopFuture<Response> {
_ body: @escaping (HTTPClient.Task<Delegate.Response>) throws -> Response) -> EventLoopFuture<Response> where Response == Delegate.Response {
let promise = channelEventLoop.makePromise(of: Response.self)
self.callOutToDelegate(promise: promise, body)
return promise.futureResult
Expand Down Expand Up @@ -678,8 +735,6 @@ extension TaskHandler: ChannelDuplexHandler {
headers.add(name: "Host", value: request.host)
}

headers.add(name: "Connection", value: "close")

do {
try headers.validate(body: request.body)
} catch {
Expand All @@ -702,16 +757,10 @@ extension TaskHandler: ChannelDuplexHandler {
context.eventLoop.assertInEventLoop()
self.state = .sent
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)

let channel = context.channel
self.task.futureResult.whenComplete { _ in
channel.close(promise: nil)
}
}.flatMapErrorThrowing { error in
context.eventLoop.assertInEventLoop()
self.state = .end
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
context.close(promise: nil)
throw error
}.cascade(to: promise)
}
Expand Down Expand Up @@ -742,6 +791,16 @@ extension TaskHandler: ChannelDuplexHandler {
let response = self.unwrapInboundIn(data)
switch response {
case .head(let head):
if !head.isKeepAlive {
self.task.lock.withLock {
if let connection = self.task.connection {
connection.isClosing = true
} else {
preconditionFailure("There should always be a connection at this point")
}
}
}

if let redirectURL = redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
self.state = .redirected(head, redirectURL)
} else {
Expand All @@ -768,8 +827,9 @@ extension TaskHandler: ChannelDuplexHandler {
switch self.state {
case .redirected(let head, let redirectURL):
self.state = .end
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
context.close(promise: nil)
self.task.releaseAssociatedConnection(delegateType: Delegate.self).whenSuccess {
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
}
default:
self.state = .end
self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest)
Expand Down Expand Up @@ -845,6 +905,13 @@ extension TaskHandler: ChannelDuplexHandler {
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
}
}

func handlerAdded(context: ChannelHandlerContext) {
guard context.channel.isActive else {
self.failTaskAndNotifyDelegate(error: HTTPClientError.remoteConnectionClosed, self.delegate.didReceiveError)
return
}
}
}

// MARK: - RedirectHandler
Expand Down Expand Up @@ -931,9 +998,13 @@ internal struct RedirectHandler<ResponseType> {
do {
var newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
newRequest.redirectState = nextState
return self.execute(newRequest).futureResult.cascade(to: promise)
self.execute(newRequest).futureResult.whenComplete { result in
promise.futureResult.eventLoop.execute {
promise.completeWith(result)
}
}
} catch {
return promise.fail(error)
promise.fail(error)
}
}
}
51 changes: 51 additions & 0 deletions Sources/AsyncHTTPClient/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import NIO
import NIOHTTP1
import NIOHTTPCompression

internal extension String {
var isIPAddress: Bool {
Expand Down Expand Up @@ -44,3 +45,53 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate {
return ()
}
}

extension ClientBootstrap {
static func makeHTTPClientBootstrapBase(group: EventLoopGroup, host: String, port: Int, configuration: HTTPClient.Configuration, channelInitializer: ((Channel) -> EventLoopFuture<Void>)? = nil) -> ClientBootstrap {
return ClientBootstrap(group: group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)

.channelInitializer { channel in
let channelAddedFuture: EventLoopFuture<Void>
switch configuration.proxy {
case .none:
channelAddedFuture = group.next().makeSucceededFuture(())
case .some:
channelAddedFuture = channel.pipeline.addProxyHandler(host: host, port: port, authorization: configuration.proxy?.authorization)
}
return channelAddedFuture.flatMap { (_: Void) -> EventLoopFuture<Void> in
channelInitializer?(channel) ?? group.next().makeSucceededFuture(())
}
}
}
}

extension CircularBuffer {
@discardableResult
mutating func swapWithFirstAndRemove(at index: Index) -> Element? {
precondition(index >= self.startIndex && index < self.endIndex)
if !self.isEmpty {
self.swapAt(self.startIndex, index)
return self.removeFirst()
} else {
return nil
}
}

@discardableResult
mutating func swapWithFirstAndRemove(where predicate: (Element) throws -> Bool) rethrows -> Element? {
if let existingIndex = try self.firstIndex(where: predicate) {
return self.swapWithFirstAndRemove(at: existingIndex)
} else {
return nil
}
}
}

extension ConnectionPool.Connection {
func removeHandler<Handler: RemovableChannelHandler>(_ type: Handler.Type) -> EventLoopFuture<Void> {
return self.channel.pipeline.handler(type: type).flatMap { handler in
self.channel.pipeline.removeHandler(handler)
}.recover { _ in }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extension HTTPClientInternalTests {
("testUploadStreamingBackpressure", testUploadStreamingBackpressure),
("testRequestURITrailingSlash", testRequestURITrailingSlash),
("testChannelAndDelegateOnDifferentEventLoops", testChannelAndDelegateOnDifferentEventLoops),
("testResponseConnectionCloseGet", testResponseConnectionCloseGet),
]
}
}
Loading