diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 45810daf2..3de8aaa68 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -87,16 +87,14 @@ extension HTTPClient { /// Represent HTTP request. public struct Request { - /// Request HTTP version, defaults to `HTTP/1.1`. - public var version: HTTPVersion /// Request HTTP method, defaults to `GET`. - public var method: HTTPMethod + public let method: HTTPMethod /// Remote URL. - public var url: URL + public let url: URL /// Remote HTTP scheme, resolved from `URL`. - public var scheme: String + public let scheme: String /// Remote host, resolved from `URL`. - public var host: String + public let host: String /// Request custom HTTP Headers, defaults to no headers. public var headers: HTTPHeaders /// Request body, defaults to no body. @@ -115,12 +113,12 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { guard let url = URL(string: url) else { throw HTTPClientError.invalidURL } - try self.init(url: url, version: version, method: method, headers: headers, body: body) + try self.init(url: url, method: method, headers: headers, body: body) } /// Create an HTTP `Request`. @@ -135,8 +133,8 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: URL, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { - guard let scheme = url.scheme else { + public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + guard let scheme = url.scheme?.lowercased() else { throw HTTPClientError.emptyScheme } @@ -148,7 +146,6 @@ extension HTTPClient { throw HTTPClientError.emptyHost } - self.version = version self.method = method self.url = url self.scheme = scheme @@ -159,7 +156,7 @@ extension HTTPClient { /// Whether request will be executed using secure socket. public var useTLS: Bool { - return self.url.scheme == "https" + return self.scheme == "https" } /// Resolved port. @@ -167,7 +164,7 @@ extension HTTPClient { return self.url.port ?? (self.useTLS ? 443 : 80) } - static func isSchemeSupported(scheme: String?) -> Bool { + static func isSchemeSupported(scheme: String) -> Bool { return scheme == "http" || scheme == "https" } } @@ -444,10 +441,10 @@ internal class TaskHandler: ChannelInboundHandler self.state = .idle let request = unwrapOutboundIn(data) - var head = HTTPRequestHead(version: request.version, method: request.method, uri: request.url.uri) + var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: request.method, uri: request.url.uri) var headers = request.headers - if request.version.major == 1, request.version.minor == 1, !request.headers.contains(name: "Host") { + if !request.headers.contains(name: "Host") { headers.add(name: "Host", value: request.host) } @@ -640,7 +637,7 @@ internal struct RedirectHandler { return nil } - guard HTTPClient.Request.isSchemeSupported(scheme: url.scheme) else { + guard HTTPClient.Request.isSchemeSupported(scheme: self.request.scheme) else { return nil } @@ -652,44 +649,38 @@ internal struct RedirectHandler { } func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise) { - let originalURL = self.request.url - - var request = self.request - request.url = redirectURL - - if let redirectHost = redirectURL.host { - request.host = redirectHost - } else { - preconditionFailure("redirectURL doesn't contain a host") - } - - if let redirectScheme = redirectURL.scheme { - request.scheme = redirectScheme - } else { - preconditionFailure("redirectURL doesn't contain a scheme") - } + let originalRequest = self.request var convertToGet = false - if status == .seeOther, request.method != .HEAD { + if status == .seeOther, self.request.method != .HEAD { convertToGet = true - } else if status == .movedPermanently || status == .found, request.method == .POST { + } else if status == .movedPermanently || status == .found, self.request.method == .POST { convertToGet = true } + var method = originalRequest.method + var headers = originalRequest.headers + var body = originalRequest.body + if convertToGet { - request.method = .GET - request.body = nil - request.headers.remove(name: "Content-Length") - request.headers.remove(name: "Content-Type") + method = .GET + body = nil + headers.remove(name: "Content-Length") + headers.remove(name: "Content-Type") } - if !originalURL.hasTheSameOrigin(as: redirectURL) { - request.headers.remove(name: "Origin") - request.headers.remove(name: "Cookie") - request.headers.remove(name: "Authorization") - request.headers.remove(name: "Proxy-Authorization") + if !originalRequest.url.hasTheSameOrigin(as: redirectURL) { + headers.remove(name: "Origin") + headers.remove(name: "Cookie") + headers.remove(name: "Authorization") + headers.remove(name: "Proxy-Authorization") } - return self.execute(request).futureResult.cascade(to: promise) + do { + let newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body) + return self.execute(newRequest).futureResult.cascade(to: promise) + } catch { + return promise.fail(error) + } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 0313576e9..86c98bdbd 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -26,6 +26,8 @@ extension HTTPClientTests { static var allTests: [(String, (HTTPClientTests) -> () throws -> Void)] { return [ ("testRequestURI", testRequestURI), + ("testBadRequestURI", testBadRequestURI), + ("testSchemaCasing", testSchemaCasing), ("testGet", testGet), ("testGetWithSharedEventLoopGroup", testGetWithSharedEventLoopGroup), ("testPost", testPost), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index c6a7f664b..78a8bdb76 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -23,7 +23,7 @@ class HTTPClientTests: XCTestCase { func testRequestURI() throws { let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar") - XCTAssertEqual(request1.host, "someserver.com") + XCTAssertEqual(request1.url.host, "someserver.com") XCTAssertEqual(request1.url.path, "/some/path") XCTAssertEqual(request1.url.query!, "foo=bar") XCTAssertEqual(request1.port, 8888) @@ -33,6 +33,22 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(request2.url.path, "") } + func testBadRequestURI() throws { + XCTAssertThrowsError(try Request(url: "some/path"), "should throw") { error in + XCTAssertEqual(error as! HTTPClientError, HTTPClientError.emptyScheme) + } + XCTAssertThrowsError(try Request(url: "file://somewhere/some/path?foo=bar"), "should throw") { error in + XCTAssertEqual(error as! HTTPClientError, HTTPClientError.unsupportedScheme("file")) + } + XCTAssertThrowsError(try Request(url: "https:/foo"), "should throw") { error in + XCTAssertEqual(error as! HTTPClientError, HTTPClientError.emptyHost) + } + } + + func testSchemaCasing() throws { + XCTAssertNoThrow(try Request(url: "hTTpS://someserver.com:8888/some/path?foo=bar")) + } + func testGet() throws { let httpBin = HttpBin() let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)