diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index ebede078b..dfb7d5f4f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -93,20 +93,23 @@ final class ConnectionPool { } } - func prepareForClose() { - let connectionProviders = self.connectionProvidersLock.withLock { self.connectionProviders.values } - for connectionProvider in connectionProviders { - connectionProvider.prepareForClose() + func prepareForClose(on eventLoop: EventLoop) -> EventLoopFuture { + let connectionProviders = self.connectionProvidersLock.withLock { + self.connectionProviders.values } + + return EventLoopFuture.andAllComplete(connectionProviders.map { $0.prepareForClose() }, on: eventLoop) } - func syncClose() { - let connectionProviders = self.connectionProvidersLock.withLock { self.connectionProviders.values } - for connectionProvider in connectionProviders { - connectionProvider.syncClose() + func close(on eventLoop: EventLoop) -> EventLoopFuture { + let connectionProviders = self.connectionProvidersLock.withLock { + self.connectionProviders.values } - self.connectionProvidersLock.withLock { - assert(self.connectionProviders.count == 0, "left-overs: \(self.connectionProviders)") + + return EventLoopFuture.andAllComplete(connectionProviders.map { $0.close() }, on: eventLoop).map { + self.connectionProvidersLock.withLock { + assert(self.connectionProviders.count == 0, "left-overs: \(self.connectionProviders)") + } } } @@ -448,9 +451,7 @@ final class ConnectionPool { } /// Removes and fails all `waiters`, remove existing `availableConnections` and sets `state.activity` to `.closing` - func prepareForClose() { - assert(MultiThreadedEventLoopGroup.currentEventLoop == nil, - "HTTPClient shutdown on EventLoop unsupported") // calls .wait() so it would crash later anyway + func prepareForClose() -> EventLoopFuture { let (waitersFutures, closeFutures) = self.stateLock.withLock { () -> ([EventLoopFuture], [EventLoopFuture]) in // Fail waiters let waitersCopy = self.state.waiters @@ -461,26 +462,29 @@ final class ConnectionPool { let closeFutures = self.state.availableConnections.map { $0.close() } return (waitersFutures, closeFutures) } - try? EventLoopFuture.andAllComplete(waitersFutures, on: self.eventLoop).wait() - try? EventLoopFuture.andAllComplete(closeFutures, on: self.eventLoop).wait() - self.stateLock.withLock { - if self.state.leased == 0, self.state.availableConnections.isEmpty { - self.state.activity = .closed - } else { - self.state.activity = .closing + return EventLoopFuture.andAllComplete(waitersFutures, on: self.eventLoop) + .flatMap { + EventLoopFuture.andAllComplete(closeFutures, on: self.eventLoop) + } + .map { _ in + self.stateLock.withLock { + if self.state.leased == 0, self.state.availableConnections.isEmpty { + self.state.activity = .closed + } else { + self.state.activity = .closing + } + } } - } } - func syncClose() { - assert(MultiThreadedEventLoopGroup.currentEventLoop == nil, - "HTTPClient shutdown on EventLoop unsupported") // calls .wait() so it would crash later anyway + func close() -> EventLoopFuture { let availableConnections = self.stateLock.withLock { () -> CircularBuffer in assert(self.state.activity == .closing) return self.state.availableConnections } - try? EventLoopFuture.andAllComplete(availableConnections.map { $0.close() }, on: self.eventLoop).wait() + + return EventLoopFuture.andAllComplete(availableConnections.map { $0.close() }, on: self.eventLoop) } private func activityPrecondition(expected: Set) { diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 0ecd49a05..548f038ad 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -93,54 +93,101 @@ public class HTTPClient { /// this indicate shutdown was called too early before tasks were completed or explicitly canceled. /// In general, setting this parameter to `true` should make it easier and faster to catch related programming errors. internal func syncShutdown(requiresCleanClose: Bool) throws { - var closeError: Error? - - let tasks = try self.stateLock.withLock { () -> Dictionary.Values in - if self.state != .upAndRunning { - throw HTTPClientError.alreadyShutdown + if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop { + preconditionFailure(""" + BUG DETECTED: syncShutdown() must not be called when on an EventLoop. + Calling syncShutdown() on any EventLoop can lead to deadlocks. + Current eventLoop: \(eventLoop) + """) + } + let errorStorageLock = Lock() + var errorStorage: Error? + let continuation = DispatchWorkItem {} + self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown")) { error in + if let error = error { + errorStorageLock.withLock { + errorStorage = error + } } - self.state = .shuttingDown - return self.tasks.values + continuation.perform() } - - self.pool.prepareForClose() - - if !tasks.isEmpty, requiresCleanClose { - closeError = HTTPClientError.uncleanShutdown + continuation.wait() + try errorStorageLock.withLock { + if let error = errorStorage { + throw error + } } + } + + /// Shuts down the client and event loop gracefully. This function is clearly an outlier in that it uses a completion + /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. + /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue + /// instead. + public func shutdown(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { + self.shutdown(requiresCleanClose: false, queue: queue, callback) + } + private func cancelTasks(_ tasks: Dictionary.Values) -> EventLoopFuture { for task in tasks { task.cancel() } - try? EventLoopFuture.andAllComplete((tasks.map { $0.completion }), on: self.eventLoopGroup.next()).wait() - - self.pool.syncClose() + return EventLoopFuture.andAllComplete(tasks.map { $0.completion }, on: self.eventLoopGroup.next()) + } - do { - try self.stateLock.withLock { - switch self.eventLoopGroupProvider { - case .shared: + private func shutdownEventLoop(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { + self.stateLock.withLock { + switch self.eventLoopGroupProvider { + case .shared: + self.state = .shutDown + callback(nil) + case .createNew: + switch self.state { + case .shuttingDown: self.state = .shutDown - return - case .createNew: - switch self.state { - case .shuttingDown: - self.state = .shutDown - try self.eventLoopGroup.syncShutdownGracefully() - case .shutDown, .upAndRunning: - assertionFailure("The only valid state at this point is \(State.shutDown)") - } + self.eventLoopGroup.shutdownGracefully(queue: queue, callback) + case .shutDown, .upAndRunning: + assertionFailure("The only valid state at this point is \(State.shutDown)") } } - } catch { - if closeError == nil { - closeError = error + } + } + + private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { + let result: Result.Values, Error> = self.stateLock.withLock { + if self.state != .upAndRunning { + return .failure(HTTPClientError.alreadyShutdown) + } else { + self.state = .shuttingDown + return .success(self.tasks.values) } } - if let closeError = closeError { - throw closeError + switch result { + case .failure(let error): + callback(error) + case .success(let tasks): + self.pool.prepareForClose(on: self.eventLoopGroup.next()).whenComplete { _ in + var closeError: Error? + if !tasks.isEmpty, requiresCleanClose { + closeError = HTTPClientError.uncleanShutdown + } + + // we ignore errors here + self.cancelTasks(tasks).whenComplete { _ in + // we ignore errors here + self.pool.close(on: self.eventLoopGroup.next()).whenComplete { _ in + self.shutdownEventLoop(queue: queue) { eventLoopError in + // we prioritise .uncleanShutdown here + if let error = closeError { + callback(error) + } else { + callback(eventLoopError) + } + } + } + } + } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 9a53e36ce..0ffeb42a6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -96,6 +96,7 @@ extension HTTPClientTests { ("testPoolClosesIdleConnections", testPoolClosesIdleConnections), ("testRacePoolIdleConnectionsAndGet", testRacePoolIdleConnectionsAndGet), ("testAvoidLeakingTLSHandshakeCompletionPromise", testAvoidLeakingTLSHandshakeCompletionPromise), + ("testAsyncShutdown", testAsyncShutdown), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index bba05a63a..731586db9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1673,4 +1673,17 @@ class HTTPClientTests: XCTestCase { } } } + + func testAsyncShutdown() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + let promise = eventLoopGroup.next().makePromise(of: Void.self) + eventLoopGroup.next().execute { + httpClient.shutdown(queue: DispatchQueue(label: "testAsyncShutdown")) { error in + XCTAssertNil(error) + promise.succeed(()) + } + } + XCTAssertNoThrow(try promise.futureResult.wait()) + } }