diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 18114323b..44f237a61 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -211,7 +211,7 @@ public class HTTPClient { return channel.eventLoop.makeSucceededFuture(()) } }.flatMap { - let taskHandler = TaskHandler(task: task, delegate: delegate, redirectHandler: redirectHandler) + let taskHandler = TaskHandler(task: task, delegate: delegate, redirectHandler: redirectHandler, ignoreUncleanSSLShutdown: self.configuration.ignoreUncleanSSLShutdown) return channel.pipeline.addHandler(taskHandler) } } @@ -276,19 +276,31 @@ public class HTTPClient { public var timeout: Timeout /// Upstream proxy, defaults to no proxy. public var proxy: Proxy? + /// Ignore TLS unclean shutdown error, defaults to `false`. + public var ignoreUncleanSSLShutdown: Bool public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) { + self.init(tlsConfiguration: tlsConfiguration, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false) + } + + public init(tlsConfiguration: TLSConfiguration? = nil, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) { self.tlsConfiguration = tlsConfiguration self.followRedirects = followRedirects self.timeout = timeout self.proxy = proxy + self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown } public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil) { + self.init(certificateVerification: certificateVerification, followRedirects: followRedirects, timeout: timeout, proxy: proxy, ignoreUncleanSSLShutdown: false) + } + + public init(certificateVerification: CertificateVerification, followRedirects: Bool = false, timeout: Timeout = Timeout(), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false) { self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification) self.followRedirects = followRedirects self.timeout = timeout self.proxy = proxy + self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown } } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index ebcbd740b..c2c7b3439 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -441,15 +441,17 @@ internal class TaskHandler: ChannelInboundHandler let task: HTTPClient.Task let delegate: T let redirectHandler: RedirectHandler? + let ignoreUncleanSSLShutdown: Bool var state: State = .idle var pendingRead = false var mayRead = true - init(task: HTTPClient.Task, delegate: T, redirectHandler: RedirectHandler?) { + init(task: HTTPClient.Task, delegate: T, redirectHandler: RedirectHandler?, ignoreUncleanSSLShutdown: Bool) { self.task = task self.delegate = delegate self.redirectHandler = redirectHandler + self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown } func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -619,6 +621,10 @@ internal class TaskHandler: ChannelInboundHandler /// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection, /// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error. break + case .head where self.ignoreUncleanSSLShutdown, + .body where self.ignoreUncleanSSLShutdown: + /// We can also ignore this error like `.end`. + break default: self.state = .end self.delegate.didReceiveError(task: self.task, error) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 7c19cde2c..42d324138 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -28,7 +28,7 @@ class HTTPClientInternalTests: XCTestCase { let task = Task(eventLoop: channel.eventLoop) try channel.pipeline.addHandler(recorder).wait() - try channel.pipeline.addHandler(TaskHandler(task: task, delegate: TestHTTPDelegate(), redirectHandler: nil)).wait() + try channel.pipeline.addHandler(TaskHandler(task: task, delegate: TestHTTPDelegate(), redirectHandler: nil, ignoreUncleanSSLShutdown: false)).wait() var request = try Request(url: "http://localhost/get") request.headers.add(name: "X-Test-Header", value: "X-Test-Value") @@ -53,7 +53,7 @@ class HTTPClientInternalTests: XCTestCase { let channel = EmbeddedChannel() let delegate = TestHTTPDelegate() let task = Task(eventLoop: channel.eventLoop) - let handler = TaskHandler(task: task, delegate: delegate, redirectHandler: nil) + let handler = TaskHandler(task: task, delegate: delegate, redirectHandler: nil, ignoreUncleanSSLShutdown: false) try channel.pipeline.addHandler(handler).wait() diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 71f930921..fee3a848e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -332,6 +332,105 @@ internal final class HttpBinHandler: ChannelInboundHandler { } } +internal class HttpBinForSSLUncleanShutdown { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let serverChannel: Channel + + var port: Int { + return Int(self.serverChannel.localAddress!.port!) + } + + init(channelPromise: EventLoopPromise? = nil) { + self.serverChannel = try! ServerBootstrap(group: self.group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) + .childChannelInitializer { channel in + let requestDecoder = HTTPRequestDecoder() + return channel.pipeline.addHandler(ByteToMessageHandler(requestDecoder)).flatMap { + let configuration = TLSConfiguration.forServer(certificateChain: [.certificate(try! NIOSSLCertificate(buffer: cert.utf8.map(Int8.init), format: .pem))], + privateKey: .privateKey(try! NIOSSLPrivateKey(buffer: key.utf8.map(Int8.init), format: .pem))) + let context = try! NIOSSLContext(configuration: configuration) + return channel.pipeline.addHandler(try! NIOSSLServerHandler(context: context), name: "NIOSSLServerHandler", position: .first).flatMap { + channel.pipeline.addHandler(HttpBinForSSLUncleanShutdownHandler(channelPromise: channelPromise)) + } + } + }.bind(host: "127.0.0.1", port: 0).wait() + } + + func shutdown() { + try! self.group.syncShutdownGracefully() + } +} + +internal final class HttpBinForSSLUncleanShutdownHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = ByteBuffer + + let channelPromise: EventLoopPromise? + + init(channelPromise: EventLoopPromise? = nil) { + self.channelPromise = channelPromise + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.unwrapInboundIn(data) { + case .head(let req): + if let promise = self.channelPromise { + promise.succeed(context.channel) + } + + let response: String? + switch req.uri { + case "/nocontentlength": + response = """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + \r\n\ + foo + """ + case "/nocontent": + response = """ + HTTP/1.1 204 OK\r\n\ + Connection: close\r\n\ + \r\n + """ + case "/noresponse": + response = nil + case "/wrongcontentlength": + response = """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Content-Length: 6\r\n\ + \r\n\ + foo + """ + default: + response = """ + HTTP/1.1 404 OK\r\n\ + Connection: close\r\n\ + Content-Length: 9\r\n\ + \r\n\ + Not Found + """ + } + + if let response = response { + var buffer = context.channel.allocator.buffer(capacity: response.count) + buffer.writeString(response) + context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) + } + + _ = context.channel.pipeline.removeHandler(name: "NIOSSLServerHandler").map { _ in + context.close(promise: nil) + } + case .body: + () + case .end: + () + } + } +} + extension ByteBuffer { public static func of(string: String) -> ByteBuffer { var buffer = ByteBufferAllocator().buffer(capacity: string.count) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 86c98bdbd..cb33ba8ed 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -45,6 +45,14 @@ extension HTTPClientTests { ("testProxyPlaintext", testProxyPlaintext), ("testProxyTLS", testProxyTLS), ("testUploadStreaming", testUploadStreaming), + ("testNoContentLengthForSSLUncleanShutdown", testNoContentLengthForSSLUncleanShutdown), + ("testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown", testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown), + ("testCorrectContentLengthForSSLUncleanShutdown", testCorrectContentLengthForSSLUncleanShutdown), + ("testNoContentForSSLUncleanShutdown", testNoContentForSSLUncleanShutdown), + ("testNoResponseForSSLUncleanShutdown", testNoResponseForSSLUncleanShutdown), + ("testNoResponseWithIgnoreErrorForSSLUncleanShutdown", testNoResponseWithIgnoreErrorForSSLUncleanShutdown), + ("testWrongContentLengthForSSLUncleanShutdown", testWrongContentLengthForSSLUncleanShutdown), + ("testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown", testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 78a8bdb76..ecd6c6209 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -16,6 +16,7 @@ import AsyncHTTPClient import NIO import NIOFoundationCompat import NIOHTTP1 +import NIOSSL import XCTest class HTTPClientTests: XCTestCase { @@ -346,4 +347,141 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, response.status) XCTAssertEqual("12344321", data.data) } + + func testNoContentLengthForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/nocontentlength").wait(), "Should fail") { error in + guard case let error = error as? NIOSSLError, error == .uncleanShutdown else { + return XCTFail("Should fail with NIOSSLError.uncleanShutdown") + } + } + } + + func testNoContentLengthWithIgnoreErrorForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none, ignoreUncleanSSLShutdown: true)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + let response = try httpClient.get(url: "https://localhost:\(httpBin.port)/nocontentlength").wait() + let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } + let string = String(decoding: bytes!, as: UTF8.self) + + XCTAssertEqual(.ok, response.status) + XCTAssertEqual("foo", string) + } + + func testCorrectContentLengthForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + let response = try httpClient.get(url: "https://localhost:\(httpBin.port)/").wait() + let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } + let string = String(decoding: bytes!, as: UTF8.self) + + XCTAssertEqual(.notFound, response.status) + XCTAssertEqual("Not Found", string) + } + + func testNoContentForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + let response = try httpClient.get(url: "https://localhost:\(httpBin.port)/nocontent").wait() + + XCTAssertEqual(.noContent, response.status) + XCTAssertEqual(response.body, nil) + } + + func testNoResponseForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/noresponse").wait(), "Should fail") { error in + guard case let error = error as? NIOSSLError, error == .uncleanShutdown else { + return XCTFail("Should fail with NIOSSLError.uncleanShutdown") + } + } + } + + func testNoResponseWithIgnoreErrorForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none, ignoreUncleanSSLShutdown: true)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/noresponse").wait(), "Should fail") { error in + guard case let error = error as? NIOSSLError, error == .uncleanShutdown else { + return XCTFail("Should fail with NIOSSLError.uncleanShutdown") + } + } + } + + func testWrongContentLengthForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/wrongcontentlength").wait(), "Should fail") { error in + guard case let error = error as? NIOSSLError, error == .uncleanShutdown else { + return XCTFail("Should fail with NIOSSLError.uncleanShutdown") + } + } + } + + func testWrongContentLengthWithIgnoreErrorForSSLUncleanShutdown() throws { + let httpBin = HttpBinForSSLUncleanShutdown() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, + configuration: HTTPClient.Configuration(certificateVerification: .none, ignoreUncleanSSLShutdown: true)) + + defer { + try! httpClient.syncShutdown() + httpBin.shutdown() + } + + XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/wrongcontentlength").wait(), "Should fail") { error in + guard case let error = error as? HTTPParserError, error == .invalidEOFState else { + return XCTFail("Should fail with HTTPParserError.invalidEOFState") + } + } + } }