diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 2bae67adb..12e6a4fc4 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -814,9 +814,8 @@ extension TaskHandler: ChannelDuplexHandler { do { try headers.validate(method: request.method, body: request.body) } catch { + self.errorCaught(context: context, error: error) promise?.fail(error) - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) - self.state = .endOrError return } @@ -843,9 +842,8 @@ extension TaskHandler: ChannelDuplexHandler { self.state = .bodySent context.eventLoop.assertInEventLoop() if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength { - self.state = .endOrError let error = HTTPClientError.bodyLengthMismatch - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: error) return context.eventLoop.makeFailedFuture(error) } return context.writeAndFlush(self.wrapOutboundOut(.end(nil))) @@ -855,13 +853,7 @@ extension TaskHandler: ChannelDuplexHandler { self.callOutToDelegateFireAndForget(self.delegate.didSendRequest) }.flatMapErrorThrowing { error in context.eventLoop.assertInEventLoop() - switch self.state { - case .endOrError: - break - default: - self.state = .endOrError - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) - } + self.errorCaught(context: context, error: error) throw error }.cascade(to: promise) } @@ -906,8 +898,7 @@ extension TaskHandler: ChannelDuplexHandler { case .idle: if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit { let error = HTTPClientError.bodyLengthMismatch - self.state = .endOrError - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: error) promise.fail(error) return } @@ -915,8 +906,7 @@ extension TaskHandler: ChannelDuplexHandler { context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) default: let error = HTTPClientError.writeAfterRequestSent - self.state = .endOrError - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: error) promise.fail(error) } } @@ -983,16 +973,13 @@ extension TaskHandler: ChannelDuplexHandler { context.read() } case .failure(let error): - self.state = .endOrError - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: error) } } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { if (event as? IdleStateHandler.IdleStateEvent) == .read { - self.state = .endOrError - let error = HTTPClientError.readTimeout - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: HTTPClientError.readTimeout) } else { context.fireUserInboundEventTriggered(event) } @@ -1000,9 +987,7 @@ extension TaskHandler: ChannelDuplexHandler { func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { if (event as? TaskCancelEvent) != nil { - self.state = .endOrError - let error = HTTPClientError.cancelled - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: HTTPClientError.cancelled) promise?.succeed(()) } else { context.triggerUserOutboundEvent(event, promise: promise) @@ -1014,9 +999,7 @@ extension TaskHandler: ChannelDuplexHandler { case .endOrError: break case .body, .head, .idle, .redirected, .sent, .bodySent: - self.state = .endOrError - let error = HTTPClientError.remoteConnectionClosed - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed) } context.fireChannelInactive() } @@ -1038,14 +1021,20 @@ extension TaskHandler: ChannelDuplexHandler { self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } default: - self.state = .endOrError - self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + switch self.state { + case .idle, .bodySent, .sent, .head, .redirected, .body: + self.state = .endOrError + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + case .endOrError: + // error was already handled + break + } } } func handlerAdded(context: ChannelHandlerContext) { guard context.channel.isActive else { - self.failTaskAndNotifyDelegate(error: HTTPClientError.remoteConnectionClosed, self.delegate.didReceiveError) + self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed) return } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index 6177127d0..648eb8078 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -46,6 +46,7 @@ extension HTTPClientInternalTests { ("testConnectErrorCalloutOnCorrectEL", testConnectErrorCalloutOnCorrectEL), ("testInternalRequestURI", testInternalRequestURI), ("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification), + ("testHandlerDoubleError", testHandlerDoubleError), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 52348ba94..706a3bbd7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -1080,4 +1080,43 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try channel.readOutbound(as: HTTPClientRequestPart.self)) // .head XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) } + + func testHandlerDoubleError() throws { + class ErrorCountingDelegate: HTTPClientResponseDelegate { + typealias Response = Void + + var count = 0 + + func didReceiveError(task: HTTPClient.Task, _: Error) { + self.count += 1 + } + + func didFinishRequest(task: HTTPClient.Task) throws { + return () + } + } + + class SendTwoErrorsHandler: ChannelInboundHandler { + typealias InboundIn = Any + + func handlerAdded(context: ChannelHandlerContext) { + context.fireErrorCaught(HTTPClientError.cancelled) + context.fireErrorCaught(HTTPClientError.cancelled) + } + } + + let channel = EmbeddedChannel() + let task = Task(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled) + let delegate = ErrorCountingDelegate() + try channel.pipeline.addHandler(TaskHandler(task: task, + kind: .host, + delegate: delegate, + redirectHandler: nil, + ignoreUncleanSSLShutdown: false, + logger: HTTPClient.loggingDisabled)).wait() + + try channel.pipeline.addHandler(SendTwoErrorsHandler()).wait() + + XCTAssertEqual(delegate.count, 1) + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index eb4dd7cb5..918da82ac 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -122,6 +122,7 @@ extension HTTPClientTests { ("testContentLengthTooShortFails", testContentLengthTooShortFails), ("testBodyUploadAfterEndFails", testBodyUploadAfterEndFails), ("testNoBytesSentOverBodyLimit", testNoBytesSentOverBodyLimit), + ("testDoubleError", testDoubleError), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 04287ef0d..970ab7dbe 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -2602,4 +2602,29 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try future.wait()) } + + func testDoubleError() throws { + // This is needed to that connection pool will not get into closed state when we release + // second connection. + _ = self.defaultClient.get(url: "http://localhost:\(self.defaultHTTPBin.port)/events/10/1") + + var request = try HTTPClient.Request(url: "http://localhost:\(self.defaultHTTPBin.port)/wait", method: .POST) + request.body = .stream { writer in + // Start writing chunks so tha we will try to write after read timeout is thrown + for _ in 1...10 { + _ = writer.write(.byteBuffer(ByteBuffer(string: "1234"))) + } + + let promise = self.clientGroup.next().makePromise(of: Void.self) + self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { + writer.write(.byteBuffer(ByteBuffer(string: "1234"))).cascade(to: promise) + } + + return promise.futureResult + } + + // We specify a deadline of 2 ms co that request will be timed out before all chunks are writtent, + // we need to verify that second error on write after timeout does not lead to double-release. + XCTAssertThrowsError(try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait()) + } }