diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 8dfd55ab8..259725aed 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -59,57 +59,61 @@ public class HTTPClient { } public func get(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try Request(url: url, method: .GET) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .GET) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try HTTPClient.Request(url: url, method: .POST, body: body) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .POST, body: body) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try HTTPClient.Request(url: url, method: .PATCH, body: body) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = HTTPClient.Request(url: url, method: .PATCH, body: body) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try HTTPClient.Request(url: url, method: .PUT, body: body) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = HTTPClient.Request(url: url, method: .PUT, body: body) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func delete(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { + guard let request = Request(url: url, method: .DELETE) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) + } + return self.execute(request: request, deadline: deadline) + } + + public func execute(request: Request, deadline: NIODeadline? = nil) -> EventLoopFuture { do { - let request = try Request(url: url, method: .DELETE) - return self.execute(request: request, deadline: deadline) + let requestWithHost = try RequestWithHost(request: request) + let accumulator = ResponseAccumulator(host: requestWithHost.host) + return self.execute(request: requestWithHost, delegate: accumulator, deadline: deadline).futureResult } catch { return self.eventLoopGroup.next().makeFailedFuture(error) } } - public func execute(request: Request, deadline: NIODeadline? = nil) -> EventLoopFuture { - let accumulator = ResponseAccumulator(request: request) - return self.execute(request: request, delegate: accumulator, deadline: deadline).futureResult + public func execute(request: Request, delegate: T, deadline: NIODeadline? = nil) -> Task { + do { + return try self.execute(request: RequestWithHost(request: request), delegate: delegate, deadline: deadline) + } catch { + return Task(eventLoop: self.eventLoopGroup.next(), error: error) + } } - public func execute(request: Request, delegate: T, deadline: NIODeadline? = nil) -> Task { + private func execute(request: RequestWithHost, delegate: T, deadline: NIODeadline? = nil) -> Task { let eventLoop = self.eventLoopGroup.next() + let task = Task(eventLoop: eventLoop) let redirectHandler: RedirectHandler? if self.configuration.followRedirects { @@ -120,8 +124,6 @@ public class HTTPClient { redirectHandler = nil } - let task = Task(eventLoop: eventLoop) - var bootstrap = ClientBootstrap(group: eventLoop) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) .channelInitializer { channel in @@ -178,7 +180,7 @@ public class HTTPClient { } } - private func resolveAddress(request: Request, proxy: Proxy?) -> (host: String, port: Int) { + private func resolveAddress(request: RequestWithHost, proxy: Proxy?) -> (host: String, port: Int) { switch self.configuration.proxy { case .none: return (request.host, request.port) @@ -225,7 +227,7 @@ public class HTTPClient { } private extension ChannelPipeline { - func addProxyHandler(for request: HTTPClient.Request, decoder: ByteToMessageHandler, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { + func addProxyHandler(for request: HTTPClient.RequestWithHost, decoder: ByteToMessageHandler, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in channel.pipeline.removeHandler(decoder).flatMap { return channel.pipeline.addHandler( @@ -239,7 +241,7 @@ private extension ChannelPipeline { return self.addHandler(handler) } - func addSSLHandlerIfNeeded(for request: HTTPClient.Request, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { + func addSSLHandlerIfNeeded(for request: HTTPClient.RequestWithHost, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture { guard request.useTLS else { return self.eventLoop.makeSucceededFuture(()) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 862c7a371..87a107029 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -60,49 +60,89 @@ extension HTTPClient { } public struct Request { - public var version: HTTPVersion public var method: HTTPMethod public var url: URL - public var scheme: String - public var host: String public var headers: HTTPHeaders public var body: Body? - public init(url: String, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init?(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) { guard let url = URL(string: url) else { - throw HTTPClientError.invalidURL + return nil } - try self.init(url: url, version: version, method: method, headers: headers, body: body) + self.init(url: url, method: method, headers: headers, body: body) } - public init(url: URL, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { - guard let scheme = url.scheme else { + public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) { + self.method = method + self.url = url + self.headers = headers + self.body = body + } + } + + struct RequestWithHost { + var request: Request + var host: String + + init(request: Request) throws { + guard let host = request.url.host else { + throw HTTPClientError.emptyHost + } + + guard let scheme = request.url.scheme else { throw HTTPClientError.emptyScheme } - guard Request.isSchemeSupported(scheme: scheme) else { + guard RequestWithHost.isSchemeSupported(scheme: scheme) else { throw HTTPClientError.unsupportedScheme(scheme) } - guard let host = url.host else { - throw HTTPClientError.emptyHost + self.request = request + self.host = host + } + + var method: HTTPMethod { + get { + return self.request.method + } + set { + self.request.method = newValue + } + } + + var url: URL { + get { + return self.request.url + } + set { + self.request.url = newValue } + } - self.version = version - self.method = method - self.url = url - self.scheme = scheme - self.host = host - self.headers = headers - self.body = body + var headers: HTTPHeaders { + get { + return self.request.headers + } + set { + self.request.headers = newValue + } } - public var useTLS: Bool { + var body: Body? { + get { + return self.request.body + } + set { + self.request.body = newValue + } + } + + var useTLS: Bool { return self.url.scheme == "https" } - public var port: Int { + var port: Int { return self.url.port ?? (self.useTLS ? 443 : 80) } @@ -131,10 +171,10 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate { } var state = State.idle - let request: HTTPClient.Request + let host: String - init(request: HTTPClient.Request) { - self.request = request + init(host: String) { + self.host = host } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { @@ -180,9 +220,9 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate { case .idle: preconditionFailure("no head received before end") case .head(let head): - return Response(host: self.request.host, status: head.status, headers: head.headers, body: nil) + return Response(host: self.host, status: head.status, headers: head.headers, body: nil) case .body(let head, let body): - return Response(host: self.request.host, status: head.status, headers: head.headers, body: body) + return Response(host: self.host, status: head.status, headers: head.headers, body: body) case .end: preconditionFailure("request already processed") case .error(let error): @@ -253,6 +293,11 @@ extension HTTPClient { self.lock = Lock() } + convenience init(eventLoop: EventLoop, error: Error) { + self.init(eventLoop: eventLoop) + self.fail(error) + } + public var futureResult: EventLoopFuture { return self.promise.futureResult } @@ -291,7 +336,7 @@ extension HTTPClient { internal struct TaskCancelEvent {} internal class TaskHandler: ChannelInboundHandler, ChannelOutboundHandler { - typealias OutboundIn = HTTPClient.Request + typealias OutboundIn = HTTPClient.RequestWithHost typealias InboundIn = HTTPClientResponsePart typealias OutboundOut = HTTPClientRequestPart @@ -322,10 +367,10 @@ internal class TaskHandler: ChannelInboundHandler self.state = .idle let request = unwrapOutboundIn(data) - var head = HTTPRequestHead(version: request.version, method: request.method, uri: request.url.uri) + var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: request.method, uri: request.url.uri) var headers = request.headers - if request.version.major == 1, request.version.minor == 1, !request.headers.contains(name: "Host") { + if !request.headers.contains(name: "Host") { headers.add(name: "Host", value: request.host) } @@ -345,7 +390,7 @@ internal class TaskHandler: ChannelInboundHandler self.delegate.didSendRequestHead(task: self.task, head) } - self.writeBody(request: request, context: context).whenComplete { result in + self.writeBody(request: request.request, context: context).whenComplete { result in switch result { case .success: context.write(self.wrapOutboundOut(.end(nil)), promise: promise) @@ -495,8 +540,8 @@ internal class TaskHandler: ChannelInboundHandler } internal struct RedirectHandler { - let request: HTTPClient.Request - let execute: (HTTPClient.Request) -> HTTPClient.Task + let request: HTTPClient.RequestWithHost + let execute: (HTTPClient.RequestWithHost) -> HTTPClient.Task func redirectTarget(status: HTTPResponseStatus, headers: HTTPHeaders) -> URL? { switch status { @@ -514,7 +559,7 @@ internal struct RedirectHandler { return nil } - guard HTTPClient.Request.isSchemeSupported(scheme: url.scheme) else { + guard HTTPClient.RequestWithHost.isSchemeSupported(scheme: url.scheme) else { return nil } @@ -537,9 +582,7 @@ internal struct RedirectHandler { preconditionFailure("redirectURL doesn't contain a host") } - if let redirectScheme = redirectURL.scheme { - request.scheme = redirectScheme - } else { + if redirectURL.scheme == nil { preconditionFailure("redirectURL doesn't contain a scheme") } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index 6ad239e02..260766bea 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -25,6 +25,7 @@ import XCTest extension HTTPClientInternalTests { static var allTests: [(String, (HTTPClientInternalTests) -> () throws -> Void)] { return [ + ("testRequestWithHost", testRequestWithHost), ("testHTTPPartsHandler", testHTTPPartsHandler), ("testHTTPPartsHandlerMultiBody", testHTTPPartsHandlerMultiBody), ("testProxyStreaming", testProxyStreaming), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 486a4e249..728a7d5cb 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -20,8 +20,17 @@ import XCTest class HTTPClientInternalTests: XCTestCase { typealias Request = HTTPClient.Request + typealias RequestWithHost = HTTPClient.RequestWithHost typealias Task = HTTPClient.Task + func testRequestWithHost() throws { + let request = try RequestWithHost(request: Request(url: "https://someserver.com:8888/some/path?foo=bar")!) + + XCTAssertEqual(request.host, "someserver.com") + XCTAssertEqual(request.port, 8888) + XCTAssertTrue(request.useTLS) + } + func testHTTPPartsHandler() throws { let channel = EmbeddedChannel() let recorder = RecordingHandler() @@ -30,7 +39,7 @@ class HTTPClientInternalTests: XCTestCase { try channel.pipeline.addHandler(recorder).wait() try channel.pipeline.addHandler(TaskHandler(task: task, delegate: TestHTTPDelegate(), redirectHandler: nil)).wait() - var request = try Request(url: "http://localhost/get") + var request = try RequestWithHost(request: Request(url: "http://localhost/get")!) request.headers.add(name: "X-Test-Header", value: "X-Test-Value") request.body = .string("1234") @@ -82,17 +91,13 @@ class HTTPClientInternalTests: XCTestCase { } let body: HTTPClient.Body = .stream(length: 50) { writer in - do { - var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") - request.headers.add(name: "Accept", value: "text/event-stream") + var request = Request(url: "http://localhost:\(httpBin.port)/events/10/1")! + request.headers.add(name: "Accept", value: "text/event-stream") - let delegate = HTTPClientCopyingDelegate { part in - writer.write(.byteBuffer(part)) - } - return httpClient.execute(request: request, delegate: delegate).futureResult - } catch { - return httpClient.eventLoopGroup.next().makeFailedFuture(error) + let delegate = HTTPClientCopyingDelegate { part in + writer.write(.byteBuffer(part)) } + return httpClient.execute(request: request, delegate: delegate).futureResult } let upload = try! httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait() @@ -118,17 +123,13 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) body = .stream(length: 50) { _ in - do { - var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") - request.headers.add(name: "Accept", value: "text/event-stream") + var request = Request(url: "http://localhost:\(httpBin.port)/events/10/1")! + request.headers.add(name: "Accept", value: "text/event-stream") - let delegate = HTTPClientCopyingDelegate { _ in - httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse) - } - return httpClient.execute(request: request, delegate: delegate).futureResult - } catch { - return httpClient.eventLoopGroup.next().makeFailedFuture(error) + let delegate = HTTPClientCopyingDelegate { _ in + httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse) } + return httpClient.execute(request: request, delegate: delegate).futureResult } XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) @@ -172,7 +173,7 @@ class HTTPClientInternalTests: XCTestCase { httpBin.shutdown() } - let request = try Request(url: "http://localhost:\(httpBin.port)/custom") + let request = Request(url: "http://localhost:\(httpBin.port)/custom")! let delegate = BackpressureTestDelegate(promise: httpClient.eventLoopGroup.next().makePromise()) let future = httpClient.execute(request: request, delegate: delegate).futureResult diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 8e0a61230..8da555d94 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -22,14 +22,11 @@ class HTTPClientTests: XCTestCase { typealias Request = HTTPClient.Request func testRequestURI() throws { - let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar") - XCTAssertEqual(request1.host, "someserver.com") + let request1 = Request(url: "https://someserver.com:8888/some/path?foo=bar")! XCTAssertEqual(request1.url.path, "/some/path") XCTAssertEqual(request1.url.query!, "foo=bar") - XCTAssertEqual(request1.port, 8888) - XCTAssertTrue(request1.useTLS) - let request2 = try Request(url: "https://someserver.com") + let request2 = Request(url: "https://someserver.com")! XCTAssertEqual(request2.url.path, "") } @@ -55,7 +52,7 @@ class HTTPClientTests: XCTestCase { } let delegate = TestHTTPDelegate() - let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/events/10/1") + let request = HTTPClient.Request(url: "http://localhost:\(httpBin.port)/events/10/1")! let task = httpClient.execute(request: request, delegate: delegate) let expectedEventLoop = task.eventLoop task.futureResult.whenComplete { (_) in @@ -102,7 +99,7 @@ class HTTPClientTests: XCTestCase { httpBin.shutdown() } - let request = try Request(url: "https://localhost:\(httpBin.port)/post", method: .POST, body: .string("1234")) + let request = Request(url: "https://localhost:\(httpBin.port)/post", method: .POST, body: .string("1234"))! let response = try httpClient.execute(request: request).wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } @@ -181,7 +178,7 @@ class HTTPClientTests: XCTestCase { var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "12") - let request = try Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, headers: headers, body: .byteBuffer(body)) + let request = Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, headers: headers, body: .byteBuffer(body))! let response = try httpClient.execute(request: request).wait() // if the library adds another content length header we'll get a bad request error. XCTAssertEqual(.ok, response.status) @@ -195,7 +192,7 @@ class HTTPClientTests: XCTestCase { httpBin.shutdown() } - var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") + var request = Request(url: "http://localhost:\(httpBin.port)/events/10/1")! request.headers.add(name: "Accept", value: "text/event-stream") let delegate = CountingDelegate() @@ -262,7 +259,7 @@ class HTTPClientTests: XCTestCase { } let queue = DispatchQueue(label: "nio-test") - let request = try Request(url: "http://localhost:\(httpBin.port)/wait") + let request = Request(url: "http://localhost:\(httpBin.port)/wait")! let task = httpClient.execute(request: request, delegate: TestHTTPDelegate()) queue.asyncAfter(deadline: .now() + .milliseconds(100)) {