From c0e1cf27f29a6d8db59c2a302cfcf311fd69b326 Mon Sep 17 00:00:00 2001 From: vkill Date: Fri, 30 Aug 2019 10:57:17 +0800 Subject: [PATCH] Add authorization to proxy --- Sources/AsyncHTTPClient/HTTPClient.swift | 11 ++-- .../HTTPClientProxyHandler.swift | 23 +++++++- Sources/AsyncHTTPClient/HTTPHandler.swift | 35 ++++++++++++ .../HTTPClientTestUtils.swift | 54 ++++++++++++------- .../HTTPClientTests+XCTest.swift | 3 ++ .../HTTPClientTests.swift | 39 ++++++++++++++ 6 files changed, 139 insertions(+), 26 deletions(-) diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 218a70568..662b66307 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -232,8 +232,8 @@ public class HTTPClient { switch self.configuration.proxy { case .none: return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration) - case .some: - return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration) + case .some(let proxy): + return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration, proxy: proxy) } }.flatMap { if let timeout = self.resolve(timeout: self.configuration.timeout.read, deadline: deadline) { @@ -386,8 +386,8 @@ public class HTTPClient { } private extension ChannelPipeline { - func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { - let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in + func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?, proxy: HTTPClient.Configuration.Proxy?) -> EventLoopFuture { + let handler = HTTPClientProxyHandler(host: request.host, port: request.port, authorization: proxy?.authorization, onConnect: { channel in channel.pipeline.removeHandler(decoder).flatMap { return channel.pipeline.addHandler( ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)), @@ -431,6 +431,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case chunkedSpecifiedMultipleTimes case invalidProxyResponse case contentLengthMissing + case proxyAuthenticationRequired } private var code: Code @@ -467,4 +468,6 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let invalidProxyResponse = HTTPClientError(code: .invalidProxyResponse) /// Request does not contain `Content-Length` header. public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing) + /// Proxy Authentication Required + public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired) } diff --git a/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift b/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift index 72cdf49fe..491ca5b88 100644 --- a/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift @@ -31,6 +31,8 @@ public extension HTTPClient.Configuration { public var host: String /// Specifies Proxy server port. public var port: Int + /// Specifies Proxy server authorization. + public var authorization: HTTPClient.Authorization? /// Create proxy. /// @@ -38,7 +40,17 @@ public extension HTTPClient.Configuration { /// - host: proxy server host. /// - port: proxy server port. public static func server(host: String, port: Int) -> Proxy { - return .init(host: host, port: port) + return .init(host: host, port: port, authorization: nil) + } + + /// Create proxy. + /// + /// - parameters: + /// - host: proxy server host. + /// - port: proxy server port. + /// - authorization: proxy server authorization. + public static func server(host: String, port: Int, authorization: HTTPClient.Authorization? = nil) -> Proxy { + return .init(host: host, port: port, authorization: authorization) } } } @@ -61,14 +73,16 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan private let host: String private let port: Int + private let authorization: HTTPClient.Authorization? private var onConnect: (Channel) -> EventLoopFuture private var writeBuffer: CircularBuffer private var readBuffer: CircularBuffer private var readState: ReadState - init(host: String, port: Int, onConnect: @escaping (Channel) -> EventLoopFuture) { + init(host: String, port: Int, authorization: HTTPClient.Authorization?, onConnect: @escaping (Channel) -> EventLoopFuture) { self.host = host self.port = port + self.authorization = authorization self.onConnect = onConnect self.writeBuffer = .init() self.readBuffer = .init() @@ -87,6 +101,8 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan // inbound proxies) will switch to tunnel mode immediately after the // blank line that concludes the successful response's header section break + case 407: + context.fireErrorCaught(HTTPClientError.proxyAuthenticationRequired) default: // Any response other than a successful response // indicates that the tunnel has not yet been formed and that the @@ -150,6 +166,9 @@ internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChan uri: "\(self.host):\(self.port)" ) head.headers.add(name: "proxy-connection", value: "keep-alive") + if let authorization = authorization { + head.headers.add(name: "proxy-authorization", value: authorization.headerValue) + } context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 088b7f562..f1952fd3a 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -194,6 +194,41 @@ extension HTTPClient { self.body = body } } + + /// HTTP authentication + public struct Authorization { + private enum Scheme { + case Basic(String) + case Bearer(String) + } + + private let scheme: Scheme + + private init(scheme: Scheme) { + self.scheme = scheme + } + + public static func basic(username: String, password: String) -> HTTPClient.Authorization { + return .basic(credentials: Data("\(username):\(password)".utf8).base64EncodedString()) + } + + public static func basic(credentials: String) -> HTTPClient.Authorization { + return .init(scheme: .Basic(credentials)) + } + + public static func bearer(tokens: String) -> HTTPClient.Authorization { + return .init(scheme: .Bearer(tokens)) + } + + public var headerValue: String { + switch self.scheme { + case .Basic(let credentials): + return "Basic \(credentials)" + case .Bearer(let tokens): + return "Bearer \(tokens)" + } + } + } } internal class ResponseAccumulator: HTTPClientResponseDelegate { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 1546f3834..a43b3bff4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -110,7 +110,10 @@ internal class HttpBin { .childChannelInitializer { channel in channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap { if let simulateProxy = simulateProxy { - return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first) + let responseEncoder = HTTPResponseEncoder() + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + + return channel.pipeline.addHandlers([responseEncoder, requestDecoder, HTTPProxySimulator(option: simulateProxy, encoder: responseEncoder, decoder: requestDecoder)], position: .first) } else { return channel.eventLoop.makeSucceededFuture(()) } @@ -132,9 +135,9 @@ internal class HttpBin { } final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = ByteBuffer - typealias InboundOut = ByteBuffer - typealias OutboundOut = ByteBuffer + typealias InboundIn = HTTPServerRequestPart + typealias InboundOut = HTTPServerResponsePart + typealias OutboundOut = HTTPServerResponsePart enum Option { case plaintext @@ -142,33 +145,44 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { } let option: Option + let encoder: HTTPResponseEncoder + let decoder: ByteToMessageHandler + var head: HTTPResponseHead - init(option: Option) { + init(option: Option, encoder: HTTPResponseEncoder, decoder: ByteToMessageHandler) { self.option = option + self.encoder = encoder + self.decoder = decoder + self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0"), ("Connection", "close")])) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let response = """ - HTTP/1.1 200 OK\r\n\ - Content-Length: 0\r\n\ - Connection: close\r\n\ - \r\n - """ - var buffer = self.unwrapInboundIn(data) - let request = buffer.readString(length: buffer.readableBytes)! - if request.hasPrefix("CONNECT") { - var buffer = context.channel.allocator.buffer(capacity: 0) - buffer.writeString(response) - context.write(self.wrapInboundOut(buffer), promise: nil) - context.flush() + let request = self.unwrapInboundIn(data) + switch request { + case .head(let head): + guard head.method == .CONNECT else { + fatalError("Expected a CONNECT request") + } + if head.headers.contains(name: "proxy-authorization") { + if head.headers["proxy-authorization"].first != "Basic YWxhZGRpbjpvcGVuc2VzYW1l" { + self.head.status = .proxyAuthenticationRequired + } + } + case .body: + () + case .end: + context.write(self.wrapOutboundOut(.head(self.head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + context.channel.pipeline.removeHandler(self, promise: nil) + context.channel.pipeline.removeHandler(self.decoder, promise: nil) + context.channel.pipeline.removeHandler(self.encoder, promise: nil) + switch self.option { case .tls: _ = HttpBin.configureTLS(channel: context.channel) case .plaintext: break } - } else { - fatalError("Expected a CONNECT request") } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 93f66905b..37c976806 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -42,8 +42,11 @@ extension HTTPClientTests { ("testReadTimeout", testReadTimeout), ("testDeadline", testDeadline), ("testCancel", testCancel), + ("testHTTPClientAuthorization", testHTTPClientAuthorization), ("testProxyPlaintext", testProxyPlaintext), ("testProxyTLS", testProxyTLS), + ("testProxyPlaintextWithCorrectlyAuthorization", testProxyPlaintextWithCorrectlyAuthorization), + ("testProxyPlaintextWithIncorrectlyAuthorization", testProxyPlaintextWithIncorrectlyAuthorization), ("testUploadStreaming", testUploadStreaming), ("testNoContentLengthForSSLUncleanShutdown", testNoContentLengthForSSLUncleanShutdown), ("testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown", testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 70a776c4d..95a4b3415 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -293,6 +293,14 @@ class HTTPClientTests: XCTestCase { } } + func testHTTPClientAuthorization() { + var authorization = HTTPClient.Authorization.basic(username: "aladdin", password: "opensesame") + XCTAssertEqual(authorization.headerValue, "Basic YWxhZGRpbjpvcGVuc2VzYW1l") + + authorization = HTTPClient.Authorization.bearer(tokens: "mF_9.B5f-4.1JqM") + XCTAssertEqual(authorization.headerValue, "Bearer mF_9.B5f-4.1JqM") + } + func testProxyPlaintext() throws { let httpBin = HttpBin(simulateProxy: .plaintext) let httpClient = HTTPClient( @@ -324,6 +332,37 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(res.status, .ok) } + func testProxyPlaintextWithCorrectlyAuthorization() throws { + let httpBin = HttpBin(simulateProxy: .plaintext) + let httpClient = HTTPClient( + eventLoopGroupProvider: .createNew, + configuration: .init(proxy: .server(host: "localhost", port: httpBin.port, authorization: .basic(username: "aladdin", password: "opensesame"))) + ) + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + let res = try httpClient.get(url: "http://test/ok").wait() + XCTAssertEqual(res.status, .ok) + } + + func testProxyPlaintextWithIncorrectlyAuthorization() throws { + let httpBin = HttpBin(simulateProxy: .plaintext) + let httpClient = HTTPClient( + eventLoopGroupProvider: .createNew, + configuration: .init(proxy: .server(host: "localhost", port: httpBin.port, authorization: .basic(username: "aladdin", password: "opensesamefoo"))) + ) + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + XCTAssertThrowsError(try httpClient.get(url: "http://test/ok").wait(), "Should fail") { error in + guard case let error = error as? HTTPClientError, error == .proxyAuthenticationRequired else { + return XCTFail("Should fail with HTTPClientError.proxyAuthenticationRequired") + } + } + } + func testUploadStreaming() throws { let httpBin = HttpBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)