diff --git a/.gitignore b/.gitignore index f1b4020f3..cca8e27d0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .build Package.resolved *.xcodeproj +DerivedData diff --git a/Sources/NIOHTTPClient/HTTPClientProxyHandler.swift b/Sources/NIOHTTPClient/HTTPClientProxyHandler.swift new file mode 100644 index 000000000..fffa794c5 --- /dev/null +++ b/Sources/NIOHTTPClient/HTTPClientProxyHandler.swift @@ -0,0 +1,149 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIOHTTPClient open source project +// +// Copyright (c) 2018-2019 Swift Server Working Group and the SwiftNIOHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIOHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO +import NIOHTTP1 + +/// Specifies the remote address of an HTTP proxy. +/// +/// Adding an `HTTPClientProxy` to your client's `HTTPClientConfiguration` +/// will cause requests to be passed through the specified proxy using the +/// HTTP `CONNECT` method. +/// +/// If a `TLSConfiguration` is used in conjunction with `HTTPClientProxy`, +/// TLS will be established _after_ successful proxy, between your client +/// and the destination server. +public extension HTTPClient { + struct Proxy { + internal let host: String + internal let port: Int + + public static func server(host: String, port: Int) -> Proxy { + return .init(host: host, port: port) + } + } +} + +internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChannelHandler { + typealias InboundIn = HTTPClientResponsePart + typealias OutboundIn = HTTPClientRequestPart + typealias OutboundOut = HTTPClientRequestPart + + enum WriteItem { + case write(NIOAny, EventLoopPromise?) + case flush + } + + enum ReadState { + case awaitingResponse + case connecting + case connected + } + + private let host: String + private let port: Int + private var onConnect: (Channel) -> EventLoopFuture + private var writeBuffer: CircularBuffer + private var readBuffer: CircularBuffer + private var readState: ReadState + + init(host: String, port: Int, onConnect: @escaping (Channel) -> EventLoopFuture) { + self.host = host + self.port = port + self.onConnect = onConnect + self.writeBuffer = .init() + self.readBuffer = .init() + self.readState = .awaitingResponse + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.readState { + case .awaitingResponse: + let res = self.unwrapInboundIn(data) + switch res { + case .head(let head): + switch head.status.code { + case 200..<300: + // Any 2xx (Successful) response indicates that the sender (and all + // inbound proxies) will switch to tunnel mode immediately after the + // blank line that concludes the successful response's header section + break + default: + // Any response other than a successful response + // indicates that the tunnel has not yet been formed and that the + // connection remains governed by HTTP. + context.fireErrorCaught(HTTPClientError.invalidProxyResponse) + } + case .end: + self.readState = .connecting + _ = self.handleConnect(context: context) + case .body: + break + } + case .connecting: + self.readBuffer.append(data) + case .connected: + context.fireChannelRead(data) + } + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + self.writeBuffer.append(.write(data, promise)) + } + + func flush(context: ChannelHandlerContext) { + self.writeBuffer.append(.flush) + } + + func channelActive(context: ChannelHandlerContext) { + self.sendConnect(context: context) + context.fireChannelActive() + } + + // MARK: Private + + private func handleConnect(context: ChannelHandlerContext) -> EventLoopFuture { + return self.onConnect(context.channel).flatMap { + self.readState = .connected + + // forward any buffered reads + while !self.readBuffer.isEmpty { + context.fireChannelRead(self.readBuffer.removeFirst()) + } + + // calls to context.write may be re-entrant + while !self.writeBuffer.isEmpty { + switch self.writeBuffer.removeFirst() { + case .flush: + context.flush() + case .write(let data, let promise): + context.write(data, promise: promise) + } + } + return context.pipeline.removeHandler(self) + } + } + + private func sendConnect(context: ChannelHandlerContext) { + var head = HTTPRequestHead( + version: .init(major: 1, minor: 1), + method: .CONNECT, + uri: "\(self.host):\(self.port)" + ) + head.headers.add(name: "proxy-connection", value: "keep-alive") + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } +} diff --git a/Sources/NIOHTTPClient/SwiftNIOHTTP.swift b/Sources/NIOHTTPClient/SwiftNIOHTTP.swift index 188c8b15d..f7fc1130f 100644 --- a/Sources/NIOHTTPClient/SwiftNIOHTTP.swift +++ b/Sources/NIOHTTPClient/SwiftNIOHTTP.swift @@ -124,24 +124,33 @@ public class HTTPClient { var bootstrap = ClientBootstrap(group: group) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - self.configureSSL(channel: channel, useTLS: request.useTLS, hostname: request.host) - }.flatMap { - if let readTimeout = timeout.read { - return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout)) - } else { - return channel.eventLoop.makeSucceededFuture(()) + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)) + return channel.pipeline.addHandlers([encoder, decoder], position: .first).flatMap { + switch self.configuration.proxy { + case .none: + return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration) + case .some: + return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration) + } + }.flatMap { + if let readTimeout = timeout.read { + return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout)) + } else { + return channel.eventLoop.makeSucceededFuture(()) + } + }.flatMap { + let taskHandler = TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler) + return channel.pipeline.addHandler(taskHandler) } - }.flatMap { - channel.pipeline.addHandler(TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler)) - } } if let connectTimeout = timeout.connect { bootstrap = bootstrap.connectTimeout(connectTimeout) } - - bootstrap.connect(host: request.host, port: request.port) + + let address = self.resolveAddress(request: request, proxy: self.configuration.proxy) + bootstrap.connect(host: address.host, port: address.port) .map { channel in task.setChannel(channel) } @@ -155,18 +164,12 @@ public class HTTPClient { return task } - private func configureSSL(channel: Channel, useTLS: Bool, hostname: String) -> EventLoopFuture { - if useTLS { - do { - let tlsConfiguration = self.configuration.tlsConfiguration ?? TLSConfiguration.forClient() - let context = try NIOSSLContext(configuration: tlsConfiguration) - return channel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: hostname), - position: .first) - } catch { - return channel.eventLoop.makeFailedFuture(error) - } - } else { - return channel.eventLoop.makeSucceededFuture(()) + private func resolveAddress(request: Request, proxy: Proxy?) -> (host: String, port: Int) { + switch self.configuration.proxy { + case .none: + return (request.host, request.port) + case .some(let proxy): + return (proxy.host, proxy.port) } } @@ -174,17 +177,20 @@ public class HTTPClient { public var tlsConfiguration: TLSConfiguration? public var followRedirects: Bool public var timeout: Timeout + public var proxy: Proxy? - public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout()) { + public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) { self.tlsConfiguration = tlsConfiguration self.followRedirects = followRedirects self.timeout = timeout + self.proxy = proxy } - public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout()) { + public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) { self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification) self.followRedirects = followRedirects self.timeout = timeout + self.proxy = proxy } } @@ -199,6 +205,37 @@ public class HTTPClient { } } +private extension ChannelPipeline { + func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { + let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in + return channel.pipeline.removeHandler(decoder).flatMap { + return channel.pipeline.addHandler( + ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)), + position: .after(encoder) + ) + }.flatMap { + return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: tlsConfiguration) + } + }) + return self.addHandler(handler) + } + + func addSSLHandlerIfNeeded(for request: HTTPClient.Request, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { + guard request.useTLS else { + return self.eventLoop.makeSucceededFuture(()) + } + + do { + let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient() + let context = try NIOSSLContext(configuration: tlsConfiguration) + return self.addHandler(try NIOSSLClientHandler(context: context, serverHostname: request.host), + position: .first) + } catch { + return self.eventLoop.makeFailedFuture(error) + } + } +} + public struct HTTPClientError: Error, Equatable, CustomStringConvertible { private enum Code: Equatable { case invalidURL @@ -211,6 +248,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case cancelled case identityCodingIncorrectlyPresent case chunkedSpecifiedMultipleTimes + case invalidProxyResponse } private var code: Code @@ -233,4 +271,5 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let cancelled = HTTPClientError(code: .cancelled) public static let identityCodingIncorrectlyPresent = HTTPClientError(code: .identityCodingIncorrectlyPresent) public static let chunkedSpecifiedMultipleTimes = HTTPClientError(code: .chunkedSpecifiedMultipleTimes) + public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse) } diff --git a/Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift b/Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift index d93fb6a14..bb684d915 100644 --- a/Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift @@ -86,17 +86,27 @@ internal class HttpBin { return self.serverChannel.localAddress! } - init(ssl: Bool = false) { - self.serverChannel = try! ServerBootstrap(group: self.group) + static func configureTLS(channel: Channel) -> EventLoopFuture { + let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))], + privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem))) + let context = try! NIOSSLContext(configuration: configuration) + return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first) + } + + init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil) { + self.serverChannel = try! ServerBootstrap(group: group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { + return channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { + if let simulateProxy = simulateProxy { + return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first) + } else { + return channel.eventLoop.makeSucceededFuture(()) + } + }.flatMap { if ssl { - let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))], - privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem))) - let context = try! NIOSSLContext(configuration: configuration) - return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first).flatMap { + return HttpBin.configureTLS(channel: channel).flatMap { channel.pipeline.addHandler(HttpBinHandler()) } } else { @@ -111,6 +121,48 @@ internal class HttpBin { } } +final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + typealias OutboundOut = ByteBuffer + + enum Option { + case plaintext + case tls + } + + let option: Option + + init(option: Option) { + self.option = option + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let response = """ + HTTP/1.1 200 OK\r\n\ + Content-Length: 0\r\n\ + Connection: close\r\n\ + \r\n + """ + var buffer = self.unwrapInboundIn(data) + let request = buffer.readString(length: buffer.readableBytes)! + if request.hasPrefix("CONNECT") { + var buffer = context.channel.allocator.buffer(capacity: 0) + buffer.writeString(response) + context.write(self.wrapInboundOut(buffer), promise: nil) + context.flush() + context.channel.pipeline.removeHandler(self, promise: nil) + switch self.option { + case .tls: + _ = HttpBin.configureTLS(channel: context.channel) + case .plaintext: break + } + } else { + fatalError("Expected a CONNECT request") + } + } +} + internal struct HTTPResponseBuilder { let head: HTTPResponseHead var body: ByteBuffer? diff --git a/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift b/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift index 81367f387..760e21cf9 100644 --- a/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift +++ b/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift @@ -38,6 +38,8 @@ extension SwiftHTTPTests { ("testRemoteClose", testRemoteClose), ("testReadTimeout", testReadTimeout), ("testCancel", testCancel), + ("testProxyPlaintext", testProxyPlaintext), + ("testProxyTLS", testProxyTLS), ] } } diff --git a/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift b/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift index d2d6fd520..38e99dcc7 100644 --- a/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift +++ b/Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift @@ -265,4 +265,35 @@ class SwiftHTTPTests: XCTestCase { } } } + + func testProxyPlaintext() throws { + let httpBin = HttpBin(simulateProxy: .plaintext) + let httpClient = HTTPClient( + eventLoopGroupProvider: .createNew, + configuration: .init(proxy: .server(host: "localhost", port: httpBin.port)) + ) + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + let res = try httpClient.get(url: "http://test/ok").wait() + XCTAssertEqual(res.status, .ok) + } + + func testProxyTLS() throws { + let httpBin = HttpBin(simulateProxy: .tls) + let httpClient = HTTPClient( + eventLoopGroupProvider: .createNew, + configuration: .init( + certificateVerification: .none, + proxy: .server(host: "localhost", port: httpBin.port) + ) + ) + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + let res = try httpClient.get(url: "https://test/ok").wait() + XCTAssertEqual(res.status, .ok) + } }