diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift new file mode 100644 index 000000000..8e01efc52 --- /dev/null +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -0,0 +1,653 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2019-2020 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import NIO +import NIOConcurrencyHelpers +import NIOHTTP1 +import NIOTLS + +/// A connection pool that manages and creates new connections to hosts respecting the specified preferences +/// +/// - Note: All `internal` methods of this class are thread safe +final class ConnectionPool { + /// The configuration used to bootstrap new HTTP connections + private let configuration: HTTPClient.Configuration + + /// The main data structure used by the `ConnectionPool` to retreive and create connections associated + /// to a given `Key` . + /// - Warning: This property should be accessed with proper synchronization, see `connectionProvidersLock` + private var connectionProviders: [Key: HTTP1ConnectionProvider] = [:] + + /// The lock used by the connection pool used to ensure correct synchronization of accesses to `_connectionProviders` + /// + /// + /// - Warning: This lock should always be acquired *before* `HTTP1ConnectionProvider`s `stateLock` if used in combination with it. + private let connectionProvidersLock = Lock() + + init(configuration: HTTPClient.Configuration) { + self.configuration = configuration + } + + /// Gets the `EventLoop` associated with the given `Key` if it exists + /// + /// This is part of an optimization used by the `.execute(...)` method when + /// a request has its `EventLoopPreference` property set to `.indifferent`. + /// Having a default `EventLoop` shared by the *channel* and the *delegate* avoids + /// loss of performance due to `EventLoop` hopping + func associatedEventLoop(for key: Key) -> EventLoop? { + return self.connectionProvidersLock.withLock { + self.connectionProviders[key]?.eventLoop + } + } + + /// This method asks the pool for a connection usable by the specified `request`, respecting the specified options. + /// + /// - parameter request: The request that needs a `Connection` + /// - parameter preference: The `EventLoopPreference` the connection pool will respect to lease a new connection + /// - parameter deadline: The connection timeout + /// - Returns: A connection corresponding to the specified parameters + /// + /// When the pool is asked for a new connection, it creates a `Key` from the url associated to the `request`. This key + /// is used to determine if there already exists an associated `HTTP1ConnectionProvider` in `connectionProviders`. + /// If there is, the connection provider then takes care of leasing a new connection. If a connection provider doesn't exist, it is created. + func getConnection(for request: HTTPClient.Request, preference: HTTPClient.EventLoopPreference, on eventLoop: EventLoop, deadline: NIODeadline?) -> EventLoopFuture { + let key = Key(request) + + let provider: HTTP1ConnectionProvider = self.connectionProvidersLock.withLock { + if let existing = self.connectionProviders[key] { + existing.stateLock.withLock { + existing.state.pending += 1 + } + return existing + } else { + let http1Provider = HTTP1ConnectionProvider(key: key, eventLoop: eventLoop, configuration: self.configuration, parentPool: self) + self.connectionProviders[key] = http1Provider + http1Provider.stateLock.withLock { + http1Provider.state.pending += 1 + } + return http1Provider + } + } + + return provider.getConnection(preference: preference) + } + + func release(_ connection: Connection) { + let connectionProvider = self.connectionProvidersLock.withLock { + self.connectionProviders[connection.key] + } + if let connectionProvider = connectionProvider { + connectionProvider.release(connection: connection) + } + } + + func prepareForClose() { + let connectionProviders = self.connectionProvidersLock.withLock { self.connectionProviders.values } + for connectionProvider in connectionProviders { + connectionProvider.prepareForClose() + } + } + + func syncClose() { + let connectionProviders = self.connectionProvidersLock.withLock { self.connectionProviders.values } + for connectionProvider in connectionProviders { + connectionProvider.syncClose() + } + self.connectionProvidersLock.withLock { + assert(self.connectionProviders.count == 0, "left-overs: \(self.connectionProviders)") + } + } + + var connectionProviderCount: Int { + return self.connectionProvidersLock.withLock { + self.connectionProviders.count + } + } + + /// Used by the `ConnectionPool` to index its `HTTP1ConnectionProvider`s + /// + /// A key is initialized from a `URL`, it uses the components to derive a hashed value + /// used by the `connectionProviders` dictionary to allow retrieving and creating + /// connection providers associated to a certain request in constant time. + struct Key: Hashable { + init(_ request: HTTPClient.Request) { + switch request.scheme { + case "http": + self.scheme = .http + case "https": + self.scheme = .https + case "unix": + self.scheme = .unix + self.unixPath = request.url.baseURL?.path ?? request.url.path + default: + fatalError("HTTPClient.Request scheme should already be a valid one") + } + self.port = request.port + self.host = request.host + } + + var scheme: Scheme + var host: String + var port: Int + var unixPath: String = "" + + enum Scheme: Hashable { + case http + case https + case unix + } + } + + /// A `Connection` represents a `Channel` in the context of the connection pool + /// + /// In the `ConnectionPool`, each `Channel` belongs to a given `HTTP1ConnectionProvider` + /// and has a certain "lease state" (see the `isLeased` property). + /// The role of `Connection` is to model this by storing a `Channel` alongside its associated properties + /// so that they can be passed around together. + /// + /// - Warning: `Connection` properties are not thread-safe and should be used with proper synchronization + class Connection: CustomStringConvertible { + init(key: Key, channel: Channel, parentPool: ConnectionPool) { + self.key = key + self.channel = channel + self.parentPool = parentPool + self.closePromise = channel.eventLoop.makePromise(of: Void.self) + self.closeFuture = self.closePromise.futureResult + } + + /// Release this `Connection` to its associated `HTTP1ConnectionProvider` in the parent `ConnectionPool` + /// + /// This is exactly equivalent to calling `.release(theProvider)` on `ConnectionPool` + /// + /// - Warning: This only releases the connection and doesn't take care of cleaning handlers in the + /// `Channel` pipeline. + func release() { + self.parentPool.release(self) + } + + func close() -> EventLoopFuture { + self.channel.close(promise: nil) + return self.closeFuture + } + + var description: String { + return "Connection { channel: \(self.channel) }" + } + + /// The connection pool this `Connection` belongs to. + /// + /// This enables calling methods like `release()` directly on a `Connection` instead of + /// calling `pool.release(connection)`. This gives a more object oriented feel to the API + /// and can avoid having to keep explicit references to the pool at call site. + let parentPool: ConnectionPool + + /// The `Key` of the `HTTP1ConnectionProvider` this `Connection` belongs to + /// + /// This lets `ConnectionPool` know the relationship between `Connection`s and `HTTP1ConnectionProvider`s + fileprivate let key: Key + + /// The `Channel` of this `Connection` + /// + /// - Warning: Requests that lease connections from the `ConnectionPool` are responsible + /// for removing the specific handlers they added to the `Channel` pipeline before releasing it to the pool. + let channel: Channel + + /// Wether the connection is currently leased or not + var isLeased: Bool = false + + /// Indicates that this connection is about to close + var isClosing: Bool = false + + /// Indicates wether the usual close callback should be run or not, this allows customizing what happens + /// on close in some cases such as for the `.replaceConnection` action + /// + /// - Warning: This should be accessed under the `stateLock` of `HTTP1ConnectionProvider` + fileprivate var mustRunDefaultCloseCallback: Bool = true + + /// Convenience property indicating wether the underlying `Channel` is active or not + var isActiveEstimation: Bool { + return self.channel.isActive + } + + fileprivate var closePromise: EventLoopPromise + + var closeFuture: EventLoopFuture + } + + /// A connection provider of `HTTP/1.1` connections with a given `Key` (host, scheme, port) + /// + /// On top of enabling connection reuse this provider it also facilitates the creation + /// of concurrent requests as it has built-in politeness regarding the maximum number + /// of concurrent requests to the server. + class HTTP1ConnectionProvider: CustomStringConvertible { + /// The default `EventLoop` for this provider + /// + /// The default event loop is used to create futures and is used + /// when creating `Channel`s for requests for which the + /// `EventLoopPreference` is set to `.indifferent` + let eventLoop: EventLoop + + /// The client configuration used to bootstrap new requests + private let configuration: HTTPClient.Configuration + + /// The key associated with this provider + private let key: ConnectionPool.Key + + /// The `State` of this provider + /// + /// This property holds data structures representing the current state of the provider + /// - Warning: This type isn't thread safe and should be accessed with proper + /// synchronization (see the `stateLock` property) + fileprivate var state: State + + /// The lock used to access and modify the `state` property + /// + /// - Warning: This lock should always be acquired *after* `ConnectionPool`s `connectionProvidersLock` if used in combination with it. + fileprivate let stateLock = Lock() + + /// The maximum number of concurrent connections to a given (host, scheme, port) + private let maximumConcurrentConnections: Int = 8 + + /// The pool this provider belongs to + private let parentPool: ConnectionPool + + /// Creates a new `HTTP1ConnectionProvider` + /// + /// - parameters: + /// - key: The `Key` (host, scheme, port) this provider is associated to + /// - configuration: The client configuration used globally by all requests + /// - initialConnection: The initial connection the pool initializes this provider with + /// - parentPool: The pool this provider belongs to + init(key: ConnectionPool.Key, eventLoop: EventLoop, configuration: HTTPClient.Configuration, parentPool: ConnectionPool) { + self.eventLoop = eventLoop + self.configuration = configuration + self.key = key + self.parentPool = parentPool + self.state = State(eventLoop: eventLoop, parentPool: parentPool, key: key) + } + + deinit { + assert(self.state.activity == .closed, "Non closed on deinit") + assert(self.state.availableConnections.isEmpty, "Available connections should be empty before deinit") + assert(self.state.leased == 0, "All leased connections should have been returned before deinit") + assert(self.state.waiters.count == 0, "Waiters on deinit: \(self.state.waiters)") + } + + var description: String { + return "HTTP1ConnectionProvider { key: \(self.key), state: \(self.state) }" + } + + func getConnection(preference: HTTPClient.EventLoopPreference) -> EventLoopFuture { + self.activityPrecondition(expected: [.opened]) + let action = self.stateLock.withLock { self.state.connectionAction(for: preference) } + switch action { + case .leaseConnection(let connection): + return connection.channel.eventLoop.makeSucceededFuture(connection) + case .makeConnection(let eventLoop): + return self.makeConnection(on: eventLoop) + case .leaseFutureConnection(let futureConnection): + return futureConnection + } + } + + func release(connection: Connection) { + self.activityPrecondition(expected: [.opened, .closing]) + let action = self.parentPool.connectionProvidersLock.withLock { + self.stateLock.withLock { self.state.releaseAction(for: connection) } + } + switch action { + case .succeed(let promise): + promise.succeed(connection) + + case .makeConnectionAndComplete(let eventLoop, let promise): + self.makeConnection(on: eventLoop).cascade(to: promise) + + case .replaceConnection(let eventLoop, let promise): + connection.close().flatMap { + self.makeConnection(on: eventLoop) + }.whenComplete { result in + switch result { + case .success(let connection): + promise.succeed(connection) + case .failure(let error): + promise.fail(error) + } + } + + case .none: + break + } + } + + private func makeConnection(on eventLoop: EventLoop) -> EventLoopFuture { + self.activityPrecondition(expected: [.opened]) + let handshakePromise = eventLoop.makePromise(of: Void.self) + let bootstrap = ClientBootstrap.makeHTTPClientBootstrapBase(group: eventLoop, host: self.key.host, port: self.key.port, configuration: self.configuration) + let address = HTTPClient.resolveAddress(host: self.key.host, port: self.key.port, proxy: self.configuration.proxy) + + let channel: EventLoopFuture + switch self.key.scheme { + case .http, .https: + channel = bootstrap.connect(host: address.host, port: address.port) + case .unix: + channel = bootstrap.connect(unixDomainSocketPath: self.key.unixPath) + } + + return channel.flatMap { channel -> EventLoopFuture in + channel.pipeline.addSSLHandlerIfNeeded(for: self.key, tlsConfiguration: self.configuration.tlsConfiguration, handshakePromise: handshakePromise).flatMap { + channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) + }.map { + let connection = Connection(key: self.key, channel: channel, parentPool: self.parentPool) + connection.isLeased = true + return connection + } + }.flatMap { connection in + handshakePromise.futureResult.map { + self.configureCloseCallback(of: connection) + return connection + }.flatMapError { error in + connection.closePromise.succeed(()) + let action = self.parentPool.connectionProvidersLock.withLock { + self.stateLock.withLock { + self.state.failedConnectionAction() + } + } + switch action { + case .makeConnectionAndComplete(let el, let promise): + self.makeConnection(on: el).cascade(to: promise) + case .none: + break + } + return self.eventLoop.makeFailedFuture(error) + } + } + } + + /// Adds a callback on connection close that asks the `state` what to do about this + /// + /// The callback informs the state about the event, and the state returns a + /// `ClosedConnectionRemoveAction` which instructs it about what it should do. + private func configureCloseCallback(of connection: Connection) { + connection.channel.closeFuture.whenComplete { result in + let action: HTTP1ConnectionProvider.State.ClosedConnectionRemoveAction? = self.parentPool.connectionProvidersLock.withLock { + self.stateLock.withLock { + guard connection.mustRunDefaultCloseCallback else { + return nil + } + switch result { + case .success: + return self.state.removeClosedConnection(connection) + + case .failure(let error): + preconditionFailure("Connection close future failed with error: \(error)") + } + } + } + + if let action = action { + switch action { + case .makeConnectionAndComplete(let el, let promise): + self.makeConnection(on: el).cascade(to: promise) + case .none: + break + } + } + + connection.closePromise.succeed(()) + } + } + + /// Removes and fails all `waiters`, remove existing `availableConnections` and sets `state.activity` to `.closing` + func prepareForClose() { + assert(MultiThreadedEventLoopGroup.currentEventLoop == nil, + "HTTPClient shutdown on EventLoop unsupported") // calls .wait() so it would crash later anyway + let (waitersFutures, closeFutures) = self.stateLock.withLock { () -> ([EventLoopFuture], [EventLoopFuture]) in + assert(self.state.activity == .opened, "Invalid activity: \(self.state.activity)") + // Fail waiters + let waitersCopy = self.state.waiters + self.state.waiters.removeAll() + let waitersPromises = waitersCopy.map { $0.promise } + let waitersFutures = waitersPromises.map { $0.futureResult } + waitersPromises.forEach { $0.fail(HTTPClientError.cancelled) } + let closeFutures = self.state.availableConnections.map { $0.close() } + return (waitersFutures, closeFutures) + } + try? EventLoopFuture.andAllComplete(waitersFutures, on: self.eventLoop).wait() + try? EventLoopFuture.andAllComplete(closeFutures, on: self.eventLoop).wait() + + self.stateLock.withLock { + if self.state.leased == 0, self.state.availableConnections.isEmpty { + self.state.activity = .closed + } else { + self.state.activity = .closing + } + } + } + + func syncClose() { + assert(MultiThreadedEventLoopGroup.currentEventLoop == nil, + "HTTPClient shutdown on EventLoop unsupported") // calls .wait() so it would crash later anyway + let availableConnections = self.stateLock.withLock { () -> CircularBuffer in + assert(self.state.activity == .closing) + return self.state.availableConnections + } + try? EventLoopFuture.andAllComplete(availableConnections.map { $0.close() }, on: self.eventLoop).wait() + } + + private func activityPrecondition(expected: Set) { + self.stateLock.withLock { + precondition(expected.contains(self.state.activity), "Attempting to use HTTP1ConnectionProvider with unexpected state: \(self.state.activity) (expected: \(expected))") + } + } + + fileprivate struct State { + /// The default `EventLoop` to use for this `HTTP1ConnectionProvider` + private let defaultEventLoop: EventLoop + + /// The maximum number of connections to a certain (host, scheme, port) tuple. + private let maximumConcurrentConnections: Int = 8 + + /// Opened connections that are available + fileprivate var availableConnections: CircularBuffer = .init(initialCapacity: 8) + + /// The number of currently leased connections + fileprivate var leased: Int = 0 { + didSet { + assert((0...self.maximumConcurrentConnections).contains(self.leased), "Invalid number of leased connections (\(self.leased))") + } + } + + /// Consumers that weren't able to get a new connection without exceeding + /// `maximumConcurrentConnections` get a `Future` + /// whose associated promise is stored in `Waiter`. The promise is completed + /// as soon as possible by the provider, in FIFO order. + fileprivate var waiters: CircularBuffer = .init(initialCapacity: 8) + + fileprivate var activity: Activity = .opened + + fileprivate var pending: Int = 0 + + private let parentPool: ConnectionPool + + private let key: Key + + fileprivate init(eventLoop: EventLoop, parentPool: ConnectionPool, key: Key) { + self.defaultEventLoop = eventLoop + self.parentPool = parentPool + self.key = key + } + + fileprivate mutating func connectionAction(for preference: HTTPClient.EventLoopPreference) -> ConnectionGetAction { + self.pending -= 1 + let (channelEL, requiresSpecifiedEL) = self.resolvePreference(preference) + if self.leased < self.maximumConcurrentConnections { + self.leased += 1 + if let connection = availableConnections.swapWithFirstAndRemove(where: { $0.channel.eventLoop === channelEL }) { + connection.isLeased = true + return .leaseConnection(connection) + } else { + if requiresSpecifiedEL { + return .makeConnection(channelEL) + } else if let existingConnection = availableConnections.popFirst() { + return .leaseConnection(existingConnection) + } else { + return .makeConnection(self.defaultEventLoop) + } + } + } else { + let promise = channelEL.makePromise(of: Connection.self) + self.waiters.append(Waiter(promise: promise, preference: preference)) + return .leaseFutureConnection(promise.futureResult) + } + } + + fileprivate mutating func releaseAction(for connection: Connection) -> ConnectionReleaseAction { + if let firstWaiter = self.waiters.popFirst() { + let (channelEL, requiresSpecifiedEL) = self.resolvePreference(firstWaiter.preference) + + guard connection.isActiveEstimation, !connection.isClosing else { + return .makeConnectionAndComplete(channelEL, firstWaiter.promise) + } + + if connection.channel.eventLoop === channelEL { + return .succeed(firstWaiter.promise) + } else { + if requiresSpecifiedEL { + connection.mustRunDefaultCloseCallback = false + return .replaceConnection(channelEL, firstWaiter.promise) + } else { + return .makeConnectionAndComplete(channelEL, firstWaiter.promise) + } + } + + } else { + connection.isLeased = false + self.leased -= 1 + if connection.isActiveEstimation, !connection.isClosing { + self.availableConnections.append(connection) + } + + if self.providerMustClose() { + self.removeFromPool() + } + + return .none + } + } + + fileprivate mutating func removeClosedConnection(_ connection: Connection) -> ClosedConnectionRemoveAction { + if connection.isLeased { + if let firstWaiter = self.waiters.popFirst() { + let (el, _) = self.resolvePreference(firstWaiter.preference) + return .makeConnectionAndComplete(el, firstWaiter.promise) + } + } else { + self.availableConnections.swapWithFirstAndRemove(where: { $0 === connection }) + } + + if self.providerMustClose() { + self.removeFromPool() + } + + return .none + } + + fileprivate mutating func failedConnectionAction() -> ClosedConnectionRemoveAction { + if let firstWaiter = self.waiters.popFirst() { + let (el, _) = self.resolvePreference(firstWaiter.preference) + return .makeConnectionAndComplete(el, firstWaiter.promise) + } else { + self.leased -= 1 + if self.providerMustClose() { + self.removeFromPool() + } + return .none + } + } + + private func providerMustClose() -> Bool { + return self.pending == 0 && self.activity != .closed && self.leased == 0 && self.availableConnections.isEmpty && self.waiters.isEmpty + } + + /// - Warning: This should always be called from a critical section protected by `.connectionProvidersLock` + fileprivate mutating func removeFromPool() { + assert(self.parentPool.connectionProviders[self.key] != nil) + self.parentPool.connectionProviders[self.key] = nil + assert(self.activity != .closed) + self.activity = .closed + } + + private func resolvePreference(_ preference: HTTPClient.EventLoopPreference) -> (EventLoop, Bool) { + switch preference.preference { + case .indifferent: + return (self.defaultEventLoop, false) + case .delegate(let el): + return (el, false) + case .delegateAndChannel(let el), .testOnly_exact(let el, _): + return (el, true) + } + } + + fileprivate enum ConnectionGetAction { + case leaseConnection(Connection) + case makeConnection(EventLoop) + case leaseFutureConnection(EventLoopFuture) + } + + fileprivate enum ConnectionReleaseAction { + case succeed(EventLoopPromise) + case makeConnectionAndComplete(EventLoop, EventLoopPromise) + case replaceConnection(EventLoop, EventLoopPromise) + case none + } + + fileprivate enum ClosedConnectionRemoveAction { + case none + case makeConnectionAndComplete(EventLoop, EventLoopPromise) + } + + /// A `Waiter` represents a request that waits for a connection when none is + /// currently available + /// + /// `Waiter`s are created when `maximumConcurrentConnections` is reached + /// and we cannot create new connections anymore. + fileprivate struct Waiter { + /// The promise to complete once a connection is available + let promise: EventLoopPromise + + /// The event loop preference associated to this particular request + /// that the provider should respect + let preference: HTTPClient.EventLoopPreference + } + + enum Activity: Hashable, CustomStringConvertible { + case opened + case closing + case closed + + var description: String { + switch self { + case .opened: + return "opened" + case .closing: + return "closing" + case .closed: + return "closed" + } + } + } + } + } +} diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 2fcff89b5..99ec37f2e 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -18,6 +18,7 @@ import NIOConcurrencyHelpers import NIOHTTP1 import NIOHTTPCompression import NIOSSL +import NIOTLS /// HTTPClient class provides API for request execution. /// @@ -48,7 +49,10 @@ public class HTTPClient { public let eventLoopGroup: EventLoopGroup let eventLoopGroupProvider: EventLoopGroupProvider let configuration: Configuration - let isShutdown = NIOAtomic.makeAtomic(value: false) + let pool: ConnectionPool + var state: State + private var tasks = [UUID: TaskProtocol]() + private let stateLock = Lock() /// Create an `HTTPClient` with specified `EventLoopGroup` provider and configuration. /// @@ -64,24 +68,79 @@ public class HTTPClient { self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) } self.configuration = configuration + self.pool = ConnectionPool(configuration: configuration) + self.state = .upAndRunning } deinit { - assert(self.isShutdown.load(), "Client not shut down before the deinit. Please call client.syncShutdown() when no longer needed.") + assert(self.pool.connectionProviderCount == 0) + assert(self.state == .shutDown, "Client not shut down before the deinit. Please call client.syncShutdown() when no longer needed.") } /// Shuts down the client and `EventLoopGroup` if it was created by the client. public func syncShutdown() throws { - switch self.eventLoopGroupProvider { - case .shared: - self.isShutdown.store(true) - return - case .createNew: - if self.isShutdown.compareAndExchange(expected: false, desired: true) { - try self.eventLoopGroup.syncShutdownGracefully() - } else { + try self.syncShutdown(requiresCleanClose: false) + } + + /// Shuts down the client and `EventLoopGroup` if it was created by the client. + /// + /// - parameters: + /// - requiresCleanClose: Determine if the client should throw when it is shutdown in a non-clean state + /// + /// - Note: + /// The `requiresCleanClose` will let the client do additional checks about its internal consistency on shutdown and + /// throw the appropriate error if needed. For instance, if its internal connection pool has any non-released connections, + /// this indicate shutdown was called too early before tasks were completed or explicitly canceled. + /// In general, setting this parameter to `true` should make it easier and faster to catch related programming errors. + internal func syncShutdown(requiresCleanClose: Bool) throws { + var closeError: Error? + + let tasks = try self.stateLock.withLock { () -> Dictionary.Values in + if self.state != .upAndRunning { throw HTTPClientError.alreadyShutdown } + self.state = .shuttingDown + return self.tasks.values + } + + self.pool.prepareForClose() + + if !tasks.isEmpty, requiresCleanClose { + closeError = HTTPClientError.uncleanShutdown + } + + for task in tasks { + task.cancel() + } + + try? EventLoopFuture.andAllComplete((tasks.map { $0.completion }), on: self.eventLoopGroup.next()).wait() + + self.pool.syncClose() + + do { + try self.stateLock.withLock { + switch self.eventLoopGroupProvider { + case .shared: + self.state = .shutDown + return + case .createNew: + switch self.state { + case .shuttingDown: + self.state = .shutDown + try self.eventLoopGroup.syncShutdownGracefully() + case .shutDown, .upAndRunning: + assertionFailure("The only valid state at this point is \(State.shutDown)") + } + } + } + } catch { + if closeError == nil { + closeError = error + } + } + + if let closeError = closeError { + throw closeError } } @@ -188,8 +247,7 @@ public class HTTPClient { public func execute(request: Request, delegate: Delegate, deadline: NIODeadline? = nil) -> Task { - let eventLoop = self.eventLoopGroup.next() - return self.execute(request: request, delegate: delegate, eventLoop: eventLoop, deadline: deadline) + return self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -201,31 +259,35 @@ public class HTTPClient { /// - deadline: Point in time by which the request must complete. public func execute(request: Request, delegate: Delegate, - eventLoop: EventLoopPreference, + eventLoop eventLoopPreference: EventLoopPreference, deadline: NIODeadline? = nil) -> Task { - switch eventLoop.preference { + let taskEL: EventLoop + switch eventLoopPreference.preference { case .indifferent: - return self.execute(request: request, delegate: delegate, eventLoop: self.eventLoopGroup.next(), deadline: deadline) + taskEL = self.pool.associatedEventLoop(for: ConnectionPool.Key(request)) ?? self.eventLoopGroup.next() case .delegate(on: let eventLoop): precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") - return self.execute(request: request, delegate: delegate, eventLoop: eventLoop, deadline: deadline) + taskEL = eventLoop case .delegateAndChannel(on: let eventLoop): precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") - return self.execute(request: request, delegate: delegate, eventLoop: eventLoop, deadline: deadline) - case .testOnly_exact(channelOn: let channelEL, delegateOn: let delegateEL): - return self.execute(request: request, - delegate: delegate, - eventLoop: delegateEL, - channelEL: channelEL, - deadline: deadline) + taskEL = eventLoop + case .testOnly_exact(_, delegateOn: let delegateEL): + taskEL = delegateEL + } + + let failedTask: Task? = self.stateLock.withLock { + switch state { + case .upAndRunning: + return nil + case .shuttingDown, .shutDown: + return Task.failedTask(eventLoop: taskEL, error: HTTPClientError.alreadyShutdown) + } + } + + if let failedTask = failedTask { + return failedTask } - } - private func execute(request: Request, - delegate: Delegate, - eventLoop delegateEL: EventLoop, - channelEL: EventLoop? = nil, - deadline: NIODeadline? = nil) -> Task { let redirectHandler: RedirectHandler? switch self.configuration.redirectConfiguration.configuration { case .follow(let max, let allowCycles): @@ -236,72 +298,73 @@ public class HTTPClient { redirectHandler = RedirectHandler(request: request) { newRequest in self.execute(request: newRequest, delegate: delegate, - eventLoop: delegateEL, - channelEL: channelEL, + eventLoop: eventLoopPreference, deadline: deadline) } case .disallow: redirectHandler = nil } - let task = Task(eventLoop: delegateEL) + let task = Task(eventLoop: taskEL) + self.stateLock.withLock { + self.tasks[task.id] = task + } + let promise = task.promise - var bootstrap = ClientBootstrap(group: channelEL ?? delegateEL) - .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) - .channelInitializer { channel in - let encoder = HTTPRequestEncoder() - let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)) - return channel.pipeline.addHandlers([encoder, decoder], position: .first).flatMap { - switch self.configuration.proxy { - case .none: - return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration) - case .some(let proxy): - return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration, proxy: proxy) - } - }.flatMap { - switch self.configuration.decompression { - case .disabled: - return channel.eventLoop.makeSucceededFuture(()) - case .enabled(let limit): - return channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: limit)) - } - }.flatMap { - if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) { - return channel.pipeline.addHandler(IdleStateHandler(readTimeout: timeout)) - } else { - return channel.eventLoop.makeSucceededFuture(()) - } - }.flatMap { - let taskHandler = TaskHandler(task: task, - kind: request.kind, - delegate: delegate, - redirectHandler: redirectHandler, - ignoreUncleanSSLShutdown: self.configuration.ignoreUncleanSSLShutdown) - return channel.pipeline.addHandler(taskHandler) - } + promise.futureResult.whenComplete { _ in + self.stateLock.withLock { + self.tasks[task.id] = nil } - - if let timeout = self.resolve(timeout: self.configuration.timeout.connect, deadline: deadline) { - bootstrap = bootstrap.connectTimeout(timeout) } - let eventLoopChannel: EventLoopFuture - switch request.kind { - case .unixSocket: - let socketPath = request.url.baseURL?.path ?? request.url.path - eventLoopChannel = bootstrap.connect(unixDomainSocketPath: socketPath) - case .host: - let address = self.resolveAddress(request: request, proxy: self.configuration.proxy) - eventLoopChannel = bootstrap.connect(host: address.host, port: address.port) - } + let connection = self.pool.getConnection(for: request, preference: eventLoopPreference, on: taskEL, deadline: deadline) - eventLoopChannel.map { channel in - task.setChannel(channel) - } - .flatMap { channel in - channel.writeAndFlush(request) - } - .cascadeFailure(to: task.promise) + connection.flatMap { connection -> EventLoopFuture in + let channel = connection.channel + let addedFuture: EventLoopFuture + + switch self.configuration.decompression { + case .disabled: + addedFuture = channel.eventLoop.makeSucceededFuture(()) + case .enabled(let limit): + let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) + addedFuture = channel.pipeline.addHandler(decompressHandler) + } + + return addedFuture.flatMap { + if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) { + return channel.pipeline.addHandler(IdleStateHandler(readTimeout: timeout)) + } else { + return channel.eventLoop.makeSucceededFuture(()) + } + }.flatMap { + let taskHandler = TaskHandler(task: task, + kind: request.kind, + delegate: delegate, + redirectHandler: redirectHandler, + ignoreUncleanSSLShutdown: self.configuration.ignoreUncleanSSLShutdown) + return channel.pipeline.addHandler(taskHandler) + }.flatMap { + task.setConnection(connection) + + let isCancelled = task.lock.withLock { + task.cancelled + } + + if !isCancelled { + return channel.writeAndFlush(request).flatMapError { _ in + // At this point the `TaskHandler` will already be present + // to handle the failure and pass it to the `promise` + channel.eventLoop.makeSucceededFuture(()) + } + } else { + return channel.eventLoop.makeSucceededFuture(()) + } + }.flatMapError { error in + connection.release() + return channel.eventLoop.makeFailedFuture(error) + } + }.cascadeFailure(to: promise) return task } @@ -319,10 +382,10 @@ public class HTTPClient { } } - private func resolveAddress(request: Request, proxy: Configuration.Proxy?) -> (host: String, port: Int) { - switch self.configuration.proxy { + static func resolveAddress(host: String, port: Int, proxy: Configuration.Proxy?) -> (host: String, port: Int) { + switch proxy { case .none: - return (request.host, request.port) + return (host, port) case .some(let proxy): return (proxy.host, proxy.port) } @@ -436,6 +499,12 @@ public class HTTPClient { /// Decompression is enabled. case enabled(limit: NIOHTTPDecompression.DecompressionLimit) } + + enum State { + case upAndRunning + case shuttingDown + case shutDown + } } extension HTTPClient.Configuration { @@ -490,37 +559,84 @@ extension HTTPClient.Configuration { } } -private extension ChannelPipeline { - func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?, proxy: HTTPClient.Configuration.Proxy?) -> EventLoopFuture { - let handler = HTTPClientProxyHandler(host: request.host, port: request.port, authorization: proxy?.authorization, onConnect: { channel in - channel.pipeline.removeHandler(decoder).flatMap { - channel.pipeline.addHandler( - ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)), - position: .after(encoder) - ) - }.flatMap { - channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: tlsConfiguration) +extension ChannelPipeline { + func addProxyHandler(host: String, port: Int, authorization: HTTPClient.Authorization?) -> EventLoopFuture { + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)) + let handler = HTTPClientProxyHandler(host: host, port: port, authorization: authorization) { channel in + let encoderRemovePromise = self.eventLoop.next().makePromise(of: Void.self) + channel.pipeline.removeHandler(encoder, promise: encoderRemovePromise) + return encoderRemovePromise.futureResult.flatMap { + channel.pipeline.removeHandler(decoder) } - }) - return self.addHandler(handler) + } + return addHandlers([encoder, decoder, handler]) } - func addSSLHandlerIfNeeded(for request: HTTPClient.Request, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { - guard request.useTLS else { + func addSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, handshakePromise: EventLoopPromise) -> EventLoopFuture { + guard key.scheme == .https else { + handshakePromise.succeed(()) return self.eventLoop.makeSucceededFuture(()) } do { let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient() let context = try NIOSSLContext(configuration: tlsConfiguration) - return self.addHandler(try NIOSSLClientHandler(context: context, serverHostname: request.host.isIPAddress ? nil : request.host), - position: .first) + let handlers: [ChannelHandler] = [ + try NIOSSLClientHandler(context: context, serverHostname: key.host.isIPAddress ? nil : key.host), + TLSEventsHandler(completionPromise: handshakePromise), + ] + + return self.addHandlers(handlers) } catch { return self.eventLoop.makeFailedFuture(error) } } } +class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = NIOAny + + var completionPromise: EventLoopPromise? + + init(completionPromise: EventLoopPromise) { + self.completionPromise = completionPromise + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if let tlsEvent = event as? TLSUserEvent { + switch tlsEvent { + case .handshakeCompleted: + self.completionPromise?.succeed(()) + self.completionPromise = nil + context.pipeline.removeHandler(self, promise: nil) + case .shutdownCompleted: + break + } + } + context.fireUserInboundEventTriggered(event) + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + if let sslError = error as? NIOSSLError { + switch sslError { + case .handshakeFailed: + self.completionPromise?.fail(error) + self.completionPromise = nil + context.pipeline.removeHandler(self, promise: nil) + default: + break + } + } + context.fireErrorCaught(error) + } + + func handlerRemoved(context: ChannelHandlerContext) { + struct NoResult: Error {} + self.completionPromise?.fail(NoResult()) + } +} + /// Possible client errors. public struct HTTPClientError: Error, Equatable, CustomStringConvertible { private enum Code: Equatable { @@ -539,6 +655,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case proxyAuthenticationRequired case redirectLimitReached case redirectCycleDetected + case uncleanShutdown } private var code: Code @@ -581,4 +698,6 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached) /// Redirect Cycle detected. public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected) + /// Unclean shutdown + public static let uncleanShutdown = HTTPClientError(code: .uncleanShutdown) } diff --git a/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift b/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift index 05c1f9cf6..ebdfbfa24 100644 --- a/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift @@ -69,6 +69,7 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan case awaitingResponse case connecting case connected + case failed } private let host: String @@ -102,6 +103,7 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan // blank line that concludes the successful response's header section break case 407: + self.readState = .failed context.fireErrorCaught(HTTPClientError.proxyAuthenticationRequired) default: // Any response other than a successful response @@ -119,6 +121,8 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan self.readBuffer.append(data) case .connected: context.fireChannelRead(data) + case .failed: + break } } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index eeb675f02..6320266c9 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -17,6 +17,7 @@ import NIO import NIOConcurrencyHelpers import NIOFoundationCompat import NIOHTTP1 +import NIOHTTPCompression import NIOSSL extension HTTPClient { @@ -486,22 +487,31 @@ extension URL { extension HTTPClient { /// Response execution context. Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. - public final class Task { + public final class Task: TaskProtocol { /// The `EventLoop` the delegate will be executed on. public let eventLoop: EventLoop let promise: EventLoopPromise - var channel: Channel? - private var cancelled: Bool - private let lock: Lock + var completion: EventLoopFuture + var connection: ConnectionPool.Connection? + var cancelled: Bool + let lock: Lock + let id = UUID() init(eventLoop: EventLoop) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() + self.completion = self.promise.futureResult.map { _ in } self.cancelled = false self.lock = Lock() } + static func failedTask(eventLoop: EventLoop, error: Error) -> Task { + let task = self.init(eventLoop: eventLoop) + task.promise.fail(error) + return task + } + /// `EventLoopFuture` for the response returned by this request. public var futureResult: EventLoopFuture { return self.promise.futureResult @@ -520,18 +530,58 @@ extension HTTPClient { let channel: Channel? = self.lock.withLock { if !cancelled { cancelled = true - return self.channel + return self.connection?.channel + } else { + return nil } - return nil } channel?.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil) } @discardableResult - func setChannel(_ channel: Channel) -> Channel { + func setConnection(_ connection: ConnectionPool.Connection) -> ConnectionPool.Connection { return self.lock.withLock { - self.channel = channel - return channel + self.connection = connection + if self.cancelled { + connection.channel.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil) + } + return connection + } + } + + func succeed(promise: EventLoopPromise?, with value: Response, delegateType: Delegate.Type) { + self.releaseAssociatedConnection(delegateType: delegateType).whenSuccess { + promise?.succeed(value) + } + } + + func fail(with error: Error, delegateType: Delegate.Type) { + if let connection = self.connection { + connection.close().whenComplete { _ in + self.releaseAssociatedConnection(delegateType: delegateType).whenComplete { _ in + self.promise.fail(error) + } + } + } + } + + func releaseAssociatedConnection(delegateType: Delegate.Type) -> EventLoopFuture { + if let connection = self.connection { + return connection.removeHandler(NIOHTTPResponseDecompressor.self).flatMap { + connection.removeHandler(IdleStateHandler.self) + }.flatMap { + connection.removeHandler(TaskHandler.self) + }.map { + connection.release() + }.flatMapError { error in + fatalError("Couldn't remove taskHandler: \(error)") + } + + } else { + // TODO: This seems only reached in some internal unit test + // Maybe there could be a better handling in the future to make + // it an error outside of testing contexts + return self.eventLoop.makeSucceededFuture(()) } } } @@ -539,9 +589,15 @@ extension HTTPClient { internal struct TaskCancelEvent {} +internal protocol TaskProtocol { + func cancel() + var id: UUID { get } + var completion: EventLoopFuture { get } +} + // MARK: - TaskHandler -internal class TaskHandler { +internal class TaskHandler: RemovableChannelHandler { enum State { case idle case sent @@ -581,7 +637,7 @@ extension TaskHandler { _ body: @escaping (HTTPClient.Task, Err) -> Void) { func doIt() { body(self.task, error) - self.task.promise.fail(error) + self.task.fail(with: error, delegateType: Delegate.self) } if self.task.eventLoop.inEventLoop { @@ -621,13 +677,14 @@ extension TaskHandler { } func callOutToDelegate(promise: EventLoopPromise? = nil, - _ body: @escaping (HTTPClient.Task) throws -> Response) { + _ body: @escaping (HTTPClient.Task) throws -> Response) where Response == Delegate.Response { func doIt() { do { let result = try body(self.task) - promise?.succeed(result) + + self.task.succeed(promise: promise, with: result, delegateType: Delegate.self) } catch { - promise?.fail(error) + self.task.fail(with: error, delegateType: Delegate.self) } } @@ -641,7 +698,7 @@ extension TaskHandler { } func callOutToDelegate(channelEventLoop: EventLoop, - _ body: @escaping (HTTPClient.Task) throws -> Response) -> EventLoopFuture { + _ body: @escaping (HTTPClient.Task) throws -> Response) -> EventLoopFuture where Response == Delegate.Response { let promise = channelEventLoop.makePromise(of: Response.self) self.callOutToDelegate(promise: promise, body) return promise.futureResult @@ -678,8 +735,6 @@ extension TaskHandler: ChannelDuplexHandler { headers.add(name: "Host", value: request.host) } - headers.add(name: "Connection", value: "close") - do { try headers.validate(body: request.body) } catch { @@ -702,16 +757,10 @@ extension TaskHandler: ChannelDuplexHandler { context.eventLoop.assertInEventLoop() self.state = .sent self.callOutToDelegateFireAndForget(self.delegate.didSendRequest) - - let channel = context.channel - self.task.futureResult.whenComplete { _ in - channel.close(promise: nil) - } }.flatMapErrorThrowing { error in context.eventLoop.assertInEventLoop() self.state = .end self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) - context.close(promise: nil) throw error }.cascade(to: promise) } @@ -742,6 +791,16 @@ extension TaskHandler: ChannelDuplexHandler { let response = self.unwrapInboundIn(data) switch response { case .head(let head): + if !head.isKeepAlive { + self.task.lock.withLock { + if let connection = self.task.connection { + connection.isClosing = true + } else { + preconditionFailure("There should always be a connection at this point") + } + } + } + if let redirectURL = redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { self.state = .redirected(head, redirectURL) } else { @@ -768,8 +827,9 @@ extension TaskHandler: ChannelDuplexHandler { switch self.state { case .redirected(let head, let redirectURL): self.state = .end - self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise) - context.close(promise: nil) + self.task.releaseAssociatedConnection(delegateType: Delegate.self).whenSuccess { + self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise) + } default: self.state = .end self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest) @@ -845,6 +905,13 @@ extension TaskHandler: ChannelDuplexHandler { self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } } + + func handlerAdded(context: ChannelHandlerContext) { + guard context.channel.isActive else { + self.failTaskAndNotifyDelegate(error: HTTPClientError.remoteConnectionClosed, self.delegate.didReceiveError) + return + } + } } // MARK: - RedirectHandler @@ -931,9 +998,13 @@ internal struct RedirectHandler { do { var newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body) newRequest.redirectState = nextState - return self.execute(newRequest).futureResult.cascade(to: promise) + self.execute(newRequest).futureResult.whenComplete { result in + promise.futureResult.eventLoop.execute { + promise.completeWith(result) + } + } } catch { - return promise.fail(error) + promise.fail(error) } } } diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index 12f921a06..6e2fedf53 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -14,6 +14,7 @@ import NIO import NIOHTTP1 +import NIOHTTPCompression internal extension String { var isIPAddress: Bool { @@ -44,3 +45,53 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { return () } } + +extension ClientBootstrap { + static func makeHTTPClientBootstrapBase(group: EventLoopGroup, host: String, port: Int, configuration: HTTPClient.Configuration, channelInitializer: ((Channel) -> EventLoopFuture)? = nil) -> ClientBootstrap { + return ClientBootstrap(group: group) + .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) + + .channelInitializer { channel in + let channelAddedFuture: EventLoopFuture + switch configuration.proxy { + case .none: + channelAddedFuture = group.next().makeSucceededFuture(()) + case .some: + channelAddedFuture = channel.pipeline.addProxyHandler(host: host, port: port, authorization: configuration.proxy?.authorization) + } + return channelAddedFuture.flatMap { (_: Void) -> EventLoopFuture in + channelInitializer?(channel) ?? group.next().makeSucceededFuture(()) + } + } + } +} + +extension CircularBuffer { + @discardableResult + mutating func swapWithFirstAndRemove(at index: Index) -> Element? { + precondition(index >= self.startIndex && index < self.endIndex) + if !self.isEmpty { + self.swapAt(self.startIndex, index) + return self.removeFirst() + } else { + return nil + } + } + + @discardableResult + mutating func swapWithFirstAndRemove(where predicate: (Element) throws -> Bool) rethrows -> Element? { + if let existingIndex = try self.firstIndex(where: predicate) { + return self.swapWithFirstAndRemove(at: existingIndex) + } else { + return nil + } + } +} + +extension ConnectionPool.Connection { + func removeHandler(_ type: Handler.Type) -> EventLoopFuture { + return self.channel.pipeline.handler(type: type).flatMap { handler in + self.channel.pipeline.removeHandler(handler) + }.recover { _ in } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index d9778df09..8139aff49 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -33,6 +33,7 @@ extension HTTPClientInternalTests { ("testUploadStreamingBackpressure", testUploadStreamingBackpressure), ("testRequestURITrailingSlash", testRequestURITrailingSlash), ("testChannelAndDelegateOnDifferentEventLoops", testChannelAndDelegateOnDifferentEventLoops), + ("testResponseConnectionCloseGet", testResponseConnectionCloseGet), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 917b3622f..e23f9684b 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -44,7 +44,6 @@ class HTTPClientInternalTests: XCTestCase { head.headers.add(name: "X-Test-Header", value: "X-Test-Value") head.headers.add(name: "Host", value: "localhost") head.headers.add(name: "Content-Length", value: "4") - head.headers.add(name: "Connection", value: "close") XCTAssertEqual(HTTPClientRequestPart.head(head), recorder.writes[0]) let buffer = ByteBuffer.of(string: "1234") XCTAssertEqual(HTTPClientRequestPart.body(.byteBuffer(buffer)), recorder.writes[1]) @@ -107,7 +106,7 @@ class HTTPClientInternalTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -137,7 +136,7 @@ class HTTPClientInternalTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -197,8 +196,8 @@ class HTTPClientInternalTests: XCTestCase { func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { // This is to force NIO to send only 1 byte at a time. - let future = task.channel!.setOption(ChannelOptions.maxMessagesPerRead, value: 1).flatMap { - task.channel!.setOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) + let future = task.connection!.channel.setOption(ChannelOptions.maxMessagesPerRead, value: 1).flatMap { + task.connection!.channel.setOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) } future.cascade(to: self.optionsApplied) return future @@ -222,7 +221,7 @@ class HTTPClientInternalTests: XCTestCase { let httpBin = HTTPBin(channelPromise: promise) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -231,6 +230,7 @@ class HTTPClientInternalTests: XCTestCase { let future = httpClient.execute(request: request, delegate: delegate).futureResult let channel = try promise.futureResult.wait() + // We need to wait for channel options that limit NIO to sending only one byte at a time. try delegate.optionsApplied.futureResult.wait() @@ -363,7 +363,7 @@ class HTTPClientInternalTests: XCTestCase { let promise: EventLoopPromise = httpClient.eventLoopGroup.next().makePromise() let httpBin = HTTPBin(channelPromise: promise) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -436,4 +436,21 @@ class HTTPClientInternalTests: XCTestCase { XCTFail("wrong message") } } + + func testResponseConnectionCloseGet() throws { + let httpBin = HTTPBin(ssl: false) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["Connection": "close"], body: nil) + _ = try! httpClient.execute(request: req).wait() + let el = httpClient.eventLoopGroup.next() + try! el.scheduleTask(in: .milliseconds(500)) { + XCTAssertEqual(httpClient.pool.connectionProviderCount, 0) + }.futureResult.wait() + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 81926f1e5..e0c843ca4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -192,7 +192,10 @@ internal final class HTTPBin { compress: Bool = false, bindTarget: BindTarget = .localhostIPv4RandomPort, simulateProxy: HTTPProxySimulator.Option? = nil, - channelPromise: EventLoopPromise? = nil) { + channelPromise: EventLoopPromise? = nil, + connectionDelay: TimeAmount = .seconds(0), + maxChannelAge: TimeAmount? = nil, + refusesConnections: Bool = false) { let socketAddress: SocketAddress switch bindTarget { case .localhostIPv4RandomPort: @@ -200,12 +203,16 @@ internal final class HTTPBin { case .unixDomainSocket(let path): socketAddress = try! SocketAddress(unixDomainSocketPath: path) } + self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true) - .flatMap { + guard !refusesConnections else { + return channel.eventLoop.makeFailedFuture(HTTPBinError.refusedConnection) + } + return channel.eventLoop.scheduleTask(in: connectionDelay) {}.futureResult.flatMap { + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { if compress { return channel.pipeline.addHandler(HTTPResponseCompressor()) } else { @@ -221,8 +228,7 @@ internal final class HTTPBin { } else { return channel.eventLoop.makeSucceededFuture(()) } - } - .flatMap { + }.flatMap { if ssl { return HTTPBin.configureTLS(channel: channel).flatMap { channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise)) @@ -231,8 +237,8 @@ internal final class HTTPBin { return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise)) } } - } - .bind(to: socketAddress).wait() + } + }.bind(to: socketAddress).wait() } func shutdown() throws { @@ -245,6 +251,10 @@ internal final class HTTPBin { } } +enum HTTPBinError: Error { + case refusedConnection +} + final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPServerRequestPart typealias InboundOut = HTTPServerResponsePart @@ -327,14 +337,53 @@ internal final class HttpBinHandler: ChannelInboundHandler { let channelPromise: EventLoopPromise? var resps = CircularBuffer() - - init(channelPromise: EventLoopPromise? = nil) { + var closeAfterResponse = false + var delay: TimeAmount = .seconds(0) + let creationDate = Date() + let maxChannelAge: TimeAmount? + var shouldClose = false + var isServingRequest = false + + init(channelPromise: EventLoopPromise? = nil, maxChannelAge: TimeAmount? = nil) { self.channelPromise = channelPromise + self.maxChannelAge = maxChannelAge + } + + func handlerAdded(context: ChannelHandlerContext) { + if let maxChannelAge = self.maxChannelAge { + context.eventLoop.scheduleTask(in: maxChannelAge) { + if !self.isServingRequest { + context.close(promise: nil) + } else { + self.shouldClose = true + } + } + } + } + + func parseAndSetOptions(from head: HTTPRequestHead) { + if let delay = head.headers["X-internal-delay"].first { + if let milliseconds = Int64(delay) { + self.delay = TimeAmount.milliseconds(milliseconds) + } else { + assertionFailure("Invalid interval format") + } + } else { + self.delay = .nanoseconds(0) + } + + if let connection = head.headers["Connection"].first { + self.closeAfterResponse = (connection == "close") + } else { + self.closeAfterResponse = false + } } func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.isServingRequest = true switch self.unwrapInboundIn(data) { case .head(let req): + self.parseAndSetOptions(from: req) let url = URL(string: req.uri)! switch url.path { case "/": @@ -454,7 +503,19 @@ internal final class HttpBinHandler: ChannelInboundHandler { responseBody.writeBytes(serialized) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + context.eventLoop.scheduleTask(in: self.delay) { + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenComplete { result in + self.isServingRequest = false + switch result { + case .success: + if self.closeAfterResponse || self.shouldClose { + context.close(promise: nil) + } + case .failure(let error): + assertionFailure("\(error)") + } + } + } } } @@ -587,6 +648,29 @@ extension ByteBuffer { } } +struct EventLoopFutureTimeoutError: Error {} + +extension EventLoopFuture { + func timeout(after failDelay: TimeAmount) -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Value.self) + + self.whenComplete { result in + switch result { + case .success(let value): + promise.succeed(value) + case .failure(let error): + promise.fail(error) + } + } + + self.eventLoop.scheduleTask(in: failDelay) { + promise.fail(EventLoopFutureTimeoutError()) + } + + return promise.futureResult + } +} + private let cert = """ -----BEGIN CERTIFICATE----- MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index e35b9b679..7cf747406 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -43,6 +43,7 @@ extension HTTPClientTests { ("testReadTimeout", testReadTimeout), ("testDeadline", testDeadline), ("testCancel", testCancel), + ("testStressCancel", testStressCancel), ("testHTTPClientAuthorization", testHTTPClientAuthorization), ("testProxyPlaintext", testProxyPlaintext), ("testProxyTLS", testProxyTLS), @@ -58,6 +59,7 @@ extension HTTPClientTests { ("testWrongContentLengthForSSLUncleanShutdown", testWrongContentLengthForSSLUncleanShutdown), ("testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown", testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown), ("testEventLoopArgument", testEventLoopArgument), + ("testResponseFutureIsOnCorrectEL", testResponseFutureIsOnCorrectEL), ("testDecompression", testDecompression), ("testDecompressionLimit", testDecompressionLimit), ("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit), @@ -68,10 +70,28 @@ extension HTTPClientTests { ("testWorksWhenServerClosesConnectionAfterReceivingRequest", testWorksWhenServerClosesConnectionAfterReceivingRequest), ("testSubsequentRequestsWorkWithServerSendingConnectionClose", testSubsequentRequestsWorkWithServerSendingConnectionClose), ("testSubsequentRequestsWorkWithServerAlternatingBetweenKeepAliveAndClose", testSubsequentRequestsWorkWithServerAlternatingBetweenKeepAliveAndClose), + ("testStressGetHttps", testStressGetHttps), + ("testStressGetHttpsSSLError", testStressGetHttpsSSLError), + ("testUncleanCloseThrows", testUncleanCloseThrows), + ("testFailingConnectionIsReleased", testFailingConnectionIsReleased), + ("testResponseDelayGet", testResponseDelayGet), + ("testIdleTimeoutNoReuse", testIdleTimeoutNoReuse), + ("testStressGetClose", testStressGetClose), ("testManyConcurrentRequestsWork", testManyConcurrentRequestsWork), ("testRepeatedRequestsWorkWhenServerAlwaysCloses", testRepeatedRequestsWorkWhenServerAlwaysCloses), + ("testShutdownBeforeTasksCompletion", testShutdownBeforeTasksCompletion), + ("testUncleanShutdownActuallyShutsDown", testUncleanShutdownActuallyShutsDown), + ("testUncleanShutdownCancelsTasks", testUncleanShutdownCancelsTasks), + ("testDoubleShutdown", testDoubleShutdown), + ("testTaskFailsWhenClientIsShutdown", testTaskFailsWhenClientIsShutdown), + ("testRaceNewRequestsVsShutdown", testRaceNewRequestsVsShutdown), + ("testVaryingLoopPreference", testVaryingLoopPreference), + ("testMakeSecondRequestDuringCancelledCallout", testMakeSecondRequestDuringCancelledCallout), + ("testMakeSecondRequestDuringSuccessCallout", testMakeSecondRequestDuringSuccessCallout), + ("testMakeSecondRequestWhilstFirstIsOngoing", testMakeSecondRequestWhilstFirstIsOngoing), ("testUDSBasic", testUDSBasic), ("testUDSSocketAndPath", testUDSSocketAndPath), + ("testUseExistingConnectionOnDifferentEL", testUseExistingConnectionOnDifferentEL), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 96b3761c0..f2ca8353e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -import AsyncHTTPClient +@testable import AsyncHTTPClient import NIO import NIOFoundationCompat import NIOHTTP1 @@ -77,7 +77,7 @@ class HTTPClientTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -91,7 +91,7 @@ class HTTPClientTests: XCTestCase { let external = MultiThreadedEventLoopGroup(numberOfThreads: 1) let httpClient = HTTPClient(eventLoopGroupProvider: .shared(loopGroup)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try loopGroup.syncShutdownGracefully()) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -105,7 +105,7 @@ class HTTPClientTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -122,7 +122,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -135,7 +135,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -148,7 +148,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -169,7 +169,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) XCTAssertNoThrow(try httpsBin.shutdown()) } @@ -187,7 +187,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -209,7 +209,7 @@ class HTTPClientTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -220,7 +220,7 @@ class HTTPClientTests: XCTestCase { func testMultipleContentLengthHeaders() throws { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } let httpBin = HTTPBin() defer { @@ -241,7 +241,7 @@ class HTTPClientTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -259,7 +259,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -275,7 +275,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150)))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -291,7 +291,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -307,7 +307,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -326,6 +326,36 @@ class HTTPClientTests: XCTestCase { } } + func testStressCancel() throws { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let request = try Request(url: "http://localhost:\(httpBin.port)/wait", method: .GET) + let tasks = (1...100).map { _ -> HTTPClient.Task in + let task = httpClient.execute(request: request, delegate: TestHTTPDelegate()) + task.cancel() + return task + } + + for task in tasks { + switch (Result { try task.futureResult.timeout(after: .seconds(10)).wait() }) { + case .success: + XCTFail("Shouldn't succeed") + return + case .failure(let error): + guard let clientError = error as? HTTPClientError, clientError == .cancelled else { + XCTFail("Unexpected error: \(error)") + return + } + } + } + } + func testHTTPClientAuthorization() { var authorization = HTTPClient.Authorization.basic(username: "aladdin", password: "opensesame") XCTAssertEqual(authorization.headerValue, "Basic YWxhZGRpbjpvcGVuc2VzYW1l") @@ -341,7 +371,7 @@ class HTTPClientTests: XCTestCase { configuration: .init(proxy: .server(host: "localhost", port: httpBin.port)) ) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } let res = try httpClient.get(url: "http://test/ok").wait() @@ -358,7 +388,7 @@ class HTTPClientTests: XCTestCase { ) ) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } let res = try httpClient.get(url: "https://test/ok").wait() @@ -372,7 +402,7 @@ class HTTPClientTests: XCTestCase { configuration: .init(proxy: .server(host: "localhost", port: httpBin.port, authorization: .basic(username: "aladdin", password: "opensesame"))) ) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } let res = try httpClient.get(url: "http://test/ok").wait() @@ -386,7 +416,7 @@ class HTTPClientTests: XCTestCase { configuration: .init(proxy: .server(host: "localhost", port: httpBin.port, authorization: .basic(username: "aladdin", password: "opensesamefoo"))) ) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } XCTAssertThrowsError(try httpClient.get(url: "http://test/ok").wait(), "Should fail") { error in @@ -400,7 +430,7 @@ class HTTPClientTests: XCTestCase { let httpBin = HTTPBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -426,7 +456,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -443,7 +473,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none, ignoreUncleanSSLShutdown: true)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -461,7 +491,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -479,7 +509,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -495,7 +525,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -512,7 +542,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none, ignoreUncleanSSLShutdown: true)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -529,7 +559,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -546,7 +576,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none, ignoreUncleanSSLShutdown: true)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) httpBin.shutdown() } @@ -563,7 +593,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup), configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -600,12 +630,36 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(true, response) } + func testResponseFutureIsOnCorrectEL() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 4) + let client = HTTPClient(eventLoopGroupProvider: .shared(group)) + let httpBin = HTTPBin() + defer { + XCTAssertNoThrow(try client.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get") + var futures = [EventLoopFuture]() + for _ in 1...100 { + let el = group.next() + let req1 = client.execute(request: request, eventLoop: .delegate(on: el)) + let req2 = client.execute(request: request, eventLoop: .delegateAndChannel(on: el)) + let req3 = client.execute(request: request, eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el))) + XCTAssert(req1.eventLoop === el) + XCTAssert(req2.eventLoop === el) + XCTAssert(req3.eventLoop === el) + futures.append(contentsOf: [req1, req2, req3]) + } + try EventLoopFuture.andAllComplete(futures, on: group.next()).wait() + } + func testDecompression() throws { let httpBin = HTTPBin(compress: true) let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .none))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -641,7 +695,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .ratio(10)))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -670,7 +724,7 @@ class HTTPClientTests: XCTestCase { configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -682,10 +736,10 @@ class HTTPClientTests: XCTestCase { func testCountRedirectLimit() throws { let httpBin = HTTPBin(ssl: true) let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: true))) + configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 1000, allowCycles: true))) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) } @@ -695,7 +749,7 @@ class HTTPClientTests: XCTestCase { } func testMultipleConcurrentRequests() throws { - let numberOfRequestsPerThread = 100 + let numberOfRequestsPerThread = 1000 let numberOfParallelWorkers = 5 final class HTTPServer: ChannelInboundHandler { @@ -736,7 +790,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } let g = DispatchGroup() @@ -751,7 +805,13 @@ class HTTPClientTests: XCTestCase { } } } - g.wait() + let timeout = DispatchTime.now() + .seconds(180) + switch g.wait(timeout: timeout) { + case .success: + break + case .timedOut: + XCTFail("Timed out") + } } func testWorksWith500Error() { @@ -762,7 +822,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } let result = httpClient.get(url: "http://localhost:\(web.serverPort)/foo") @@ -770,9 +830,6 @@ class HTTPClientTests: XCTestCase { method: .GET, uri: "/foo", headers: HTTPHeaders([("Host", "localhost"), - // The following line can be removed once we - // have a connection pool. - ("Connection", "close"), ("Content-Length", "0")]))), try web.readInbound())) XCTAssertNoThrow(XCTAssertEqual(.end(nil), @@ -795,7 +852,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } let result = httpClient.get(url: "http://localhost:\(web.serverPort)/foo") @@ -803,9 +860,6 @@ class HTTPClientTests: XCTestCase { method: .GET, uri: "/foo", headers: HTTPHeaders([("Host", "localhost"), - // The following line can be removed once we - // have a connection pool. - ("Connection", "close"), ("Content-Length", "0")]))), try web.readInbound())) XCTAssertNoThrow(XCTAssertEqual(.end(nil), @@ -825,7 +879,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } let result = httpClient.get(url: "http://localhost:\(web.serverPort)/foo") @@ -833,9 +887,6 @@ class HTTPClientTests: XCTestCase { method: .GET, uri: "/foo", headers: HTTPHeaders([("Host", "localhost"), - // The following line can be removed once we - // have a connection pool. - ("Connection", "close"), ("Content-Length", "0")]))), try web.readInbound())) XCTAssertNoThrow(XCTAssertEqual(.end(nil), @@ -855,7 +906,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } for _ in 0..<10 { @@ -865,9 +916,6 @@ class HTTPClientTests: XCTestCase { method: .GET, uri: "/foo", headers: HTTPHeaders([("Host", "localhost"), - // The following line can be removed once - // we have a connection pool. - ("Connection", "close"), ("Content-Length", "0")]))), try web.readInbound())) XCTAssertNoThrow(XCTAssertEqual(.end(nil), @@ -892,7 +940,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } for i in 0..<10 { @@ -902,9 +950,6 @@ class HTTPClientTests: XCTestCase { method: .GET, uri: "/foo", headers: HTTPHeaders([("Host", "localhost"), - // The following line can be removed once - // we have a connection pool. - ("Connection", "close"), ("Content-Length", "0")]))), try web.readInbound())) XCTAssertNoThrow(XCTAssertEqual(.end(nil), @@ -922,6 +967,142 @@ class HTTPClientTests: XCTestCase { } } + func testStressGetHttps() throws { + let httpBin = HTTPBin(ssl: true) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let eventLoop = httpClient.eventLoopGroup.next() + let requestCount = 200 + var futureResults = [EventLoopFuture]() + for _ in 1...requestCount { + let req = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)/get", method: .GET, headers: ["X-internal-delay": "100"]) + futureResults.append(httpClient.execute(request: req)) + } + XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(futureResults, on: eventLoop).wait()) + } + + func testStressGetHttpsSSLError() throws { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let request = try Request(url: "https://localhost:\(httpBin.port)/wait", method: .GET) + let tasks = (1...100).map { _ -> HTTPClient.Task in + httpClient.execute(request: request, delegate: TestHTTPDelegate()) + } + + let results = try EventLoopFuture.whenAllComplete(tasks.map { $0.futureResult }, on: httpClient.eventLoopGroup.next()).wait() + + for result in results { + switch result { + case .success: + XCTFail("Shouldn't succeed") + continue + case .failure(let error): + guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { + XCTFail("Unexpected error: \(error)") + continue + } + } + } + } + + func testUncleanCloseThrows() { + let httpBin = HTTPBin() + defer { + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + _ = httpClient.get(url: "http://localhost:\(httpBin.port)/wait") + do { + try httpClient.syncShutdown(requiresCleanClose: true) + XCTFail("There should be an error on shutdown") + } catch { + guard let clientError = error as? HTTPClientError, clientError == .uncleanShutdown else { + XCTFail("Unexpected shutdown error: \(error)") + return + } + } + } + + func testFailingConnectionIsReleased() { + let httpBin = HTTPBin(refusesConnections: true) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + do { + _ = try httpClient.get(url: "http://localhost:\(httpBin.port)/get").timeout(after: .seconds(5)).wait() + XCTFail("Shouldn't succeed") + } catch { + guard !(error is EventLoopFutureTimeoutError) else { + XCTFail("Timed out but should have failed immediately") + return + } + } + } + + func testResponseDelayGet() throws { + let httpBin = HTTPBin(ssl: false) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["X-internal-delay": "2000"], body: nil) + let start = Date() + let response = try! httpClient.execute(request: req).wait() + XCTAssertEqual(Date().timeIntervalSince(start), 2, accuracy: 0.25) + XCTAssertEqual(response.status, .ok) + } + + func testIdleTimeoutNoReuse() throws { + let httpBin = HTTPBin(ssl: false) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + var req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET) + XCTAssertNoThrow(try httpClient.execute(request: req, deadline: .now() + .seconds(2)).wait()) + req.headers.add(name: "X-internal-delay", value: "2500") + try httpClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(250)) {}.futureResult.wait() + XCTAssertNoThrow(try httpClient.execute(request: req).timeout(after: .seconds(10)).wait()) + } + + func testStressGetClose() throws { + let httpBin = HTTPBin(ssl: false) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let eventLoop = httpClient.eventLoopGroup.next() + let requestCount = 200 + var futureResults = [EventLoopFuture]() + for _ in 1...requestCount { + let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["X-internal-delay": "5", "Connection": "close"]) + futureResults.append(httpClient.execute(request: req)) + } + XCTAssertNoThrow(try EventLoopFuture.andAllComplete(futureResults, on: eventLoop).timeout(after: .seconds(10)).wait()) + } + func testManyConcurrentRequestsWork() { let numberOfWorkers = 20 let numberOfRequestsPerWorkers = 20 @@ -976,7 +1157,7 @@ class HTTPClientTests: XCTestCase { let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.group)) defer { - XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } for _ in 0..<10 { @@ -985,9 +1166,6 @@ class HTTPClientTests: XCTestCase { method: .GET, uri: "/foo", headers: HTTPHeaders([("Host", "localhost"), - // The following line can be removed once - // we have a connection pool. - ("Connection", "close"), ("Content-Length", "0")]))), try web.readInbound())) XCTAssertNoThrow(XCTAssertEqual(.end(nil), @@ -1004,6 +1182,301 @@ class HTTPClientTests: XCTestCase { } } + func testShutdownBeforeTasksCompletion() throws { + let httpBin = HTTPBin() + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let client = HTTPClient(eventLoopGroupProvider: .shared(elg)) + let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["X-internal-delay": "500"]) + let res = client.execute(request: req) + XCTAssertNoThrow(try client.syncShutdown(requiresCleanClose: false)) + _ = try? res.timeout(after: .seconds(2)).wait() + try httpBin.shutdown() + try elg.syncShutdownGracefully() + } + + /// This test would cause an assertion failure on `HTTPClient` deinit if client doesn't actually shutdown + func testUncleanShutdownActuallyShutsDown() throws { + let httpBin = HTTPBin() + let client = HTTPClient(eventLoopGroupProvider: .createNew) + let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["X-internal-delay": "200"]) + _ = client.execute(request: req) + try? client.syncShutdown(requiresCleanClose: true) + try httpBin.shutdown() + } + + func testUncleanShutdownCancelsTasks() throws { + let httpBin = HTTPBin() + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let client = HTTPClient(eventLoopGroupProvider: .shared(elg)) + + defer { + XCTAssertNoThrow(try httpBin.shutdown()) + XCTAssertNoThrow(try elg.syncShutdownGracefully()) + } + + let responses = (1...100).map { _ in + client.get(url: "http://localhost:\(httpBin.port)/wait") + } + + try client.syncShutdown(requiresCleanClose: false) + + let results = try EventLoopFuture.whenAllComplete(responses, on: elg.next()).timeout(after: .seconds(100)).wait() + + for result in results { + switch result { + case .success: + XCTFail("Shouldn't succeed") + case .failure(let error): + if let clientError = error as? HTTPClientError, clientError == .cancelled { + continue + } else { + XCTFail("Unexpected error: \(error)") + } + } + } + } + + func testDoubleShutdown() { + let client = HTTPClient(eventLoopGroupProvider: .createNew) + XCTAssertNoThrow(try client.syncShutdown()) + do { + try client.syncShutdown() + XCTFail("Shutdown should fail with \(HTTPClientError.alreadyShutdown)") + } catch { + guard let clientError = error as? HTTPClientError, clientError == .alreadyShutdown else { + XCTFail("Unexpected error: \(error) instead of \(HTTPClientError.alreadyShutdown)") + return + } + } + } + + func testTaskFailsWhenClientIsShutdown() { + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try elg.syncShutdownGracefully()) + } + let client = HTTPClient(eventLoopGroupProvider: .shared(elg)) + XCTAssertNoThrow(try client.syncShutdown(requiresCleanClose: true)) + do { + _ = try client.get(url: "http://localhost/").wait() + XCTFail("Request shouldn't succeed") + } catch { + if let error = error as? HTTPClientError, error == .alreadyShutdown { + return + } else { + XCTFail("Unexpected error: \(error)") + } + } + } + + func testRaceNewRequestsVsShutdown() { + let numberOfWorkers = 20 + let allWorkersReady = DispatchSemaphore(value: 0) + let allWorkersGo = DispatchSemaphore(value: 0) + let allDone = DispatchGroup() + + let httpBin = HTTPBin() + defer { + XCTAssertNoThrow(try httpBin.shutdown()) + } + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + XCTAssertThrowsError(try httpClient.syncShutdown()) { error in + XCTAssertEqual(.alreadyShutdown, error as? HTTPClientError) + } + } + + let url = "http://localhost:\(httpBin.port)/get" + XCTAssertNoThrow(XCTAssertEqual(.ok, try httpClient.get(url: url).wait().status)) + + for w in 0..]() + for i in 1...100 { + let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["X-internal-delay": "10"]) + let preference: HTTPClient.EventLoopPreference + if i <= 50 { + preference = .delegateAndChannel(on: first) + } else { + preference = .delegateAndChannel(on: second) + } + futureResults.append(client.execute(request: request, eventLoop: preference)) + } + + let results = try EventLoopFuture.whenAllComplete(futureResults, on: elg.next()).wait() + + for result in results { + switch result { + case .success: + break + case .failure(let error): + XCTFail("Unexpected error: \(error)") + } + } + } + + func testMakeSecondRequestDuringCancelledCallout() { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) // needs to be 1 thread! + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let el = group.next() + + let web = NIOHTTP1TestServer(group: el) + defer { + // This will throw as we've started the request but haven't fulfilled it. + XCTAssertThrowsError(try web.stop()) + } + + let url = "http://127.0.0.1:\(web.serverPort)" + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(el)) + defer { + XCTAssertThrowsError(try httpClient.syncShutdown(requiresCleanClose: true)) { error in + XCTAssertEqual(.alreadyShutdown, error as? HTTPClientError) + } + } + + let seenError = DispatchGroup() + seenError.enter() + var maybeSecondRequest: EventLoopFuture? + XCTAssertNoThrow(maybeSecondRequest = try el.submit { + let neverSucceedingRequest = httpClient.get(url: url) + let secondRequest = neverSucceedingRequest.flatMapError { error in + XCTAssertEqual(.cancelled, error as? HTTPClientError) + seenError.leave() + return httpClient.get(url: url) // <== this is the main part, during the error callout, we call back in + } + return secondRequest + }.wait()) + + guard let secondRequest = maybeSecondRequest else { + XCTFail("couldn't get request future") + return + } + + // Let's pull out the request .head so we know the request has started (but nothing else) + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) + + XCTAssertNoThrow(try httpClient.syncShutdown()) + + seenError.wait() + XCTAssertThrowsError(try secondRequest.wait()) { error in + XCTAssertEqual(.alreadyShutdown, error as? HTTPClientError) + } + } + + func testMakeSecondRequestDuringSuccessCallout() { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) // needs to be 1 thread! + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let el = group.next() + + let web = HTTPBin() + defer { + XCTAssertNoThrow(try web.shutdown()) + } + + let url = "http://127.0.0.1:\(web.port)/get" + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(el)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + } + + XCTAssertNoThrow(XCTAssertEqual(.ok, + try el.flatSubmit { () -> EventLoopFuture in + httpClient.get(url: url).flatMap { firstResponse in + XCTAssertEqual(.ok, firstResponse.status) + return httpClient.get(url: url) // <== interesting bit here + } + }.wait().status)) + } + + func testMakeSecondRequestWhilstFirstIsOngoing() { + let web = NIOHTTP1TestServer(group: self.group) + defer { + XCTAssertNoThrow(try web.stop()) + } + + let client = HTTPClient(eventLoopGroupProvider: .shared(self.group)) + defer { + XCTAssertNoThrow(try client.syncShutdown()) + } + + let url = "http://127.0.0.1:\(web.serverPort)" + let firstRequest = client.get(url: url) + + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + + // Now, the first request is ongoing but not complete, let's start a second one + let secondRequest = client.get(url: url) + XCTAssertNoThrow(XCTAssertEqual(.end(nil), try web.readInbound())) // first request: .end + + XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) + XCTAssertNoThrow(try web.writeOutbound(.end(nil))) + + XCTAssertNoThrow(XCTAssertEqual(.ok, try firstRequest.wait().status)) + + // Okay, first request done successfully, let's do the second one too. + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertNoThrow(XCTAssertEqual(.end(nil), try web.readInbound())) // first request: .end + + XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .created)))) + XCTAssertNoThrow(try web.writeOutbound(.end(nil))) + XCTAssertNoThrow(XCTAssertEqual(.created, try secondRequest.wait().status)) + } + func testUDSBasic() { // This tests just connecting to a URL where the whole URL is the UNIX domain socket path like // unix:///this/is/my/socket.sock @@ -1048,4 +1521,29 @@ class HTTPClientTests: XCTestCase { try httpClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"])) }) } + + func testUseExistingConnectionOnDifferentEL() throws { + let threadCount = 16 + let elg = MultiThreadedEventLoopGroup(numberOfThreads: threadCount) + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(elg)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + XCTAssertNoThrow(try elg.syncShutdownGracefully()) + } + + let eventLoops = (1...threadCount).map { _ in elg.next() } + let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get") + let closingRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", headers: ["Connection": "close"]) + + for (index, el) in eventLoops.enumerated() { + if index.isMultiple(of: 2) { + XCTAssertNoThrow(try httpClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + } else { + XCTAssertNoThrow(try httpClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow(try httpClient.execute(request: closingRequest, eventLoop: .indifferent).wait()) + } + } + } }