From 776b8a8640fab742d7c4884327a904858c42fe5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Trev=C3=B6r=20Anne=20Denise?= Date: Mon, 2 Sep 2019 14:41:41 +0200 Subject: [PATCH] Tolerate futures from arbitrary event loops This commit fixes #95 by always hopping event loop futures received from the delegate to the right event loop. This could be a source of bugs if the library users forgot to hop(to:) futures from their delegates implementations. --- Sources/AsyncHTTPClient/HTTPHandler.swift | 42 ++++++++++++------- .../HTTPClientTestUtils.swift | 12 ++++-- .../HTTPClientTests+XCTest.swift | 2 +- .../HTTPClientTests.swift | 20 ++++----- 4 files changed, 45 insertions(+), 31 deletions(-) diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 088b7f562..b9fcb0738 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -227,7 +227,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate { case .error: break } - return task.eventLoop.makeSucceededFuture(()) + return task.currentEventLoop.makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _ part: ByteBuffer) -> EventLoopFuture { @@ -245,7 +245,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate { case .error: break } - return task.eventLoop.makeSucceededFuture(()) + return task.currentEventLoop.makeSucceededFuture(()) } func didReceiveError(task: HTTPClient.Task, _ error: Error) { @@ -343,9 +343,9 @@ extension HTTPClientResponseDelegate { public func didSendRequest(task: HTTPClient.Task) {} - public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { return task.eventLoop.makeSucceededFuture(()) } + public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { return task.currentEventLoop.makeSucceededFuture(()) } - public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { return task.eventLoop.makeSucceededFuture(()) } + public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { return task.currentEventLoop.makeSucceededFuture(()) } public func didReceiveError(task: HTTPClient.Task, _: Error) {} } @@ -366,15 +366,23 @@ extension HTTPClient { /// `EventLoopFuture` of the execution or cancellation of the execution. public final class Task { /// `EventLoop` used to execute and process this request. - public let eventLoop: EventLoop - let promise: EventLoopPromise + public var currentEventLoop: EventLoop { + return self.lock.withLock { + _currentEventLoop + } + } + /// The stored property used by `currentEventLoop` in combination with the `lock` + /// + /// In most cases you should use `currentEventLoop` instead + private var _currentEventLoop: EventLoop + let promise: EventLoopPromise private var channel: Channel? private var cancelled: Bool private let lock: Lock - public init(eventLoop: EventLoop) { - self.eventLoop = eventLoop + init(eventLoop: EventLoop) { + self._currentEventLoop = eventLoop self.promise = eventLoop.makePromise() self.cancelled = false self.lock = Lock() @@ -405,8 +413,8 @@ extension HTTPClient { @discardableResult func setChannel(_ channel: Channel) -> Channel { - precondition(self.eventLoop === channel.eventLoop, "Channel must use same event loop as this task.") return self.lock.withLock { + self._currentEventLoop = channel.eventLoop self.channel = channel return channel } @@ -539,9 +547,11 @@ internal class TaskHandler: ChannelInboundHandler } else { self.state = .head self.mayRead = false - self.delegate.didReceiveHead(task: self.task, head).whenComplete { result in - self.handleBackpressureResult(context: context, result: result) - } + self.delegate.didReceiveHead(task: self.task, head) + .hop(to: context.eventLoop) + .whenComplete { result in + self.handleBackpressureResult(context: context, result: result) + } } case .body(let body): switch self.state { @@ -550,9 +560,11 @@ internal class TaskHandler: ChannelInboundHandler default: self.state = .body self.mayRead = false - self.delegate.didReceiveBodyPart(task: self.task, body).whenComplete { result in - self.handleBackpressureResult(context: context, result: result) - } + self.delegate.didReceiveBodyPart(task: self.task, body) + .hop(to: context.eventLoop) + .whenComplete { result in + self.handleBackpressureResult(context: context, result: result) + } } case .end: switch self.state { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 1546f3834..fd5b59bce 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -21,6 +21,12 @@ import NIOSSL class TestHTTPDelegate: HTTPClientResponseDelegate { typealias Response = Void + init(backpressureEventLoop: EventLoop? = nil) { + self.backpressureEventLoop = backpressureEventLoop + } + + var backpressureEventLoop: EventLoop? + enum State { case idle case head(HTTPResponseHead) @@ -33,7 +39,7 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { self.state = .head(head) - return task.eventLoop.makeSucceededFuture(()) + return (self.backpressureEventLoop ?? task.currentEventLoop).makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { @@ -47,7 +53,7 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { default: preconditionFailure("expecting head or body") } - return task.eventLoop.makeSucceededFuture(()) + return (self.backpressureEventLoop ?? task.currentEventLoop).makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws {} @@ -63,7 +69,7 @@ class CountingDelegate: HTTPClientResponseDelegate { if str?.starts(with: "id:") ?? false { self.count += 1 } - return task.eventLoop.makeSucceededFuture(()) + return task.currentEventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Int { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 93f66905b..9a7fe1bba 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -29,7 +29,7 @@ extension HTTPClientTests { ("testBadRequestURI", testBadRequestURI), ("testSchemaCasing", testSchemaCasing), ("testGet", testGet), - ("testGetWithSharedEventLoopGroup", testGetWithSharedEventLoopGroup), + ("testGetWithDifferentEventLoopBackpressure", testGetWithDifferentEventLoopBackpressure), ("testPost", testPost), ("testGetHttps", testGetHttps), ("testPostHttps", testPostHttps), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index c3e8bf105..44485fc70 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -62,22 +62,18 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, response.status) } - func testGetWithSharedEventLoopGroup() throws { + func testGetWithDifferentEventLoopBackpressure() throws { let httpBin = HttpBin() - let elg = MultiThreadedEventLoopGroup(numberOfThreads: 8) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(elg)) + let loopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let external = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(loopGroup)) defer { - try! elg.syncShutdownGracefully() + try! loopGroup.syncShutdownGracefully() httpBin.shutdown() } - - let delegate = TestHTTPDelegate() let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/events/10/1") + let delegate = TestHTTPDelegate(backpressureEventLoop: external.next()) let task = httpClient.execute(request: request, delegate: delegate) - let expectedEventLoop = task.eventLoop - task.futureResult.whenComplete { _ in - XCTAssertTrue(expectedEventLoop.inEventLoop) - } try task.wait() } @@ -506,8 +502,8 @@ class HTTPClientTests: XCTestCase { } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.result = task.eventLoop === self.eventLoop - return task.eventLoop.makeSucceededFuture(()) + self.result = task.currentEventLoop === self.eventLoop + return task.currentEventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Bool {