diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index cc6bd4899..8b2a50738 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -139,30 +139,13 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // MARK: Run Actions private func run(_ action: HTTPRequestStateMachine.Action, context: ChannelHandlerContext) { - // NOTE: We can bang the request in the following actions, since the `HTTPRequestStateMachine` - // ensures, that actions that require a request are only called, if the request is - // still present. The request is only nilled as a response to a state machine action - // (.failRequest or .succeedRequest). - switch action { case .sendRequestHead(let head, let startBody): - if startBody { - context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - self.request!.requestHeadSent() - self.request!.resumeRequestBodyStream() - } else { - context.write(self.wrapOutboundOut(.head(head)), promise: nil) - context.write(self.wrapOutboundOut(.end(nil)), promise: nil) - context.flush() - - self.request!.requestHeadSent() - - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) - } - } + self.sendRequestHead(head, startBody: startBody, context: context) case .pauseRequestBodyStream: + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.pauseRequestBodyStream() case .sendBodyPart(let data): @@ -182,18 +165,29 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { break case .resumeRequestBodyStream: + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.resumeRequestBodyStream() case .forwardResponseHead(let head, pauseRequestBodyStream: let pauseRequestBodyStream): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.receiveResponseHead(head) - if pauseRequestBodyStream { - self.request!.pauseRequestBodyStream() + if pauseRequestBodyStream, let request = self.request { + // The above response head forward might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.pauseRequestBodyStream() } case .forwardResponseBodyParts(let parts): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(parts) case .failRequest(let error, _): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request object is still present. self.request!.fail(error) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) @@ -204,6 +198,8 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.runFinalAction(.close, context: context) case .succeedRequest(let finalAction, let finalParts): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request object is still present. self.request!.succeedRequest(finalParts) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) @@ -211,6 +207,33 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } + private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) { + if startBody { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) + + // The above write might trigger an error, which may lead to a call to `errorCaught`, + // which in turn, may fail the request and pop it from the handler. For this reason + // we must check if the request is still present here. + guard let request = self.request else { return } + request.requestHeadSent() + request.resumeRequestBodyStream() + } else { + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.write(self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + + // The above write might trigger an error, which may lead to a call to `errorCaught`, + // which in turn, may fail the request and pop it from the handler. For this reason + // we must check if the request is still present here. + guard let request = self.request else { return } + request.requestHeadSent() + + if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(timeoutAction, context: context) + } + } + } + private func runFinalAction(_ action: HTTPRequestStateMachine.Action.FinalStreamAction, context: ChannelHandlerContext) { switch action { case .close: diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift index a0facdb65..8fa219838 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift @@ -29,6 +29,7 @@ extension HTTP2ClientRequestHandlerTests { ("testWriteBackpressure", testWriteBackpressure), ("testIdleReadTimeout", testIdleReadTimeout), ("testIdleReadTimeoutIsCanceledIfRequestIsCanceled", testIdleReadTimeoutIsCanceledIfRequestIsCanceled), + ("testWriteHTTPHeadFails", testWriteHTTPHeadFails), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 7bbb30105..e67529ad8 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -285,4 +285,64 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { // therefore advancing the time should not trigger a crash embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + + func testWriteHTTPHeadFails() { + struct WriteError: Error, Equatable {} + + class FailWriteHandler: ChannelOutboundHandler { + typealias OutboundIn = HTTPClientRequestPart + typealias OutboundOut = HTTPClientRequestPart + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let error = WriteError() + promise?.fail(error) + context.fireErrorCaught(error) + } + } + + let bodies: [HTTPClient.Body?] = [ + .none, + .some(.byteBuffer(ByteBuffer(string: "hello world"))), + ] + + for body in bodies { + let embeddedEventLoop = EmbeddedEventLoop() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embeddedEventLoop) + let embedded = EmbeddedChannel(handlers: [FailWriteHandler(), requestHandler], loop: embeddedEventLoop) + + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: body)) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + )) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = false + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + embedded.write(requestBag, promise: nil) + + // the handler only writes once the channel is writable + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .none) + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? WriteError, WriteError()) + } + + XCTAssertEqual(embedded.isActive, false) + } + } }