diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 63be3aa37..4aecf6d17 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -18,6 +18,7 @@ import NIO import NIOConcurrencyHelpers import NIOHTTP1 import NIOHTTPCompression +import NIOSSL import NIOTLS import NIOTransportServices @@ -41,6 +42,8 @@ final class ConnectionPool { private let backgroundActivityLogger: Logger + let sslContextCache = SSLContextCache() + init(configuration: HTTPClient.Configuration, backgroundActivityLogger: Logger) { self.configuration = configuration self.backgroundActivityLogger = backgroundActivityLogger @@ -106,6 +109,8 @@ final class ConnectionPool { self.providers.values } + self.sslContextCache.shutdown() + return EventLoopFuture.reduce(true, providers.map { $0.close() }, on: eventLoop) { $0 && $1 } } @@ -148,7 +153,7 @@ final class ConnectionPool { var host: String var port: Int var unixPath: String - var tlsConfiguration: BestEffortHashableTLSConfiguration? + private var tlsConfiguration: BestEffortHashableTLSConfiguration? enum Scheme: Hashable { case http @@ -249,14 +254,15 @@ class HTTP1ConnectionProvider { } else { logger.trace("opening fresh connection (found matching but inactive connection)", metadata: ["ahc-dead-connection": "\(connection)"]) - self.makeChannel(preference: waiter.preference).whenComplete { result in + self.makeChannel(preference: waiter.preference, + logger: logger).whenComplete { result in self.connect(result, waiter: waiter, logger: logger) } } } case .create(let waiter): logger.trace("opening fresh connection (no connections to reuse available)") - self.makeChannel(preference: waiter.preference).whenComplete { result in + self.makeChannel(preference: waiter.preference, logger: logger).whenComplete { result in self.connect(result, waiter: waiter, logger: logger) } case .replace(let connection, let waiter): @@ -266,7 +272,7 @@ class HTTP1ConnectionProvider { logger.trace("opening fresh connection (replacing exising connection)", metadata: ["ahc-old-connection": "\(connection)", "ahc-waiter": "\(waiter)"]) - self.makeChannel(preference: waiter.preference).whenComplete { result in + self.makeChannel(preference: waiter.preference, logger: logger).whenComplete { result in self.connect(result, waiter: waiter, logger: logger) } } @@ -434,8 +440,14 @@ class HTTP1ConnectionProvider { return self.closePromise.futureResult.map { true } } - private func makeChannel(preference: HTTPClient.EventLoopPreference) -> EventLoopFuture { - return NIOClientTCPBootstrap.makeHTTP1Channel(destination: self.key, eventLoop: self.eventLoop, configuration: self.configuration, preference: preference) + private func makeChannel(preference: HTTPClient.EventLoopPreference, + logger: Logger) -> EventLoopFuture { + return NIOClientTCPBootstrap.makeHTTP1Channel(destination: self.key, + eventLoop: self.eventLoop, + configuration: self.configuration, + sslContextCache: self.pool.sslContextCache, + preference: preference, + logger: logger) } /// A `Waiter` represents a request that waits for a connection when none is diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 40b07f67b..ec549d993 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -900,7 +900,9 @@ extension ChannelPipeline { try sync.addHandler(handler) } - func syncAddLateSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, handshakePromise: EventLoopPromise) { + func syncAddLateSSLHandlerIfNeeded(for key: ConnectionPool.Key, + sslContext: NIOSSLContext, + handshakePromise: EventLoopPromise) { precondition(key.scheme.requiresTLS) do { @@ -913,10 +915,9 @@ extension ChannelPipeline { try synchronousPipelineView.addHandler(eventsHandler, name: TLSEventsHandler.handlerName) // Then we add the SSL handler. - let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient() - let context = try NIOSSLContext(configuration: tlsConfiguration) try synchronousPipelineView.addHandler( - try NIOSSLClientHandler(context: context, serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host), + try NIOSSLClientHandler(context: sslContext, + serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host), position: .before(eventsHandler) ) } catch { diff --git a/Sources/AsyncHTTPClient/LRUCache.swift b/Sources/AsyncHTTPClient/LRUCache.swift new file mode 100644 index 000000000..0a01da0d2 --- /dev/null +++ b/Sources/AsyncHTTPClient/LRUCache.swift @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +struct LRUCache { + private typealias Generation = UInt64 + private struct Element { + var generation: Generation + var key: Key + var value: Value + } + + private let capacity: Int + private var generation: Generation = 0 + private var elements: [Element] + + init(capacity: Int = 8) { + precondition(capacity > 0, "capacity needs to be > 0") + self.capacity = capacity + self.elements = [] + self.elements.reserveCapacity(capacity) + } + + private mutating func bumpGenerationAndFindIndex(key: Key) -> Int? { + self.generation += 1 + + let found = self.elements.firstIndex { element in + element.key == key + } + + return found + } + + mutating func find(key: Key) -> Value? { + if let found = self.bumpGenerationAndFindIndex(key: key) { + self.elements[found].generation = self.generation + return self.elements[found].value + } else { + return nil + } + } + + @discardableResult + mutating func append(key: Key, value: Value) -> Value { + let newElement = Element(generation: self.generation, + key: key, + value: value) + if let found = self.bumpGenerationAndFindIndex(key: key) { + self.elements[found] = newElement + return value + } + + if self.elements.count < self.capacity { + self.elements.append(newElement) + return value + } + assert(self.elements.count == self.capacity) + assert(self.elements.count > 0) + + let minIndex = self.elements.minIndex { l, r in + l.generation < r.generation + }! + + self.elements.swapAt(minIndex, self.elements.endIndex - 1) + self.elements.removeLast() + self.elements.append(newElement) + + return value + } + + mutating func findOrAppend(key: Key, _ valueGenerator: (Key) -> Value) -> Value { + if let found = self.find(key: key) { + return found + } + + return self.append(key: key, value: valueGenerator(key)) + } +} + +extension Array { + func minIndex(by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows -> Index? { + guard var minSoFar: (Index, Element) = self.first.map({ (0, $0) }) else { + return nil + } + + for indexElement in self.enumerated() { + if try areInIncreasingOrder(indexElement.1, minSoFar.1) { + minSoFar = indexElement + } + } + + return minSoFar.0 + } +} diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift index 1f9dceb88..16e9d4717 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift @@ -13,13 +13,14 @@ //===----------------------------------------------------------------------===// #if canImport(Network) - import Network - import NIO - import NIOHTTP1 - import NIOTransportServices +#endif +import NIO +import NIOHTTP1 +import NIOTransportServices - extension HTTPClient { +extension HTTPClient { + #if canImport(Network) public struct NWPOSIXError: Error, CustomStringConvertible { /// POSIX error code (enum) public let errorCode: POSIXErrorCode @@ -57,28 +58,35 @@ public var description: String { return self.reason } } + #endif - @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) - class NWErrorHandler: ChannelInboundHandler { - typealias InboundIn = HTTPClientResponsePart + class NWErrorHandler: ChannelInboundHandler { + typealias InboundIn = HTTPClientResponsePart - func errorCaught(context: ChannelHandlerContext, error: Error) { - context.fireErrorCaught(NWErrorHandler.translateError(error)) - } + func errorCaught(context: ChannelHandlerContext, error: Error) { + context.fireErrorCaught(NWErrorHandler.translateError(error)) + } - static func translateError(_ error: Error) -> Error { - if let error = error as? NWError { - switch error { - case .tls(let status): - return NWTLSError(status, reason: error.localizedDescription) - case .posix(let errorCode): - return NWPOSIXError(errorCode, reason: error.localizedDescription) - default: - return error + static func translateError(_ error: Error) -> Error { + #if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + if let error = error as? NWError { + switch error { + case .tls(let status): + return NWTLSError(status, reason: error.localizedDescription) + case .posix(let errorCode): + return NWPOSIXError(errorCode, reason: error.localizedDescription) + default: + return error + } } + return error + } else { + preconditionFailure("\(self) used on a non-NIOTS Channel") } - return error - } + #else + preconditionFailure("\(self) used on a non-NIOTS Channel") + #endif } } -#endif +} diff --git a/Sources/AsyncHTTPClient/SSLContextCache.swift b/Sources/AsyncHTTPClient/SSLContextCache.swift new file mode 100644 index 000000000..582d2cee8 --- /dev/null +++ b/Sources/AsyncHTTPClient/SSLContextCache.swift @@ -0,0 +1,104 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO +import NIOConcurrencyHelpers +import NIOSSL + +class SSLContextCache { + private var state = State.activeNoThread + private let lock = Lock() + private var sslContextCache = LRUCache() + private let threadPool = NIOThreadPool(numberOfThreads: 1) + + enum State { + case activeNoThread + case active + case shutDown + } + + init() {} + + func shutdown() { + self.lock.withLock { () -> Void in + switch self.state { + case .activeNoThread: + self.state = .shutDown + case .active: + self.state = .shutDown + self.threadPool.shutdownGracefully { maybeError in + precondition(maybeError == nil, "\(maybeError!)") + } + case .shutDown: + preconditionFailure("SSLContextCache shut down twice") + } + } + } + + deinit { + assert(self.state == .shutDown) + } +} + +extension SSLContextCache { + private struct SSLContextCacheShutdownError: Error {} + + func sslContext(tlsConfiguration: TLSConfiguration, + eventLoop: EventLoop, + logger: Logger) -> EventLoopFuture { + let earlyExitError: Error? = self.lock.withLock { () -> Error? in + switch self.state { + case .activeNoThread: + self.state = .active + self.threadPool.start() + return nil + case .active: + return nil + case .shutDown: + return SSLContextCacheShutdownError() + } + } + + if let error = earlyExitError { + return eventLoop.makeFailedFuture(error) + } + + let eqTLSConfiguration = BestEffortHashableTLSConfiguration(wrapping: tlsConfiguration) + let sslContext = self.lock.withLock { + self.sslContextCache.find(key: eqTLSConfiguration) + } + + if let sslContext = sslContext { + logger.debug("found SSL context in cache", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + return eventLoop.makeSucceededFuture(sslContext) + } + + logger.debug("creating new SSL context", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + let newSSLContext = self.threadPool.runIfActive(eventLoop: eventLoop) { + try NIOSSLContext(configuration: tlsConfiguration) + } + + newSSLContext.whenSuccess { (newSSLContext: NIOSSLContext) -> Void in + self.lock.withLock { () -> Void in + self.sslContextCache.append(key: eqTLSConfiguration, + value: newSSLContext) + } + } + + return newSSLContext + } +} diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index 24ec338ab..1af7899b2 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -16,6 +16,7 @@ import Foundation #if canImport(Network) import Network #endif +import Logging import NIO import NIOHTTP1 import NIOHTTPCompression @@ -52,174 +53,207 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { } } -extension ClientBootstrap { - fileprivate func makeClientTCPBootstrap( - host: String, - requiresTLS: Bool, - configuration: HTTPClient.Configuration - ) throws -> NIOClientTCPBootstrap { - // if there is a proxy don't create TLS provider as it will be added at a later point - if configuration.proxy != nil { - return NIOClientTCPBootstrap(self, tls: NIOInsecureNoTLS()) +extension NIOClientTCPBootstrap { + static func makeHTTP1Channel(destination: ConnectionPool.Key, + eventLoop: EventLoop, + configuration: HTTPClient.Configuration, + sslContextCache: SSLContextCache, + preference: HTTPClient.EventLoopPreference, + logger: Logger) -> EventLoopFuture { + let channelEventLoop = preference.bestEventLoop ?? eventLoop + + let key = destination + let requiresTLS = key.scheme.requiresTLS + let sslContext: EventLoopFuture + if key.scheme.requiresTLS, configuration.proxy != nil { + // If we use a proxy & also require TLS, then we always use NIOSSL (and not Network.framework TLS because + // it can't be added later) and therefore require a `NIOSSLContext`. + // In this case, `makeAndConfigureBootstrap` will not create another `NIOSSLContext`. + // + // Note that TLS proxies are not supported at the moment. This means that we will always speak + // plaintext to the proxy but we do support sending HTTPS traffic through the proxy. + sslContext = sslContextCache.sslContext(tlsConfiguration: configuration.tlsConfiguration ?? .forClient(), + eventLoop: eventLoop, + logger: logger).map { $0 } } else { - let tlsConfiguration = configuration.tlsConfiguration ?? TLSConfiguration.forClient() - let sslContext = try NIOSSLContext(configuration: tlsConfiguration) - let hostname = (!requiresTLS || host.isIPAddress || host.isEmpty) ? nil : host - let tlsProvider = try NIOSSLClientTLSProvider(context: sslContext, serverHostname: hostname) - return NIOClientTCPBootstrap(self, tls: tlsProvider) + sslContext = eventLoop.makeSucceededFuture(nil) } - } -} -extension NIOClientTCPBootstrap { - /// create a TCP Bootstrap based off what type of `EventLoop` has been passed to the function. - fileprivate static func makeBootstrap( - on eventLoop: EventLoop, - host: String, - requiresTLS: Bool, - configuration: HTTPClient.Configuration - ) throws -> NIOClientTCPBootstrap { - var bootstrap: NIOClientTCPBootstrap - #if canImport(Network) - // if eventLoop is compatible with NIOTransportServices create a NIOTSConnectionBootstrap - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { - // if there is a proxy don't create TLS provider as it will be added at a later point - if configuration.proxy != nil { - bootstrap = NIOClientTCPBootstrap(tsBootstrap, tls: NIOInsecureNoTLS()) + let bootstrap = NIOClientTCPBootstrap.makeAndConfigureBootstrap(on: channelEventLoop, + host: key.host, + port: key.port, + requiresTLS: requiresTLS, + configuration: configuration, + sslContextCache: sslContextCache, + logger: logger) + return bootstrap.flatMap { bootstrap -> EventLoopFuture in + let channel: EventLoopFuture + switch key.scheme { + case .http, .https: + let address = HTTPClient.resolveAddress(host: key.host, port: key.port, proxy: configuration.proxy) + channel = bootstrap.connect(host: address.host, port: address.port) + case .unix, .http_unix, .https_unix: + channel = bootstrap.connect(unixDomainSocketPath: key.unixPath) + } + + return channel.flatMap { channel -> EventLoopFuture<(Channel, NIOSSLContext?)> in + sslContext.map { sslContext -> (Channel, NIOSSLContext?) in + (channel, sslContext) + } + }.flatMap { channel, sslContext in + configureChannelPipeline(channel, + isNIOTS: bootstrap.isNIOTS, + sslContext: sslContext, + configuration: configuration, + key: key) + }.flatMapErrorThrowing { error in + if bootstrap.isNIOTS { + throw HTTPClient.NWErrorHandler.translateError(error) } else { - // create NIOClientTCPBootstrap with NIOTS TLS provider - let tlsConfiguration = configuration.tlsConfiguration ?? TLSConfiguration.forClient() - let parameters = tlsConfiguration.getNWProtocolTLSOptions() - let tlsProvider = NIOTSClientTLSProvider(tlsOptions: parameters) - bootstrap = NIOClientTCPBootstrap(tsBootstrap, tls: tlsProvider) + throw error } - } else if let clientBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - bootstrap = try clientBootstrap.makeClientTCPBootstrap(host: host, requiresTLS: requiresTLS, configuration: configuration) - } else { - preconditionFailure("Cannot create bootstrap for the supplied EventLoop") } - #else - if let clientBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - bootstrap = try clientBootstrap.makeClientTCPBootstrap(host: host, requiresTLS: requiresTLS, configuration: configuration) - } else { - preconditionFailure("Cannot create bootstrap for the supplied EventLoop") - } - #endif - - if let timeout = configuration.timeout.connect { - bootstrap = bootstrap.connectTimeout(timeout) - } - - // don't enable TLS if we have a proxy, this will be enabled later on - if requiresTLS, configuration.proxy == nil { - return bootstrap.enableTLS() } - - return bootstrap } - static func makeHTTPClientBootstrapBase( + /// Creates and configures a bootstrap given the `eventLoop`, if TLS/a proxy is being used. + private static func makeAndConfigureBootstrap( on eventLoop: EventLoop, host: String, port: Int, requiresTLS: Bool, - configuration: HTTPClient.Configuration - ) throws -> NIOClientTCPBootstrap { - return try self.makeBootstrap(on: eventLoop, host: host, requiresTLS: requiresTLS, configuration: configuration) - .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) - .channelInitializer { channel in - do { - if let proxy = configuration.proxy { - try channel.pipeline.syncAddProxyHandler(host: host, port: port, authorization: proxy.authorization) - } else if requiresTLS { - // We only add the handshake verifier if we need TLS and we're not going through a proxy. If we're going - // through a proxy we add it later. - let completionPromise = channel.eventLoop.makePromise(of: Void.self) - try channel.pipeline.syncOperations.addHandler(TLSEventsHandler(completionPromise: completionPromise), name: TLSEventsHandler.handlerName) + configuration: HTTPClient.Configuration, + sslContextCache: SSLContextCache, + logger: Logger + ) -> EventLoopFuture { + return self.makeBestBootstrap(host: host, + eventLoop: eventLoop, + requiresTLS: requiresTLS, + sslContextCache: sslContextCache, + tlsConfiguration: configuration.tlsConfiguration ?? .forClient(), + useProxy: configuration.proxy != nil, + logger: logger) + .map { bootstrap -> NIOClientTCPBootstrap in + var bootstrap = bootstrap + + if let timeout = configuration.timeout.connect { + bootstrap = bootstrap.connectTimeout(timeout) + } + + // Don't enable TLS if we have a proxy, this will be enabled later on (outside of this method). + if requiresTLS, configuration.proxy == nil { + bootstrap = bootstrap.enableTLS() + } + + return bootstrap.channelInitializer { channel in + do { + if let proxy = configuration.proxy { + try channel.pipeline.syncAddProxyHandler(host: host, + port: port, + authorization: proxy.authorization) + } else if requiresTLS { + // We only add the handshake verifier if we need TLS and we're not going through a proxy. + // If we're going through a proxy we add it later (outside of this method). + let completionPromise = channel.eventLoop.makePromise(of: Void.self) + try channel.pipeline.syncOperations.addHandler(TLSEventsHandler(completionPromise: completionPromise), + name: TLSEventsHandler.handlerName) + } + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) } - return channel.eventLoop.makeSucceededVoidFuture() - } catch { - return channel.eventLoop.makeFailedFuture(error) } } } - static func makeHTTP1Channel(destination: ConnectionPool.Key, eventLoop: EventLoop, configuration: HTTPClient.Configuration, preference: HTTPClient.EventLoopPreference) -> EventLoopFuture { - let channelEventLoop = preference.bestEventLoop ?? eventLoop + /// Creates the best-suited bootstrap given an `EventLoop` and pairs it with an appropriate TLS provider. + private static func makeBestBootstrap( + host: String, + eventLoop: EventLoop, + requiresTLS: Bool, + sslContextCache: SSLContextCache, + tlsConfiguration: TLSConfiguration, + useProxy: Bool, + logger: Logger + ) -> EventLoopFuture { + #if canImport(Network) + // if eventLoop is compatible with NIOTransportServices create a NIOTSConnectionBootstrap + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + // create NIOClientTCPBootstrap with NIOTS TLS provider + let parameters = tlsConfiguration.getNWProtocolTLSOptions() + let tlsProvider = NIOTSClientTLSProvider(tlsOptions: parameters) + return eventLoop.makeSucceededFuture(NIOClientTCPBootstrap(tsBootstrap, tls: tlsProvider)) + } + #endif - let key = destination + if let clientBootstrap = ClientBootstrap(validatingGroup: eventLoop) { + // If there is a proxy don't create TLS provider as it will be added at a later point. + if !requiresTLS || useProxy { + return eventLoop.makeSucceededFuture(NIOClientTCPBootstrap(clientBootstrap, + tls: NIOInsecureNoTLS())) + } else { + return sslContextCache.sslContext(tlsConfiguration: tlsConfiguration, + eventLoop: eventLoop, + logger: logger) + .flatMapThrowing { sslContext in + let hostname = (host.isIPAddress || host.isEmpty) ? nil : host + let tlsProvider = try NIOSSLClientTLSProvider(context: sslContext, serverHostname: hostname) + return NIOClientTCPBootstrap(clientBootstrap, tls: tlsProvider) + } + } + } - let requiresTLS = key.scheme.requiresTLS - let bootstrap: NIOClientTCPBootstrap + preconditionFailure("Cannot create bootstrap for event loop \(eventLoop)") + } +} + +private func configureChannelPipeline(_ channel: Channel, + isNIOTS: Bool, + sslContext: NIOSSLContext?, + configuration: HTTPClient.Configuration, + key: ConnectionPool.Key) -> EventLoopFuture { + let requiresTLS = key.scheme.requiresTLS + let handshakeFuture: EventLoopFuture + + if requiresTLS, configuration.proxy != nil { + let handshakePromise = channel.eventLoop.makePromise(of: Void.self) + channel.pipeline.syncAddLateSSLHandlerIfNeeded(for: key, + sslContext: sslContext!, + handshakePromise: handshakePromise) + handshakeFuture = handshakePromise.futureResult + } else if requiresTLS { do { - bootstrap = try NIOClientTCPBootstrap.makeHTTPClientBootstrapBase(on: channelEventLoop, host: key.host, port: key.port, requiresTLS: requiresTLS, configuration: configuration) + handshakeFuture = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self).completionPromise.futureResult } catch { - return channelEventLoop.makeFailedFuture(error) - } - - let channel: EventLoopFuture - switch key.scheme { - case .http, .https: - let address = HTTPClient.resolveAddress(host: key.host, port: key.port, proxy: configuration.proxy) - channel = bootstrap.connect(host: address.host, port: address.port) - case .unix, .http_unix, .https_unix: - channel = bootstrap.connect(unixDomainSocketPath: key.unixPath) + return channel.eventLoop.makeFailedFuture(error) } + } else { + handshakeFuture = channel.eventLoop.makeSucceededVoidFuture() + } - return channel.flatMap { channel in - let requiresTLS = key.scheme.requiresTLS - let requiresLateSSLHandler = configuration.proxy != nil && requiresTLS - let handshakeFuture: EventLoopFuture - - if requiresLateSSLHandler { - let handshakePromise = channel.eventLoop.makePromise(of: Void.self) - channel.pipeline.syncAddLateSSLHandlerIfNeeded(for: key, tlsConfiguration: configuration.tlsConfiguration, handshakePromise: handshakePromise) - handshakeFuture = handshakePromise.futureResult - } else if requiresTLS { - do { - handshakeFuture = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self).completionPromise.futureResult - } catch { - return channel.eventLoop.makeFailedFuture(error) - } - } else { - handshakeFuture = channel.eventLoop.makeSucceededVoidFuture() - } - - return handshakeFuture.flatMapThrowing { - let syncOperations = channel.pipeline.syncOperations + return handshakeFuture.flatMapThrowing { + let syncOperations = channel.pipeline.syncOperations - // If we got here and we had a TLSEventsHandler in the pipeline, we can remove it ow. - if requiresTLS { - channel.pipeline.removeHandler(name: TLSEventsHandler.handlerName, promise: nil) - } + // If we got here and we had a TLSEventsHandler in the pipeline, we can remove it ow. + if requiresTLS { + channel.pipeline.removeHandler(name: TLSEventsHandler.handlerName, promise: nil) + } - try syncOperations.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) + try syncOperations.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) - #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), bootstrap.underlyingBootstrap is NIOTSConnectionBootstrap { - try syncOperations.addHandler(HTTPClient.NWErrorHandler(), position: .first) - } - #endif - - switch configuration.decompression { - case .disabled: - () - case .enabled(let limit): - let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) - try syncOperations.addHandler(decompressHandler) - } + if isNIOTS { + try syncOperations.addHandler(HTTPClient.NWErrorHandler(), position: .first) + } - return channel - } - }.flatMapError { error in - #if canImport(Network) - var error = error - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), bootstrap.underlyingBootstrap is NIOTSConnectionBootstrap { - error = HTTPClient.NWErrorHandler.translateError(error) - } - #endif - return channelEventLoop.makeFailedFuture(error) + switch configuration.decompression { + case .disabled: + () + case .enabled(let limit): + let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) + try syncOperations.addHandler(decompressHandler) } + + return channel } } @@ -230,3 +264,17 @@ extension Connection { }.recover { _ in } } } + +extension NIOClientTCPBootstrap { + var isNIOTS: Bool { + #if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + return self.underlyingBootstrap is NIOTSConnectionBootstrap + } else { + return false + } + #else + return false + #endif + } +} diff --git a/Tests/AsyncHTTPClientTests/ConnectionTests.swift b/Tests/AsyncHTTPClientTests/ConnectionTests.swift index e01cefc99..c1191124c 100644 --- a/Tests/AsyncHTTPClientTests/ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/ConnectionTests.swift @@ -19,6 +19,7 @@ import XCTest class ConnectionTests: XCTestCase { var eventLoop: EmbeddedEventLoop! var http1ConnectionProvider: HTTP1ConnectionProvider! + var pool: ConnectionPool! func buildState(connection: Connection, release: Bool) { XCTAssertTrue(self.http1ConnectionProvider.state.enqueue()) @@ -131,24 +132,30 @@ class ConnectionTests: XCTestCase { } override func setUp() { + XCTAssertNil(self.pool) XCTAssertNil(self.eventLoop) XCTAssertNil(self.http1ConnectionProvider) self.eventLoop = EmbeddedEventLoop() - XCTAssertNoThrow(self.http1ConnectionProvider = try HTTP1ConnectionProvider(key: .init(.init(url: "http://some.test")), - eventLoop: self.eventLoop, - configuration: .init(), - pool: .init(configuration: .init(), - backgroundActivityLogger: HTTPClient.loggingDisabled), - backgroundActivityLogger: HTTPClient.loggingDisabled)) + self.pool = ConnectionPool(configuration: .init(), + backgroundActivityLogger: HTTPClient.loggingDisabled) + XCTAssertNoThrow(self.http1ConnectionProvider = + try HTTP1ConnectionProvider(key: .init(.init(url: "http://some.test")), + eventLoop: self.eventLoop, + configuration: .init(), + pool: self.pool, + backgroundActivityLogger: HTTPClient.loggingDisabled)) } override func tearDown() { + XCTAssertNotNil(self.pool) XCTAssertNotNil(self.eventLoop) XCTAssertNotNil(self.http1ConnectionProvider) XCTAssertNoThrow(try self.http1ConnectionProvider.close().wait()) XCTAssertNoThrow(try self.eventLoop.syncShutdownGracefully()) - self.eventLoop = nil self.http1ConnectionProvider = nil + XCTAssertTrue(try self.pool.close(on: self.eventLoop).wait()) + self.eventLoop = nil + self.pool = nil } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift index ce71c5fab..a8d2088d7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift @@ -51,14 +51,15 @@ class HTTPClientNIOTSTests: XCTestCase { func testTLSFailError() { guard isTestingNIOTS() else { return } - #if canImport(Network) - let httpBin = HTTPBin(ssl: true) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) - defer { - XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) - XCTAssertNoThrow(try httpBin.shutdown()) - } + let httpBin = HTTPBin(ssl: true) + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + #if canImport(Network) do { _ = try httpClient.get(url: "https://localhost:\(httpBin.port)/get").wait() XCTFail("This should have failed") @@ -68,6 +69,8 @@ class HTTPClientNIOTSTests: XCTestCase { } catch { XCTFail("Error should have been NWTLSError not \(type(of: error))") } + #else + XCTFail("wrong OS") #endif } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 7e9e2b5d6..98bfb0b54 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -37,6 +37,7 @@ extension HTTPClientTests { ("testPost", testPost), ("testGetHttps", testGetHttps), ("testGetHttpsWithIP", testGetHttpsWithIP), + ("testGetHTTPSWorksOnMTELGWithIP", testGetHTTPSWorksOnMTELGWithIP), ("testPostHttps", testPostHttps), ("testHttpRedirect", testHttpRedirect), ("testHttpHostRedirect", testHttpHostRedirect), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index ed8c1acc5..3e313088d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -313,6 +313,25 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, response.status) } + func testGetHTTPSWorksOnMTELGWithIP() throws { + // Same test as above but this one will use NIO on Sockets even on Apple platforms, just to make sure + // this works. + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let localHTTPBin = HTTPBin(ssl: true) + let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + let response = try localClient.get(url: "https://127.0.0.1:\(localHTTPBin.port)/get").wait() + XCTAssertEqual(.ok, response.status) + } + func testPostHttps() throws { let localHTTPBin = HTTPBin(ssl: true) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift new file mode 100644 index 000000000..a0231bf0d --- /dev/null +++ b/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// LRUCacheTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension LRUCacheTests { + static var allTests: [(String, (LRUCacheTests) -> () throws -> Void)] { + return [ + ("testBasicsWork", testBasicsWork), + ("testCachesTheRightThings", testCachesTheRightThings), + ("testAppendingTheSameDoesNotEvictButUpdates", testAppendingTheSameDoesNotEvictButUpdates), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift new file mode 100644 index 000000000..6392bcebe --- /dev/null +++ b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import XCTest + +class LRUCacheTests: XCTestCase { + func testBasicsWork() { + var cache = LRUCache(capacity: 1) + var requestedValueGens = 0 + for i in 0..<10 { + let actual = cache.findOrAppend(key: i) { i in + requestedValueGens += 1 + return i + } + XCTAssertEqual(i, actual) + } + XCTAssertEqual(10, requestedValueGens) + + let nine = cache.findOrAppend(key: 9) { i in + XCTAssertEqual(9, i) + XCTFail("9 should be in the cache") + return -1 + } + XCTAssertEqual(9, nine) + } + + func testCachesTheRightThings() { + var cache = LRUCache(capacity: 3) + + for i in 0..<10 { + let actual = cache.findOrAppend(key: i) { i in + i + } + XCTAssertEqual(i, actual) + + let zero = cache.find(key: 0) + XCTAssertEqual(0, zero, "at \(i), couldn't find 0") + + cache.append(key: -1, value: -1) + XCTAssertEqual(-1, cache.find(key: -1)) + } + + XCTAssertEqual(0, cache.find(key: 0)) + XCTAssertEqual(9, cache.find(key: 9)) + + for i in 1..<9 { + XCTAssertNil(cache.find(key: i)) + } + } + + func testAppendingTheSameDoesNotEvictButUpdates() { + var cache = LRUCache(capacity: 3) + + cache.append(key: 1, value: 1) + cache.append(key: 3, value: 3) + for i in (2...100).reversed() { + cache.append(key: 2, value: i) + XCTAssertEqual(i, cache.find(key: 2)) + } + + for i in 1...3 { + XCTAssertEqual(i, cache.find(key: i)) + } + + cache.append(key: 4, value: 4) + XCTAssertNil(cache.find(key: 1)) + for i in 2...4 { + XCTAssertEqual(i, cache.find(key: i)) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift new file mode 100644 index 000000000..78a2d5e6f --- /dev/null +++ b/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// SSLContextCacheTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension SSLContextCacheTests { + static var allTests: [(String, (SSLContextCacheTests) -> () throws -> Void)] { + return [ + ("testJustStartingAndStoppingAContextCacheWorks", testJustStartingAndStoppingAContextCacheWorks), + ("testRequestingSSLContextWorks", testRequestingSSLContextWorks), + ("testRequestingSSLContextAfterShutdownThrows", testRequestingSSLContextAfterShutdownThrows), + ("testCacheWorks", testCacheWorks), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift new file mode 100644 index 000000000..980e7f94a --- /dev/null +++ b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import NIO +import NIOSSL +import XCTest + +final class SSLContextCacheTests: XCTestCase { + func testJustStartingAndStoppingAContextCacheWorks() { + SSLContextCache().shutdown() + } + + func testRequestingSSLContextWorks() { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = group.next() + let cache = SSLContextCache() + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + cache.shutdown() + } + + XCTAssertNoThrow(try cache.sslContext(tlsConfiguration: .forClient(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled).wait()) + } + + func testRequestingSSLContextAfterShutdownThrows() { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = group.next() + let cache = SSLContextCache() + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + cache.shutdown() + XCTAssertThrowsError(try cache.sslContext(tlsConfiguration: .forClient(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled).wait()) + } + + func testCacheWorks() { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = group.next() + let cache = SSLContextCache() + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + cache.shutdown() + } + + var firstContext: NIOSSLContext? + var secondContext: NIOSSLContext? + + XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .forClient(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: .forClient(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNotNil(firstContext) + XCTAssertNotNil(secondContext) + XCTAssert(firstContext === secondContext) + } +} diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index c6dfc8291..0db0dd9ce 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -32,6 +32,8 @@ import XCTest testCase(HTTPClientInternalTests.allTests), testCase(HTTPClientNIOTSTests.allTests), testCase(HTTPClientTests.allTests), + testCase(LRUCacheTests.allTests), testCase(RequestValidationTests.allTests), + testCase(SSLContextCacheTests.allTests), ]) #endif