Skip to content

Commit e64448e

Browse files
authored
improve request validation (#67)
motivation: safer handling of request validation and mutation changes: * drop request version * made request method and url immutable * made request scheme and host internal * fix scheme logic to be non-case sensitive * adjusted redirect handler implementation to stricter request immutabllity * adjust and add tests
1 parent bab22d0 commit e64448e

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

Sources/AsyncHTTPClient/HTTPHandler.swift

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,14 @@ extension HTTPClient {
8787

8888
/// Represent HTTP request.
8989
public struct Request {
90-
/// Request HTTP version, defaults to `HTTP/1.1`.
91-
public var version: HTTPVersion
9290
/// Request HTTP method, defaults to `GET`.
93-
public var method: HTTPMethod
91+
public let method: HTTPMethod
9492
/// Remote URL.
95-
public var url: URL
93+
public let url: URL
9694
/// Remote HTTP scheme, resolved from `URL`.
97-
public var scheme: String
95+
public let scheme: String
9896
/// Remote host, resolved from `URL`.
99-
public var host: String
97+
public let host: String
10098
/// Request custom HTTP Headers, defaults to no headers.
10199
public var headers: HTTPHeaders
102100
/// Request body, defaults to no body.
@@ -115,12 +113,12 @@ extension HTTPClient {
115113
/// - `emptyScheme` if URL does not contain HTTP scheme.
116114
/// - `unsupportedScheme` if URL does contains unsupported HTTP scheme.
117115
/// - `emptyHost` if URL does not contains a host.
118-
public init(url: String, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws {
116+
public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws {
119117
guard let url = URL(string: url) else {
120118
throw HTTPClientError.invalidURL
121119
}
122120

123-
try self.init(url: url, version: version, method: method, headers: headers, body: body)
121+
try self.init(url: url, method: method, headers: headers, body: body)
124122
}
125123

126124
/// Create an HTTP `Request`.
@@ -135,8 +133,8 @@ extension HTTPClient {
135133
/// - `emptyScheme` if URL does not contain HTTP scheme.
136134
/// - `unsupportedScheme` if URL does contains unsupported HTTP scheme.
137135
/// - `emptyHost` if URL does not contains a host.
138-
public init(url: URL, version: HTTPVersion = HTTPVersion(major: 1, minor: 1), method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws {
139-
guard let scheme = url.scheme else {
136+
public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws {
137+
guard let scheme = url.scheme?.lowercased() else {
140138
throw HTTPClientError.emptyScheme
141139
}
142140

@@ -148,7 +146,6 @@ extension HTTPClient {
148146
throw HTTPClientError.emptyHost
149147
}
150148

151-
self.version = version
152149
self.method = method
153150
self.url = url
154151
self.scheme = scheme
@@ -159,15 +156,15 @@ extension HTTPClient {
159156

160157
/// Whether request will be executed using secure socket.
161158
public var useTLS: Bool {
162-
return self.url.scheme == "https"
159+
return self.scheme == "https"
163160
}
164161

165162
/// Resolved port.
166163
public var port: Int {
167164
return self.url.port ?? (self.useTLS ? 443 : 80)
168165
}
169166

170-
static func isSchemeSupported(scheme: String?) -> Bool {
167+
static func isSchemeSupported(scheme: String) -> Bool {
171168
return scheme == "http" || scheme == "https"
172169
}
173170
}
@@ -444,10 +441,10 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
444441
self.state = .idle
445442
let request = unwrapOutboundIn(data)
446443

447-
var head = HTTPRequestHead(version: request.version, method: request.method, uri: request.url.uri)
444+
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: request.method, uri: request.url.uri)
448445
var headers = request.headers
449446

450-
if request.version.major == 1, request.version.minor == 1, !request.headers.contains(name: "Host") {
447+
if !request.headers.contains(name: "Host") {
451448
headers.add(name: "Host", value: request.host)
452449
}
453450

@@ -640,7 +637,7 @@ internal struct RedirectHandler<T> {
640637
return nil
641638
}
642639

643-
guard HTTPClient.Request.isSchemeSupported(scheme: url.scheme) else {
640+
guard HTTPClient.Request.isSchemeSupported(scheme: self.request.scheme) else {
644641
return nil
645642
}
646643

@@ -652,44 +649,38 @@ internal struct RedirectHandler<T> {
652649
}
653650

654651
func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise<T>) {
655-
let originalURL = self.request.url
656-
657-
var request = self.request
658-
request.url = redirectURL
659-
660-
if let redirectHost = redirectURL.host {
661-
request.host = redirectHost
662-
} else {
663-
preconditionFailure("redirectURL doesn't contain a host")
664-
}
665-
666-
if let redirectScheme = redirectURL.scheme {
667-
request.scheme = redirectScheme
668-
} else {
669-
preconditionFailure("redirectURL doesn't contain a scheme")
670-
}
652+
let originalRequest = self.request
671653

672654
var convertToGet = false
673-
if status == .seeOther, request.method != .HEAD {
655+
if status == .seeOther, self.request.method != .HEAD {
674656
convertToGet = true
675-
} else if status == .movedPermanently || status == .found, request.method == .POST {
657+
} else if status == .movedPermanently || status == .found, self.request.method == .POST {
676658
convertToGet = true
677659
}
678660

661+
var method = originalRequest.method
662+
var headers = originalRequest.headers
663+
var body = originalRequest.body
664+
679665
if convertToGet {
680-
request.method = .GET
681-
request.body = nil
682-
request.headers.remove(name: "Content-Length")
683-
request.headers.remove(name: "Content-Type")
666+
method = .GET
667+
body = nil
668+
headers.remove(name: "Content-Length")
669+
headers.remove(name: "Content-Type")
684670
}
685671

686-
if !originalURL.hasTheSameOrigin(as: redirectURL) {
687-
request.headers.remove(name: "Origin")
688-
request.headers.remove(name: "Cookie")
689-
request.headers.remove(name: "Authorization")
690-
request.headers.remove(name: "Proxy-Authorization")
672+
if !originalRequest.url.hasTheSameOrigin(as: redirectURL) {
673+
headers.remove(name: "Origin")
674+
headers.remove(name: "Cookie")
675+
headers.remove(name: "Authorization")
676+
headers.remove(name: "Proxy-Authorization")
691677
}
692678

693-
return self.execute(request).futureResult.cascade(to: promise)
679+
do {
680+
let newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
681+
return self.execute(newRequest).futureResult.cascade(to: promise)
682+
} catch {
683+
return promise.fail(error)
684+
}
694685
}
695686
}

Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ extension HTTPClientTests {
2626
static var allTests: [(String, (HTTPClientTests) -> () throws -> Void)] {
2727
return [
2828
("testRequestURI", testRequestURI),
29+
("testBadRequestURI", testBadRequestURI),
30+
("testSchemaCasing", testSchemaCasing),
2931
("testGet", testGet),
3032
("testGetWithSharedEventLoopGroup", testGetWithSharedEventLoopGroup),
3133
("testPost", testPost),

Tests/AsyncHTTPClientTests/HTTPClientTests.swift

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class HTTPClientTests: XCTestCase {
2323

2424
func testRequestURI() throws {
2525
let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar")
26-
XCTAssertEqual(request1.host, "someserver.com")
26+
XCTAssertEqual(request1.url.host, "someserver.com")
2727
XCTAssertEqual(request1.url.path, "/some/path")
2828
XCTAssertEqual(request1.url.query!, "foo=bar")
2929
XCTAssertEqual(request1.port, 8888)
@@ -33,6 +33,22 @@ class HTTPClientTests: XCTestCase {
3333
XCTAssertEqual(request2.url.path, "")
3434
}
3535

36+
func testBadRequestURI() throws {
37+
XCTAssertThrowsError(try Request(url: "some/path"), "should throw") { error in
38+
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.emptyScheme)
39+
}
40+
XCTAssertThrowsError(try Request(url: "file://somewhere/some/path?foo=bar"), "should throw") { error in
41+
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.unsupportedScheme("file"))
42+
}
43+
XCTAssertThrowsError(try Request(url: "https:/foo"), "should throw") { error in
44+
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.emptyHost)
45+
}
46+
}
47+
48+
func testSchemaCasing() throws {
49+
XCTAssertNoThrow(try Request(url: "hTTpS://someserver.com:8888/some/path?foo=bar"))
50+
}
51+
3652
func testGet() throws {
3753
let httpBin = HttpBin()
3854
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew)

0 commit comments

Comments
 (0)