Skip to content

Commit 13e3d6b

Browse files
tanner0101weissi
authored andcommitted
add proxy support
1 parent 9b01f9d commit 13e3d6b

File tree

7 files changed

+265
-23
lines changed

7 files changed

+265
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.build
22
Package.resolved
33
*.xcodeproj
4+
DerivedData
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import Foundation
2+
import NIO
3+
import NIOHTTP1
4+
5+
/// Specifies the remote address of an HTTP proxy.
6+
///
7+
/// Adding an `HTTPClientProxy` to your client's `HTTPClientConfiguration`
8+
/// will cause requests to be passed through the specified proxy using the
9+
/// HTTP `CONNECT` method.
10+
///
11+
/// If a `TLSConfiguration` is used in conjunction with `HTTPClientProxy`,
12+
/// TLS will be established _after_ successful proxy, between your client
13+
/// and the destination server.
14+
public struct HTTPClientProxy {
15+
internal let host: String
16+
internal let port: Int
17+
18+
public static func server(host: String, port: Int) -> HTTPClientProxy {
19+
return .init(host: host, port: port)
20+
}
21+
}
22+
23+
internal final class HTTPClientProxyHandler: ChannelDuplexHandler, RemovableChannelHandler {
24+
typealias InboundIn = HTTPClientResponsePart
25+
typealias OutboundIn = HTTPClientRequestPart
26+
typealias OutboundOut = HTTPClientRequestPart
27+
28+
enum BufferItem {
29+
case write(NIOAny, EventLoopPromise<Void>?)
30+
case flush
31+
}
32+
33+
private let host: String
34+
private let port: Int
35+
private var onConnect: (Channel) -> EventLoopFuture<Void>
36+
private var buffer: [BufferItem]
37+
38+
init(host: String, port: Int, onConnect: @escaping (Channel) -> EventLoopFuture<Void>) {
39+
self.host = host
40+
self.port = port
41+
self.onConnect = onConnect
42+
self.buffer = []
43+
}
44+
45+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
46+
let res = self.unwrapInboundIn(data)
47+
switch res {
48+
case .head(let head):
49+
switch head.status.code {
50+
case 200..<300:
51+
// Any 2xx (Successful) response indicates that the sender (and all
52+
// inbound proxies) will switch to tunnel mode immediately after the
53+
// blank line that concludes the successful response's header section
54+
break
55+
default:
56+
// Any response other than a successful response
57+
// indicates that the tunnel has not yet been formed and that the
58+
// connection remains governed by HTTP.
59+
context.fireErrorCaught(HTTPClientErrors.InvalidProxyResponseError())
60+
}
61+
case .end:
62+
_ = self.handleConnect(context: context)
63+
case .body:
64+
break
65+
}
66+
}
67+
68+
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
69+
self.buffer.append(.write(data, promise))
70+
}
71+
72+
func flush(context: ChannelHandlerContext) {
73+
self.buffer.append(.flush)
74+
}
75+
76+
func channelActive(context: ChannelHandlerContext) {
77+
self.sendConnect(context: context)
78+
context.fireChannelActive()
79+
}
80+
81+
// MARK: Private
82+
83+
private func handleConnect(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
84+
return self.onConnect(context.channel).flatMap {
85+
while self.buffer.count > 0 {
86+
// make a copy of the current buffer and clear it in case any
87+
// calls to context.write cause more requests to be buffered
88+
let buffer = self.buffer
89+
self.buffer = []
90+
buffer.forEach { item in
91+
switch item {
92+
case .flush:
93+
context.flush()
94+
case .write(let data, let promise):
95+
context.write(data, promise: promise)
96+
}
97+
}
98+
}
99+
return context.pipeline.removeHandler(self)
100+
}
101+
}
102+
103+
private func sendConnect(context: ChannelHandlerContext) {
104+
var head = HTTPRequestHead(
105+
version: .init(major: 1, minor: 1),
106+
method: .CONNECT,
107+
uri: "\(self.host):\(self.port)"
108+
)
109+
head.headers.add(name: "proxy-connection", value: "keep-alive")
110+
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
111+
context.write(self.wrapOutboundOut(.end(nil)), promise: nil)
112+
context.flush()
113+
}
114+
}

Sources/NIOHTTPClient/HTTPHandler.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ public struct HTTPClientErrors {
4646

4747
public struct CancelledError : HTTPClientError {
4848
}
49+
50+
public struct InvalidProxyResponseError : HTTPClientError {
51+
}
4952
}
5053

5154
public enum HTTPBody : Equatable {

Sources/NIOHTTPClient/SwiftNIOHTTP.swift

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,20 @@ public struct HTTPClientConfiguration {
3737
public var tlsConfiguration: TLSConfiguration?
3838
public var followRedirects: Bool
3939
public var timeout: Timeout
40+
public var proxy: HTTPClientProxy?
4041

41-
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout()) {
42+
public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: HTTPClientProxy? = nil) {
4243
self.tlsConfiguration = tlsConfiguration
4344
self.followRedirects = followRedirects
4445
self.timeout = timeout
46+
self.proxy = proxy
4547
}
4648

47-
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout()) {
49+
public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: HTTPClientProxy? = nil) {
4850
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
4951
self.followRedirects = followRedirects
5052
self.timeout = timeout
53+
self.proxy = proxy
5154
}
5255
}
5356

@@ -151,16 +154,24 @@ public class HTTPClient {
151154
var bootstrap = ClientBootstrap(group: group)
152155
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
153156
.channelInitializer { channel in
154-
channel.pipeline.addHTTPClientHandlers().flatMap {
155-
self.configureSSL(channel: channel, useTLS: request.useTLS, hostname: request.host)
157+
let encoder = HTTPRequestEncoder()
158+
let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes))
159+
return channel.pipeline.addHandlers([encoder, decoder], position: .first).flatMap {
160+
switch self.configuration.proxy {
161+
case .none:
162+
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: self.configuration.tlsConfiguration)
163+
case .some:
164+
return channel.pipeline.addProxyHandler(for: request, decoder: decoder, encoder: encoder, tlsConfiguration: self.configuration.tlsConfiguration)
165+
}
156166
}.flatMap {
157167
if let readTimeout = timeout.read {
158168
return channel.pipeline.addHandler(IdleStateHandler(readTimeout: readTimeout))
159169
} else {
160170
return channel.eventLoop.makeSucceededFuture(())
161171
}
162172
}.flatMap {
163-
channel.pipeline.addHandler(HTTPTaskHandler(delegate: delegate, promise: promise, redirectHandler: redirectHandler))
173+
let taskHandler = HTTPTaskHandler(delegate: delegate, promise: promise, redirectHandler: redirectHandler)
174+
return channel.pipeline.addHandler(taskHandler)
164175
}
165176
}
166177

@@ -170,7 +181,19 @@ public class HTTPClient {
170181

171182
let task = HTTPTask(future: promise.futureResult)
172183

173-
bootstrap.connect(host: request.host, port: request.port)
184+
let host: String
185+
let port: Int
186+
187+
switch self.configuration.proxy {
188+
case .none:
189+
host = request.host
190+
port = request.port
191+
case .some(let proxy):
192+
host = proxy.host
193+
port = proxy.port
194+
}
195+
196+
bootstrap.connect(host: host, port: port)
174197
.map { channel in
175198
task.setChannel(channel)
176199
}
@@ -183,19 +206,35 @@ public class HTTPClient {
183206

184207
return task
185208
}
209+
}
186210

187-
private func configureSSL(channel: Channel, useTLS: Bool, hostname: String) -> EventLoopFuture<Void> {
188-
if useTLS {
189-
do {
190-
let tlsConfiguration = self.configuration.tlsConfiguration ?? TLSConfiguration.forClient()
191-
let context = try NIOSSLContext(configuration: tlsConfiguration)
192-
return channel.pipeline.addHandler(try NIOSSLClientHandler(context: context, serverHostname: hostname),
193-
position: .first)
194-
} catch {
195-
return channel.eventLoop.makeFailedFuture(error)
211+
private extension ChannelPipeline {
212+
func addProxyHandler(for request: HTTPRequest, decoder: ByteToMessageHandler<HTTPResponseDecoder>, encoder: HTTPRequestEncoder, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
213+
let handler = HTTPClientProxyHandler(host: request.host, port: request.port, onConnect: { channel in
214+
return channel.pipeline.removeHandler(decoder).flatMap {
215+
return channel.pipeline.addHandler(
216+
ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)),
217+
position: .after(encoder)
218+
)
219+
}.flatMap {
220+
return channel.pipeline.addSSLHandlerIfNeeded(for: request, tlsConfiguration: tlsConfiguration)
196221
}
197-
} else {
198-
return channel.eventLoop.makeSucceededFuture(())
222+
})
223+
return self.addHandler(handler)
224+
}
225+
226+
func addSSLHandlerIfNeeded(for request: HTTPRequest, tlsConfiguration: TLSConfiguration?) -> EventLoopFuture<Void> {
227+
guard request.useTLS else {
228+
return self.eventLoop.makeSucceededFuture(())
229+
}
230+
231+
do {
232+
let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient()
233+
let context = try NIOSSLContext(configuration: tlsConfiguration)
234+
return self.addHandler(try NIOSSLClientHandler(context: context, serverHostname: request.host),
235+
position: .first)
236+
} catch {
237+
return self.eventLoop.makeFailedFuture(error)
199238
}
200239
}
201240

Tests/NIOHTTPClientTests/HTTPClientTestUtils.swift

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,27 @@ internal class HttpBin {
8787
return self.serverChannel.localAddress!
8888
}
8989

90-
init(ssl: Bool = false) {
90+
static func configureTLS(channel: Channel) -> EventLoopFuture<Void> {
91+
let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))],
92+
privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem)))
93+
let context = try! NIOSSLContext(configuration: configuration)
94+
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first)
95+
}
96+
97+
init(ssl: Bool = false, simulateProxy: HTTPProxySimulator.Option? = nil) {
9198
self.serverChannel = try! ServerBootstrap(group: group)
9299
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
93100
.childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1)
94101
.childChannelInitializer { channel in
95-
channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
102+
return channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, withErrorHandling: true).flatMap {
103+
if let simulateProxy = simulateProxy {
104+
return channel.pipeline.addHandler(HTTPProxySimulator(option: simulateProxy), position: .first)
105+
} else {
106+
return channel.eventLoop.makeSucceededFuture(())
107+
}
108+
}.flatMap {
96109
if ssl {
97-
let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))],
98-
privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem)))
99-
let context = try! NIOSSLContext(configuration: configuration)
100-
return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), position: .first).flatMap {
110+
return HttpBin.configureTLS(channel: channel).flatMap {
101111
channel.pipeline.addHandler(HttpBinHandler())
102112
}
103113
} else {
@@ -113,6 +123,48 @@ internal class HttpBin {
113123

114124
}
115125

126+
final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
127+
typealias InboundIn = ByteBuffer
128+
typealias InboundOut = ByteBuffer
129+
typealias OutboundOut = ByteBuffer
130+
131+
enum Option {
132+
case plaintext
133+
case tls
134+
}
135+
136+
let option: Option
137+
138+
init(option: Option) {
139+
self.option = option
140+
}
141+
142+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
143+
let response = """
144+
HTTP/1.1 200 OK\r\n\
145+
Content-Length: 0\r\n\
146+
Connection: close\r\n\
147+
\r\n
148+
"""
149+
var buffer = self.unwrapInboundIn(data)
150+
let request = buffer.readString(length: buffer.readableBytes)!
151+
if request.hasPrefix("CONNECT") {
152+
var buffer = context.channel.allocator.buffer(capacity: 0)
153+
buffer.writeString(response)
154+
context.write(self.wrapInboundOut(buffer), promise: nil)
155+
context.flush()
156+
context.channel.pipeline.removeHandler(self, promise: nil)
157+
switch self.option {
158+
case .tls:
159+
_ = HttpBin.configureTLS(channel: context.channel)
160+
case .plaintext: break
161+
}
162+
} else {
163+
fatalError("Expected a CONNECT request")
164+
}
165+
}
166+
}
167+
116168
internal struct HTTPResponseBuilder {
117169
let head: HTTPResponseHead
118170
var body: ByteBuffer?

Tests/NIOHTTPClientTests/SwiftNIOHTTPTests+XCTest.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ extension SwiftHTTPTests {
3939
("testRemoteClose", testRemoteClose),
4040
("testReadTimeout", testReadTimeout),
4141
("testCancel", testCancel),
42+
("testProxyPlaintext", testProxyPlaintext),
43+
("testProxyTLS", testProxyTLS),
4244
]
4345
}
4446
}

Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,4 +267,35 @@ class SwiftHTTPTests: XCTestCase {
267267
XCTFail("Unexpected error: \(error)")
268268
}
269269
}
270+
271+
func testProxyPlaintext() throws {
272+
let httpBin = HttpBin(simulateProxy: .plaintext)
273+
let httpClient = HTTPClient(
274+
eventLoopGroupProvider: .createNew,
275+
configuration: .init(proxy: .server(host: "localhost", port: httpBin.port))
276+
)
277+
defer {
278+
try! httpClient.syncShutdown()
279+
httpBin.shutdown()
280+
}
281+
let res = try httpClient.get(url: "http://test/ok").wait()
282+
XCTAssertEqual(res.status, .ok)
283+
}
284+
285+
func testProxyTLS() throws {
286+
let httpBin = HttpBin(simulateProxy: .tls)
287+
let httpClient = HTTPClient(
288+
eventLoopGroupProvider: .createNew,
289+
configuration: .init(
290+
certificateVerification: .none,
291+
proxy: .server(host: "localhost", port: httpBin.port)
292+
)
293+
)
294+
defer {
295+
try! httpClient.syncShutdown()
296+
httpBin.shutdown()
297+
}
298+
let res = try httpClient.get(url: "https://test/ok").wait()
299+
XCTAssertEqual(res.status, .ok)
300+
}
270301
}

0 commit comments

Comments
 (0)