diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index 49e755733..5a0b2708e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -147,8 +147,6 @@ final class HTTPConnectionPool { self.unlocked = Unlocked(connection: .none, request: .none) switch stateMachineAction.request { - case .cancelRequestTimeout(let requestID): - self.locked.request = .cancelRequestTimeout(requestID) case .executeRequest(let request, let connection, cancelTimeout: let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift index d654f5a87..2cd667bb3 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift @@ -323,9 +323,11 @@ extension HTTPConnectionPool { mutating func cancelRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue - if self.requests.remove(requestID) != nil { + if let request = self.requests.remove(requestID) { + // Use the last connection error to let the user know why the request was never scheduled + let error = self.lastConnectFailure ?? HTTPClientError.cancelled return .init( - request: .cancelRequestTimeout(requestID), + request: .failRequest(request, error, cancelTimeout: true), connection: .none ) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift index 06fc36ad0..d517d82e6 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift @@ -444,9 +444,11 @@ extension HTTPConnectionPool { mutating func cancelRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue - if self.requests.remove(requestID) != nil { + if let request = self.requests.remove(requestID) { + // Use the last connection error to let the user know why the request was never scheduled + let error = self.lastConnectFailure ?? HTTPClientError.cancelled return .init( - request: .cancelRequestTimeout(requestID), + request: .failRequest(request, error, cancelTimeout: true), connection: .none ) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift index 61e57941a..63f3e5a9a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift @@ -61,7 +61,6 @@ extension HTTPConnectionPool { case failRequestsAndCancelTimeouts([Request], Error) case scheduleRequestTimeout(for: Request, on: EventLoop) - case cancelRequestTimeout(Request.ID) case none } diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 45b2ce0ff..1f08fb41d 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -606,7 +606,7 @@ public class HTTPClient { var deadlineSchedule: Scheduled? if let deadline = deadline { deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { - requestBag.fail(HTTPClientError.deadlineExceeded) + requestBag.deadlineExceeded() } task.promise.futureResult.whenComplete { _ in diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index 63cb15758..a2a90749a 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -31,6 +31,10 @@ extension RequestBag { fileprivate enum State { case initialized case queued(HTTPRequestScheduler) + /// if the deadline was exceeded while in the `.queued(_:)` state, + /// we wait until the request pool fails the request with a potential more descriptive error message, + /// if a connection failure has occured while the request was queued. + case deadlineExceededWhileQueued case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState) case finished(error: Error?) case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL) @@ -90,13 +94,23 @@ extension RequestBag.StateMachine { self.state = .queued(scheduler) } - mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> Bool { + enum WillExecuteRequestAction { + case cancelExecuter(HTTPRequestExecutor) + case failTaskAndCancelExecutor(Error, HTTPRequestExecutor) + case none + } + + mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction { switch self.state { case .initialized, .queued: self.state = .executing(executor, .initialized, .initialized) - return true + return .none + case .deadlineExceededWhileQueued: + let error: Error = HTTPClientError.deadlineExceeded + self.state = .finished(error: error) + return .failTaskAndCancelExecutor(error, executor) case .finished(error: .some): - return false + return .cancelExecuter(executor) case .executing, .redirected, .finished(error: .none), .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -110,7 +124,7 @@ extension RequestBag.StateMachine { mutating func resumeRequestBodyStream() -> ResumeProducingAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be resumed, if the request was started") case .executing(let executor, .initialized, .initialized): @@ -150,7 +164,7 @@ extension RequestBag.StateMachine { mutating func pauseRequestBodyStream() { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be paused, if the request was started") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -185,7 +199,7 @@ extension RequestBag.StateMachine { mutating func writeNextRequestPart(_ part: IOData, taskEventLoop: EventLoop) -> WriteAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -231,7 +245,7 @@ extension RequestBag.StateMachine { mutating func finishRequestBodyStream(_ result: Result) -> FinishAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -282,7 +296,7 @@ extension RequestBag.StateMachine { /// - Returns: Whether the response should be forwarded to the delegate. Will be `false` if the request follows a redirect. mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response, if the request hasn't started yet.") case .executing(let executor, let requestState, let responseState): guard case .initialized = responseState else { @@ -328,7 +342,7 @@ extension RequestBag.StateMachine { mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponseBodyAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") case .executing(_, _, .initialized): preconditionFailure("If we receive a response body, we must have received a head before") @@ -385,7 +399,7 @@ extension RequestBag.StateMachine { mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") case .executing(_, _, .initialized): preconditionFailure("If we receive a response body, we must have received a head before") @@ -447,7 +461,7 @@ extension RequestBag.StateMachine { private mutating func failWithConsumptionError(_ error: Error) -> ConsumeAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") @@ -482,7 +496,7 @@ extension RequestBag.StateMachine { private mutating func consumeMoreBodyData() -> ConsumeAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): @@ -532,8 +546,33 @@ extension RequestBag.StateMachine { } } + enum DeadlineExceededAction { + case cancelScheduler(HTTPRequestScheduler?) + case fail(FailAction) + } + + mutating func deadlineExceeded() -> DeadlineExceededAction { + switch self.state { + case .queued(let queuer): + /// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message + /// We therefore depend on the scheduler failing the request after we cancel the request. + self.state = .deadlineExceededWhileQueued + return .cancelScheduler(queuer) + + case .initialized, + .deadlineExceededWhileQueued, + .executing, + .finished, + .redirected, + .modifying: + /// if we are not in the queued state, we can fail early by just calling down to `self.fail(_:)` + /// which does the appropriate state transition for us. + return .fail(self.fail(HTTPClientError.deadlineExceeded)) + } + } + enum FailAction { - case failTask(HTTPRequestScheduler?, HTTPRequestExecutor?) + case failTask(Error, HTTPRequestScheduler?, HTTPRequestExecutor?) case cancelExecutor(HTTPRequestExecutor) case none } @@ -542,31 +581,39 @@ extension RequestBag.StateMachine { switch self.state { case .initialized: self.state = .finished(error: error) - return .failTask(nil, nil) + return .failTask(error, nil, nil) case .queued(let queuer): self.state = .finished(error: error) - return .failTask(queuer, nil) + return .failTask(error, queuer, nil) case .executing(let executor, let requestState, .buffering(_, next: .eof)): self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) return .cancelExecutor(executor) case .executing(let executor, _, .buffering(_, next: .askExecutorForMore)): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .executing(let executor, _, .buffering(_, next: .error(_))): // this would override another error, let's keep the first one return .cancelExecutor(executor) case .executing(let executor, _, .initialized): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .executing(let executor, _, .waitingForRemote): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .redirected: self.state = .finished(error: error) - return .failTask(nil, nil) + return .failTask(error, nil, nil) case .finished(.none): // An error occurred after the request has finished. Ignore... return .none + case .deadlineExceededWhileQueued: + // if we just get a `HTTPClientError.cancelled` we can use the original cancellation reason + // to give a more descriptive error to the user. + if (error as? HTTPClientError) == .cancelled { + return .failTask(HTTPClientError.deadlineExceeded, nil, nil) + } + // otherwise we already had an intermediate connection error which we should present to the user instead + return .failTask(error, nil, nil) case .finished(.some(_)): // this might happen, if the stream consumer has failed... let's just drop the data return .none diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index dbef802e9..4ec7004c1 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -81,8 +81,16 @@ final class RequestBag { private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { self.task.eventLoop.assertInEventLoop() - if !self.state.willExecuteRequest(executor) { - return executor.cancelRequest(self) + let action = self.state.willExecuteRequest(executor) + switch action { + case .cancelExecuter(let executor): + executor.cancelRequest(self) + case .failTaskAndCancelExecutor(let error, let executor): + self.delegate.didReceiveError(task: self.task, error) + self.task.fail(with: error, delegateType: Delegate.self) + executor.cancelRequest(self) + case .none: + break } } @@ -320,8 +328,12 @@ final class RequestBag { let action = self.state.fail(error) + self.executeFailAction0(action) + } + + private func executeFailAction0(_ action: RequestBag.StateMachine.FailAction) { switch action { - case .failTask(let scheduler, let executor): + case .failTask(let error, let scheduler, let executor): scheduler?.cancelRequest(self) executor?.cancelRequest(self) self.failTask0(error) @@ -331,6 +343,28 @@ final class RequestBag { break } } + + func deadlineExceeded0() { + self.task.eventLoop.assertInEventLoop() + let action = self.state.deadlineExceeded() + + switch action { + case .cancelScheduler(let scheduler): + scheduler?.cancelRequest(self) + case .fail(let failAction): + self.executeFailAction0(failAction) + } + } + + func deadlineExceeded() { + if self.task.eventLoop.inEventLoop { + self.deadlineExceeded0() + } else { + self.task.eventLoop.execute { + self.deadlineExceeded0() + } + } + } } extension RequestBag: HTTPSchedulableRequest { @@ -457,12 +491,6 @@ extension RequestBag: HTTPExecutableRequest { extension RequestBag: HTTPClientTaskDelegate { func cancel() { - if self.task.eventLoop.inEventLoop { - self.fail0(HTTPClientError.cancelled) - } else { - self.task.eventLoop.execute { - self.fail0(HTTPClientError.cancelled) - } - } + self.fail(HTTPClientError.cancelled) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 3975036ea..b3a13486c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -79,6 +79,7 @@ extension HTTPClientTests { ("testStressGetHttps", testStressGetHttps), ("testStressGetHttpsSSLError", testStressGetHttpsSSLError), ("testSelfSignedCertificateIsRejectedWithCorrectError", testSelfSignedCertificateIsRejectedWithCorrectError), + ("testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded", testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded), ("testFailingConnectionIsReleased", testFailingConnectionIsReleased), ("testResponseDelayGet", testResponseDelayGet), ("testIdleTimeoutNoReuse", testIdleTimeoutNoReuse), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index ece09c52d..02c60d177 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1269,6 +1269,47 @@ class HTTPClientTests: XCTestCase { } } + func testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded() throws { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: try NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .file(keyPath) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NIOSSLServerHandler(context: sslContext)) + } + let serverChannel = try server.bind(host: "localhost", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + var config = HTTPClient.Configuration() + config.timeout.connect = .seconds(3) + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(port)", deadline: .now() + .seconds(2)).wait()) { error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + func testFailingConnectionIsReleased() { let localHTTPBin = HTTPBin(.refuse) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift index 49a6fb574..7f59fd4e1 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift @@ -21,6 +21,7 @@ import XCTest class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { func testCreatingAndFailingConnections() { + struct SomeError: Error, Equatable {} let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } @@ -65,8 +66,6 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // fail all connection attempts while let randomConnectionID = connections.randomStartingConnection() { - struct SomeError: Error, Equatable {} - XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) @@ -86,9 +85,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // cancel all queued requests while let request = queuer.timeoutRandomRequest() { - let cancelAction = state.cancelRequest(request) + let cancelAction = state.cancelRequest(request.0) XCTAssertEqual(cancelAction.connection, .none) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request)) + XCTAssertEqual(cancelAction.request, .failRequest(.init(request.1), SomeError(), cancelTimeout: true)) } // connection backoff done @@ -184,7 +183,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // 2. cancel request let cancelAction = state.cancelRequest(request.id) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request.id)) + XCTAssertEqual(cancelAction.request, .failRequest(request, HTTPClientError.cancelled, cancelTimeout: true)) XCTAssertEqual(cancelAction.connection, .none) // 3. request timeout triggers to late diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift index 2574d3da2..e42a98ac7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift @@ -212,7 +212,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // 2. cancel request let cancelAction = state.cancelRequest(request.id) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request.id)) + XCTAssertEqual(cancelAction.request, .failRequest(request, HTTPClientError.cancelled, cancelTimeout: true)) XCTAssertEqual(cancelAction.connection, .none) // 3. request timeout triggers to late @@ -1242,9 +1242,9 @@ func XCTAssertEqualTypeAndValue( let lhs = try lhs() let rhs = try rhs() guard let lhsAsRhs = lhs as? Right else { - XCTFail("could not cast \(lhs) of type \(type(of: lhs)) to \(type(of: rhs))") + XCTFail("could not cast \(lhs) of type \(type(of: lhs)) to \(type(of: rhs))", file: file, line: line) return } - XCTAssertEqual(lhsAsRhs, rhs) + XCTAssertEqual(lhsAsRhs, rhs, file: file, line: line) }(), file: file, line: line) } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift index cb67837d7..0ffdeebd8 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift @@ -126,8 +126,6 @@ extension HTTPConnectionPool.StateMachine.RequestAction: Equatable { return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) case (.scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL)): return lhsReq == rhsReq && lhsEL === rhsEL - case (.cancelRequestTimeout(let lhsReqID), .cancelRequestTimeout(let rhsReqID)): - return lhsReqID == rhsReqID case (.none, .none): return true default: diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift index e81f1ed0a..520b51875 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift @@ -82,11 +82,11 @@ struct MockRequestQueuer { return waiter.request } - mutating func timeoutRandomRequest() -> RequestID? { - guard let waiterID = self.waiters.randomElement().map(\.0) else { + mutating func timeoutRandomRequest() -> (RequestID, HTTPSchedulableRequest)? { + guard let waiter = self.waiters.randomElement() else { return nil } - self.waiters.removeValue(forKey: waiterID) - return waiterID + self.waiters.removeValue(forKey: waiter.key) + return (waiter.key, waiter.value.request) } } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift index 74c68fd1f..19de474c2 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift @@ -28,6 +28,7 @@ extension RequestBagTests { ("testWriteBackpressureWorks", testWriteBackpressureWorks), ("testTaskIsFailedIfWritingFails", testTaskIsFailedIfWritingFails), ("testCancelFailsTaskBeforeRequestIsSent", testCancelFailsTaskBeforeRequestIsSent), + ("testDeadlineExceededFailsTaskEvenIfRaceBetweenCancelingSchedulerAndRequestStart", testDeadlineExceededFailsTaskEvenIfRaceBetweenCancelingSchedulerAndRequestStart), ("testCancelFailsTaskAfterRequestIsSent", testCancelFailsTaskAfterRequestIsSent), ("testCancelFailsTaskWhenTaskIsQueued", testCancelFailsTaskWhenTaskIsQueued), ("testFailsTaskWhenTaskIsWaitingForMoreFromServer", testFailsTaskWhenTaskIsWaitingForMoreFromServer), diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index c80f8846b..b896aca0a 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -228,6 +228,44 @@ final class RequestBagTests: XCTestCase { } } + func testDeadlineExceededFailsTaskEvenIfRaceBetweenCancelingSchedulerAndRequestStart() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let queuer = MockTaskQueuer() + bag.requestWasQueued(queuer) + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + XCTAssertEqual(queuer.hitCancelCount, 0) + bag.deadlineExceeded() + XCTAssertEqual(queuer.hitCancelCount, 1) + + bag.willExecuteRequest(executor) + XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .deadlineExceeded) + } + } + func testCancelFailsTaskAfterRequestIsSent() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) }