diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 8dfd55ab8..7c3333c49 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -59,48 +59,38 @@ public class HTTPClient { } public func get(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try Request(url: url, method: .GET) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .GET) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try HTTPClient.Request(url: url, method: .POST, body: body) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .POST, body: body) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try HTTPClient.Request(url: url, method: .PATCH, body: body) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .PATCH, body: body) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try HTTPClient.Request(url: url, method: .PUT, body: body) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .PUT, body: body) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func delete(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - do { - let request = try Request(url: url, method: .DELETE) - return self.execute(request: request, deadline: deadline) - } catch { - return self.eventLoopGroup.next().makeFailedFuture(error) + guard let request = Request(url: url, method: .DELETE) else { + return self.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidURL) } + return self.execute(request: request, deadline: deadline) } public func execute(request: Request, deadline: NIODeadline? = nil) -> EventLoopFuture { diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 862c7a371..83519e116 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -60,33 +60,34 @@ extension HTTPClient { } public struct Request { - public var version: HTTPVersion - public var method: HTTPMethod - public var url: URL - public var scheme: String - public var host: String + public let version: HTTPVersion + public let method: HTTPMethod + public let url: URL public var headers: HTTPHeaders public var body: Body? + + internal let scheme: String + internal let host: String - public init(url: String, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init?(url: String, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) { guard let url = URL(string: url) else { - throw HTTPClientError.invalidURL + return nil } - try self.init(url: url, version: version, method: method, headers: headers, body: body) + self.init(url: url, version: version, method: method, headers: headers, body: body) } - public init(url: URL, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { - guard let scheme = url.scheme else { - throw HTTPClientError.emptyScheme + public init?(url: URL, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) { + guard let scheme = url.scheme?.lowercased() else { + return nil } guard Request.isSchemeSupported(scheme: scheme) else { - throw HTTPClientError.unsupportedScheme(scheme) + return nil } guard let host = url.host else { - throw HTTPClientError.emptyHost + return nil } self.version = version @@ -99,14 +100,14 @@ extension HTTPClient { } public var useTLS: Bool { - return self.url.scheme == "https" + return self.scheme == "https" } public var port: Int { return self.url.port ?? (self.useTLS ? 443 : 80) } - static func isSchemeSupported(scheme: String?) -> Bool { + static func isSchemeSupported(scheme: String) -> Bool { return scheme == "http" || scheme == "https" } } @@ -514,7 +515,7 @@ internal struct RedirectHandler { return nil } - guard HTTPClient.Request.isSchemeSupported(scheme: url.scheme) else { + guard HTTPClient.Request.isSchemeSupported(scheme: request.scheme) else { return nil } @@ -526,22 +527,7 @@ internal struct RedirectHandler { } func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise) { - let originalURL = self.request.url - - var request = self.request - request.url = redirectURL - - if let redirectHost = redirectURL.host { - request.host = redirectHost - } else { - preconditionFailure("redirectURL doesn't contain a host") - } - - if let redirectScheme = redirectURL.scheme { - request.scheme = redirectScheme - } else { - preconditionFailure("redirectURL doesn't contain a scheme") - } + let originalRequest = self.request var convertToGet = false if status == .seeOther, request.method != .HEAD { @@ -549,21 +535,28 @@ internal struct RedirectHandler { } else if status == .movedPermanently || status == .found, request.method == .POST { convertToGet = true } - + + var method = originalRequest.method + var headers = originalRequest.headers + var body = originalRequest.body + if convertToGet { - request.method = .GET - request.body = nil - request.headers.remove(name: "Content-Length") - request.headers.remove(name: "Content-Type") + method = .GET + body = nil + headers.remove(name: "Content-Length") + headers.remove(name: "Content-Type") } - if !originalURL.hasTheSameOrigin(as: redirectURL) { - request.headers.remove(name: "Origin") - request.headers.remove(name: "Cookie") - request.headers.remove(name: "Authorization") - request.headers.remove(name: "Proxy-Authorization") + if !originalRequest.url.hasTheSameOrigin(as: redirectURL) { + headers.remove(name: "Origin") + headers.remove(name: "Cookie") + headers.remove(name: "Authorization") + headers.remove(name: "Proxy-Authorization") + } + + guard let request = HTTPClient.Request(url: redirectURL, version: originalRequest.version, method: method, headers: headers, body: body) else { + return promise.fail(HTTPClientError.invalidURL) } - return self.execute(request).futureResult.cascade(to: promise) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 486a4e249..e59f49c93 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -30,7 +30,7 @@ class HTTPClientInternalTests: XCTestCase { try channel.pipeline.addHandler(recorder).wait() try channel.pipeline.addHandler(TaskHandler(task: task, delegate: TestHTTPDelegate(), redirectHandler: nil)).wait() - var request = try Request(url: "http://localhost/get") + var request = Request(url: "http://localhost/get")! request.headers.add(name: "X-Test-Header", value: "X-Test-Value") request.body = .string("1234") @@ -82,17 +82,13 @@ class HTTPClientInternalTests: XCTestCase { } let body: HTTPClient.Body = .stream(length: 50) { writer in - do { - var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") - request.headers.add(name: "Accept", value: "text/event-stream") + var request = Request(url: "http://localhost:\(httpBin.port)/events/10/1")! + request.headers.add(name: "Accept", value: "text/event-stream") - let delegate = HTTPClientCopyingDelegate { part in - writer.write(.byteBuffer(part)) - } - return httpClient.execute(request: request, delegate: delegate).futureResult - } catch { - return httpClient.eventLoopGroup.next().makeFailedFuture(error) + let delegate = HTTPClientCopyingDelegate { part in + writer.write(.byteBuffer(part)) } + return httpClient.execute(request: request, delegate: delegate).futureResult } let upload = try! httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait() @@ -118,17 +114,13 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) body = .stream(length: 50) { _ in - do { - var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") - request.headers.add(name: "Accept", value: "text/event-stream") + var request = Request(url: "http://localhost:\(httpBin.port)/events/10/1")! + request.headers.add(name: "Accept", value: "text/event-stream") - let delegate = HTTPClientCopyingDelegate { _ in - httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse) - } - return httpClient.execute(request: request, delegate: delegate).futureResult - } catch { - return httpClient.eventLoopGroup.next().makeFailedFuture(error) + let delegate = HTTPClientCopyingDelegate { _ in + httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse) } + return httpClient.execute(request: request, delegate: delegate).futureResult } XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) @@ -172,7 +164,7 @@ class HTTPClientInternalTests: XCTestCase { httpBin.shutdown() } - let request = try Request(url: "http://localhost:\(httpBin.port)/custom") + let request = Request(url: "http://localhost:\(httpBin.port)/custom")! let delegate = BackpressureTestDelegate(promise: httpClient.eventLoopGroup.next().makePromise()) let future = httpClient.execute(request: request, delegate: delegate).futureResult diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 0313576e9..86c98bdbd 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -26,6 +26,8 @@ extension HTTPClientTests { static var allTests: [(String, (HTTPClientTests) -> () throws -> Void)] { return [ ("testRequestURI", testRequestURI), + ("testBadRequestURI", testBadRequestURI), + ("testSchemaCasing", testSchemaCasing), ("testGet", testGet), ("testGetWithSharedEventLoopGroup", testGetWithSharedEventLoopGroup), ("testPost", testPost), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 8e0a61230..8080ce8af 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -import AsyncHTTPClient +@testable import AsyncHTTPClient import NIO import NIOFoundationCompat import NIOHTTP1 @@ -22,17 +22,34 @@ class HTTPClientTests: XCTestCase { typealias Request = HTTPClient.Request func testRequestURI() throws { - let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar") + let request1 = Request(url: "https://someserver.com:8888/some/path?foo=bar")! XCTAssertEqual(request1.host, "someserver.com") XCTAssertEqual(request1.url.path, "/some/path") XCTAssertEqual(request1.url.query!, "foo=bar") XCTAssertEqual(request1.port, 8888) XCTAssertTrue(request1.useTLS) - let request2 = try Request(url: "https://someserver.com") + let request2 = Request(url: "https://someserver.com")! XCTAssertEqual(request2.url.path, "") } + func testBadRequestURI() throws { + let request1 = Request(url: "file://somewhere/some/path?foo=bar") + XCTAssertNil(request1) + + let url = URL(string: "file://somewhere/some/path?foo=bar")! + let request2 = Request(url: url) + XCTAssertNil(request2) + } + + func testSchemaCasing() throws { + let request1 = Request(url: "https://someserver.com:8888/some/path?foo=bar")! + XCTAssertNotNil(request1) + + let request2 = Request(url: "hTTpS://someserver.com:8888/some/path?foo=bar")! + XCTAssertNotNil(request2) + } + func testGet() throws { let httpBin = HttpBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) @@ -55,7 +72,7 @@ class HTTPClientTests: XCTestCase { } let delegate = TestHTTPDelegate() - let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/events/10/1") + let request = HTTPClient.Request(url: "http://localhost:\(httpBin.port)/events/10/1")! let task = httpClient.execute(request: request, delegate: delegate) let expectedEventLoop = task.eventLoop task.futureResult.whenComplete { (_) in @@ -102,8 +119,7 @@ class HTTPClientTests: XCTestCase { httpBin.shutdown() } - let request = try Request(url: "https://localhost:\(httpBin.port)/post", method: .POST, body: .string("1234")) - + let request = Request(url: "https://localhost:\(httpBin.port)/post", method: .POST, body: .string("1234"))! let response = try httpClient.execute(request: request).wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -181,7 +197,7 @@ class HTTPClientTests: XCTestCase { var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "12") - let request = try Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, headers: headers, body: .byteBuffer(body)) + let request = Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, headers: headers, body: .byteBuffer(body))! let response = try httpClient.execute(request: request).wait() // if the library adds another content length header we'll get a bad request error. XCTAssertEqual(.ok, response.status) @@ -195,7 +211,7 @@ class HTTPClientTests: XCTestCase { httpBin.shutdown() } - var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") + var request = Request(url: "http://localhost:\(httpBin.port)/events/10/1")! request.headers.add(name: "Accept", value: "text/event-stream") let delegate = CountingDelegate() @@ -262,7 +278,7 @@ class HTTPClientTests: XCTestCase { } let queue = DispatchQueue(label: "nio-test") - let request = try Request(url: "http://localhost:\(httpBin.port)/wait") + let request = Request(url: "http://localhost:\(httpBin.port)/wait")! let task = httpClient.execute(request: request, delegate: TestHTTPDelegate()) queue.asyncAfter(deadline: .now() + .milliseconds(100)) {