diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 10ff09a8f..40b07f67b 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -900,27 +900,25 @@ extension ChannelPipeline { try sync.addHandler(handler) } - func syncAddSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, addSSLClient: Bool, handshakePromise: EventLoopPromise) { - guard key.scheme.requiresTLS else { - handshakePromise.succeed(()) - return - } + func syncAddLateSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, handshakePromise: EventLoopPromise) { + precondition(key.scheme.requiresTLS) do { let synchronousPipelineView = self.syncOperations // We add the TLSEventsHandler first so that it's always in the pipeline before any other TLS handler we add. + // If we're here, we must not have one in the channel already. + assert((try? synchronousPipelineView.context(name: TLSEventsHandler.handlerName)) == nil) let eventsHandler = TLSEventsHandler(completionPromise: handshakePromise) - try synchronousPipelineView.addHandler(eventsHandler) - - if addSSLClient { - let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient() - let context = try NIOSSLContext(configuration: tlsConfiguration) - try synchronousPipelineView.addHandler( - try NIOSSLClientHandler(context: context, serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host), - position: .before(eventsHandler) - ) - } + try synchronousPipelineView.addHandler(eventsHandler, name: TLSEventsHandler.handlerName) + + // Then we add the SSL handler. + let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient() + let context = try NIOSSLContext(configuration: tlsConfiguration) + try synchronousPipelineView.addHandler( + try NIOSSLClientHandler(context: context, serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host), + position: .before(eventsHandler) + ) } catch { handshakePromise.fail(error) } @@ -930,7 +928,9 @@ extension ChannelPipeline { class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = NIOAny - var completionPromise: EventLoopPromise? + static let handlerName: String = "AsyncHTTPClient.HTTPClient.TLSEventsHandler" + + var completionPromise: EventLoopPromise init(completionPromise: EventLoopPromise) { self.completionPromise = completionPromise @@ -940,9 +940,7 @@ class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { if let tlsEvent = event as? TLSUserEvent { switch tlsEvent { case .handshakeCompleted: - self.completionPromise?.succeed(()) - self.completionPromise = nil - context.pipeline.removeHandler(self, promise: nil) + self.completionPromise.succeed(()) case .shutdownCompleted: break } @@ -951,15 +949,13 @@ class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { } func errorCaught(context: ChannelHandlerContext, error: Error) { - self.completionPromise?.fail(error) - self.completionPromise = nil - context.pipeline.removeHandler(self, promise: nil) + self.completionPromise.fail(error) context.fireErrorCaught(error) } func handlerRemoved(context: ChannelHandlerContext) { struct NoResult: Error {} - self.completionPromise?.fail(NoResult()) + self.completionPromise.fail(NoResult()) } } diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index 2f1bf0b40..24ec338ab 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -131,6 +131,11 @@ extension NIOClientTCPBootstrap { do { if let proxy = configuration.proxy { try channel.pipeline.syncAddProxyHandler(host: host, port: port, authorization: proxy.authorization) + } else if requiresTLS { + // We only add the handshake verifier if we need TLS and we're not going through a proxy. If we're going + // through a proxy we add it later. + let completionPromise = channel.eventLoop.makePromise(of: Void.self) + try channel.pipeline.syncOperations.addHandler(TLSEventsHandler(completionPromise: completionPromise), name: TLSEventsHandler.handlerName) } return channel.eventLoop.makeSucceededVoidFuture() } catch { @@ -162,14 +167,32 @@ extension NIOClientTCPBootstrap { } return channel.flatMap { channel in - let requiresSSLHandler = configuration.proxy != nil && key.scheme.requiresTLS - let handshakePromise = channel.eventLoop.makePromise(of: Void.self) - - channel.pipeline.syncAddSSLHandlerIfNeeded(for: key, tlsConfiguration: configuration.tlsConfiguration, addSSLClient: requiresSSLHandler, handshakePromise: handshakePromise) + let requiresTLS = key.scheme.requiresTLS + let requiresLateSSLHandler = configuration.proxy != nil && requiresTLS + let handshakeFuture: EventLoopFuture + + if requiresLateSSLHandler { + let handshakePromise = channel.eventLoop.makePromise(of: Void.self) + channel.pipeline.syncAddLateSSLHandlerIfNeeded(for: key, tlsConfiguration: configuration.tlsConfiguration, handshakePromise: handshakePromise) + handshakeFuture = handshakePromise.futureResult + } else if requiresTLS { + do { + handshakeFuture = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self).completionPromise.futureResult + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + } else { + handshakeFuture = channel.eventLoop.makeSucceededVoidFuture() + } - return handshakePromise.futureResult.flatMapThrowing { + return handshakeFuture.flatMapThrowing { let syncOperations = channel.pipeline.syncOperations + // If we got here and we had a TLSEventsHandler in the pipeline, we can remove it ow. + if requiresTLS { + channel.pipeline.removeHandler(name: TLSEventsHandler.handlerName, promise: nil) + } + try syncOperations.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) #if canImport(Network) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 263aed5a9..46b6d2bbf 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -129,6 +129,7 @@ extension HTTPClientTests { ("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose), ("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer), ("testBiDirectionalStreaming", testBiDirectionalStreaming), + ("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index b3e28744c..ec9a798a9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -2821,4 +2821,27 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try future.wait()) } + + func testSynchronousHandshakeErrorReporting() throws { + // This only affects cases where we use NIOSSL. + guard !isTestingNIOTS() else { return } + + // We use a specially crafted client that has no cipher suites to offer. To do this we ask + // only for cipher suites incompatible with our TLS version. + let tlsConfig = TLSConfiguration.forClient(minimumTLSVersion: .tlsv13, maximumTLSVersion: .tlsv12, certificateVerification: .none) + let localHTTPBin = HTTPBin(ssl: true) + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig)) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/").wait()) { error in + guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { + XCTFail("Unexpected error: \(error)") + return + } + } + } }