diff --git a/.gitignore b/.gitignore index d64b77136..10a8e3888 100644 --- a/.gitignore +++ b/.gitignore @@ -2,16 +2,11 @@ project.xcworkspace xcuserdata .build -protoc-gen-swift -protoc-gen-swiftgrpc +/protoc-gen-swift +/protoc-gen-swiftgrpc third_party/** -Plugin/Packages/** -Plugin/Sources/protoc-gen-swiftgrpc/templates.swift -Plugin/protoc-* -Plugin/swiftgrpc.log -Plugin/echo.*.swift -Echo -test.out -echo.pid -SwiftGRPC.xcodeproj +/Echo +/test.out +/echo.pid +/SwiftGRPC.xcodeproj Package.resolved diff --git a/.travis-install.sh b/.travis-install.sh index fd7730a8c..eb201bc0b 100755 --- a/.travis-install.sh +++ b/.travis-install.sh @@ -28,7 +28,7 @@ cd mkdir -p local -if [ $TRAVIS_OS_NAME == "osx" ]; then +if [ "$TRAVIS_OS_NAME" == "osx" ]; then PROTOC_URL=https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-osx-x86_64.zip else # Install swift diff --git a/.travis.yml b/.travis.yml index 34f573a56..e59fb0573 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,6 +19,12 @@ os: - linux - osx +cache: + apt: true + directories: + - .build/checkouts + - .build/repositories + # Use Ubuntu 14.04 dist: trusty @@ -29,23 +35,25 @@ sudo: false addons: apt: packages: - - clang-3.8 - - lldb-3.8 - - libicu-dev - - libtool - - libcurl4-openssl-dev - - libbsd-dev - - build-essential - - libssl-dev - - uuid-dev - - curl - - unzip + - clang-3.8 + - lldb-3.8 + - libicu-dev + - libtool + - libcurl4-openssl-dev + - libbsd-dev + - build-essential + - libssl-dev + - uuid-dev + - curl + - unzip install: ./.travis-install.sh script: - export PATH=$HOME/local/bin:$PATH - export LD_LIBRARY_PATH=$HOME/local/lib + - swift package -v resolve + - make all - make test - make test-plugin - make test-echo diff --git a/Examples/EchoXcode/Echo/EchoViewController.swift b/Examples/EchoXcode/Echo/EchoViewController.swift index 4e3056996..8fe65e653 100644 --- a/Examples/EchoXcode/Echo/EchoViewController.swift +++ b/Examples/EchoXcode/Echo/EchoViewController.swift @@ -44,7 +44,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { do { try callServer(address: addressField.stringValue, host: "example.com") - } catch (let error) { + } catch { print(error) } } @@ -54,7 +54,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { if nowStreaming { do { try sendClose() - } catch (let error) { + } catch { print(error) } } @@ -66,7 +66,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { if nowStreaming { do { try sendClose() - } catch (let error) { + } catch { print(error) } } @@ -76,7 +76,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { if nowStreaming { do { try sendClose() - } catch (let error) { + } catch { print(error) } } @@ -151,7 +151,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { } try receiveExpandMessages() displayMessageSent(requestMessage.text) - } catch (let error) { + } catch { self.displayMessageReceived("No message received. \(error)") } } @@ -169,7 +169,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { } } try sendCollectMessage() - } catch (let error) { + } catch { self.displayMessageReceived("No message received. \(error)") } } else if callSelectButton.selectedSegment == 3 { @@ -187,7 +187,7 @@ class EchoViewController: NSViewController, NSTextFieldDelegate { } } try sendUpdateMessage() - } catch (let error) { + } catch { self.displayMessageReceived("No message received. \(error)") } } diff --git a/Examples/Google/NaturalLanguage/Sources/main.swift b/Examples/Google/NaturalLanguage/Sources/main.swift index 9f211d79c..a711b25da 100644 --- a/Examples/Google/NaturalLanguage/Sources/main.swift +++ b/Examples/Google/NaturalLanguage/Sources/main.swift @@ -55,7 +55,7 @@ if let provider = DefaultTokenProvider(scopes: scopes) { do { let result = try service.annotatetext(request) print("\(result)") - } catch (let error) { + } catch { print("ERROR: \(error)") } } diff --git a/Examples/SimpleXcode/Simple/Document.swift b/Examples/SimpleXcode/Simple/Document.swift index e23545a70..dce2b88c5 100644 --- a/Examples/SimpleXcode/Simple/Document.swift +++ b/Examples/SimpleXcode/Simple/Document.swift @@ -181,8 +181,8 @@ class Document: NSDocument { } } } - } catch (let callError) { - Swift.print("call error \(callError)") + } catch { + Swift.print("call error \(error)") } self.log("------------------------------") sleep(1) @@ -241,8 +241,8 @@ class Document: NSDocument { trailingMetadata: trailingMetadataToSend) self.log("------------------------------") - } catch (let callError) { - Swift.print("call error \(callError)") + } catch { + Swift.print("call error \(error)") } } diff --git a/Makefile b/Makefile index 299ab0cba..5aadcdfa6 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ all: project: swift package generate-xcodeproj +# Optional: set the generated project's indentation settings. + -ruby fix-indentation-settings.rb test: all swift test -v $(CFLAGS) diff --git a/Package.swift b/Package.swift index b4d9fa845..6d0e227c8 100644 --- a/Package.swift +++ b/Package.swift @@ -23,7 +23,7 @@ let package = Package( .library(name: "SwiftGRPC", targets: ["SwiftGRPC"]), ], dependencies: [ - .package(url: "https://github.com/Zewo/zlib.git", from: "0.4.0"), + .package(url: "https://github.com/apple/swift-nio-zlib-support.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-protobuf.git", from: "1.0.2"), .package(url: "https://github.com/kylef/Commander.git", from: "0.8.0") ], @@ -31,7 +31,7 @@ let package = Package( .target(name: "SwiftGRPC", dependencies: ["CgRPC", "SwiftProtobuf"]), .target(name: "CgRPC", - dependencies: ["BoringSSL", "zlib"]), + dependencies: ["BoringSSL"]), .target(name: "RootsEncoder"), .target(name: "protoc-gen-swiftgrpc", dependencies: [ diff --git a/Sources/Examples/Echo/EchoProvider.swift b/Sources/Examples/Echo/EchoProvider.swift index f12a1afb1..b37b8590a 100644 --- a/Sources/Examples/Echo/EchoProvider.swift +++ b/Sources/Examples/Echo/EchoProvider.swift @@ -38,6 +38,7 @@ class EchoProvider: Echo_EchoProvider { } } session.waitForSendOperationsToFinish() + try session.close(withStatus: .ok, completion: nil) } // collect collects a sequence of messages and returns them concatenated when the caller closes. @@ -45,17 +46,17 @@ class EchoProvider: Echo_EchoProvider { var parts: [String] = [] while true { do { - let request = try session.receive() + guard let request = try session.receive() + else { break } // End of stream parts.append(request.text) - } catch ServerError.endOfStream { + } catch { + print("collect error: \(error)") break - } catch (let error) { - print("\(error)") } } var response = Echo_EchoResponse() response.text = "Swift echo collect: " + parts.joined(separator: " ") - try session.sendAndClose(response) + try session.sendAndClose(response: response, status: .ok, completion: nil) } // update streams back messages as they are received in an input stream. @@ -63,7 +64,8 @@ class EchoProvider: Echo_EchoProvider { var count = 0 while true { do { - let request = try session.receive() + guard let request = try session.receive() + else { break } // End of stream var response = Echo_EchoResponse() response.text = "Swift echo update (\(count)): \(request.text)" count += 1 @@ -72,14 +74,12 @@ class EchoProvider: Echo_EchoProvider { print("update error: \(error)") } } - } catch ServerError.endOfStream { - break - } catch (let error) { - print("\(error)") + } catch { + print("update error: \(error)") break } } session.waitForSendOperationsToFinish() - try session.close() + try session.close(withStatus: .ok, completion: nil) } } diff --git a/Sources/Examples/Echo/Generated/echo.grpc.swift b/Sources/Examples/Echo/Generated/echo.grpc.swift index 844bfec6a..0534e12aa 100644 --- a/Sources/Examples/Echo/Generated/echo.grpc.swift +++ b/Sources/Examples/Echo/Generated/echo.grpc.swift @@ -33,9 +33,9 @@ fileprivate final class Echo_EchoGetCallBase: ClientCallUnaryBase Echo_EchoResponse + func receive() throws -> Echo_EchoResponse? /// Call this to wait for a result. Nonblocking. - func receive(completion: @escaping (Echo_EchoResponse?, ClientError?) -> Void) throws + func receive(completion: @escaping (ResultOrRPCError) -> Void) throws } fileprivate final class Echo_EchoExpandCallBase: ClientCallServerStreamingBase, Echo_EchoExpandCall { @@ -47,13 +47,15 @@ class Echo_EchoExpandCallTestStub: ClientCallServerStreamingTestStub Void) throws + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoRequest) throws /// Call this to close the connection and wait for a response. Blocking. func closeAndReceive() throws -> Echo_EchoResponse /// Call this to close the connection and wait for a response. Nonblocking. - func closeAndReceive(completion: @escaping (Echo_EchoResponse?, ClientError?) -> Void) throws + func closeAndReceive(completion: @escaping (ResultOrRPCError) -> Void) throws } fileprivate final class Echo_EchoCollectCallBase: ClientCallClientStreamingBase, Echo_EchoCollectCall { @@ -68,12 +70,14 @@ class Echo_EchoCollectCallTestStub: ClientCallClientStreamingTestStub Echo_EchoResponse + func receive() throws -> Echo_EchoResponse? /// Call this to wait for a result. Nonblocking. - func receive(completion: @escaping (Echo_EchoResponse?, ClientError?) -> Void) throws + func receive(completion: @escaping (ResultOrRPCError) -> Void) throws - /// Call this to send each message in the request stream. + /// Send a message to the stream. Nonblocking. func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoRequest) throws /// Call this to close the sending connection. Blocking. func closeSend() throws @@ -201,8 +205,14 @@ fileprivate final class Echo_EchoGetSessionBase: ServerSessionUnaryBase Void)?) throws + /// Send a message to the stream. Nonblocking. + func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoResponse) throws + + /// Close the connection and send the status. Non-blocking. + /// You MUST call this method once you are done processing the request. + func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws } fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreamingBase, Echo_EchoExpandSession {} @@ -210,11 +220,18 @@ fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreaming class Echo_EchoExpandSessionTestStub: ServerSessionServerStreamingTestStub, Echo_EchoExpandSession {} internal protocol Echo_EchoCollectSession: ServerSessionClientStreaming { - /// Receive a message. Blocks until a message is received or the client closes the connection. - func receive() throws -> Echo_EchoRequest - - /// Send a response and close the connection. - func sendAndClose(_ response: Echo_EchoResponse) throws + /// Call this to wait for a result. Blocking. + func receive() throws -> Echo_EchoRequest? + /// Call this to wait for a result. Nonblocking. + func receive(completion: @escaping (ResultOrRPCError) -> Void) throws + + /// You MUST call one of these two methods once you are done processing the request. + /// Close the connection and send a single result. Non-blocking. + func sendAndClose(response: Echo_EchoResponse, status: ServerStatus, completion: (() -> Void)?) throws + /// Close the connection and send an error. Non-blocking. + /// Use this method if you encountered an error that makes it impossible to send a response. + /// Accordingly, it does not make sense to call this method with a status of `.ok`. + func sendErrorAndClose(status: ServerStatus, completion: (() -> Void)?) throws } fileprivate final class Echo_EchoCollectSessionBase: ServerSessionClientStreamingBase, Echo_EchoCollectSession {} @@ -222,14 +239,19 @@ fileprivate final class Echo_EchoCollectSessionBase: ServerSessionClientStreamin class Echo_EchoCollectSessionTestStub: ServerSessionClientStreamingTestStub, Echo_EchoCollectSession {} internal protocol Echo_EchoUpdateSession: ServerSessionBidirectionalStreaming { - /// Receive a message. Blocks until a message is received or the client closes the connection. - func receive() throws -> Echo_EchoRequest + /// Call this to wait for a result. Blocking. + func receive() throws -> Echo_EchoRequest? + /// Call this to wait for a result. Nonblocking. + func receive(completion: @escaping (ResultOrRPCError) -> Void) throws - /// Send a message. Nonblocking. - func send(_ response: Echo_EchoResponse, completion: ((Error?) -> Void)?) throws + /// Send a message to the stream. Nonblocking. + func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoResponse) throws - /// Close a connection. Blocks until the connection is closed. - func close() throws + /// Close the connection and send the status. Non-blocking. + /// You MUST call this method once you are done processing the request. + func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws } fileprivate final class Echo_EchoUpdateSessionBase: ServerSessionBidirectionalStreamingBase, Echo_EchoUpdateSession {} diff --git a/Sources/Examples/Echo/main.swift b/Sources/Examples/Echo/main.swift index 1df981ab7..98577c816 100644 --- a/Sources/Examples/Echo/main.swift +++ b/Sources/Examples/Echo/main.swift @@ -109,16 +109,13 @@ Group { callResult = result sem.signal() } - var running = true - while running { - do { - let responseMessage = try expandCall.receive() - print("expand received: \(responseMessage.text)") - } catch ClientError.endOfStream { - running = false - } + while true { + guard let responseMessage = try expandCall.receive() + else { break } // End of stream + print("expand received: \(responseMessage.text)") } _ = sem.wait() + if let statusCode = callResult?.statusCode { print("expand completed with code \(statusCode)") } @@ -150,6 +147,7 @@ Group { let responseMessage = try collectCall.closeAndReceive() print("collect received: \(responseMessage.text)") _ = sem.wait() + if let statusCode = callResult?.statusCode { print("collect completed with code \(statusCode)") } @@ -182,17 +180,10 @@ Group { try updateCall.closeSend() while true { - do { - let responseMessage = try updateCall.receive() - print("update received: \(responseMessage.text)") - } catch ClientError.endOfStream { - break - } catch (let error) { - print("update receive error: \(error)") - break - } + guard let responseMessage = try updateCall.receive() + else { break } // End of stream + print("update received: \(responseMessage.text)") } - _ = sem.wait() if let statusCode = callResult?.statusCode { diff --git a/Sources/Examples/Simple/main.swift b/Sources/Examples/Simple/main.swift index 3d8a43788..220526e54 100644 --- a/Sources/Examples/Simple/main.swift +++ b/Sources/Examples/Simple/main.swift @@ -105,13 +105,11 @@ func server() throws { "2": "two" ]) try requestHandler.sendResponse(message: replyMessage.data(using: .utf8)!, - statusCode: .ok, - statusMessage: "OK", - trailingMetadata: trailingMetadataToSend) + status: ServerStatus(code: .ok, message: "OK", trailingMetadata: trailingMetadataToSend)) print("------------------------------") - } catch (let callError) { - Swift.print("call error \(callError)") + } catch { + Swift.print("call error \(error)") } } diff --git a/Sources/SwiftGRPC/Core/Call.swift b/Sources/SwiftGRPC/Core/Call.swift index 05ae3ec87..b085a160f 100644 --- a/Sources/SwiftGRPC/Core/Call.swift +++ b/Sources/SwiftGRPC/Core/Call.swift @@ -30,113 +30,6 @@ public enum CallWarning: Error { case blocked } -public enum CallError: Error { - case ok - case unknown - case notOnServer - case notOnClient - case alreadyAccepted - case alreadyInvoked - case notInvoked - case alreadyFinished - case tooManyOperations - case invalidFlags - case invalidMetadata - case invalidMessage - case notServerCompletionQueue - case batchTooBig - case payloadTypeMismatch - - static func callError(grpcCallError error: grpc_call_error) -> CallError { - switch error { - case GRPC_CALL_OK: - return .ok - case GRPC_CALL_ERROR: - return .unknown - case GRPC_CALL_ERROR_NOT_ON_SERVER: - return .notOnServer - case GRPC_CALL_ERROR_NOT_ON_CLIENT: - return .notOnClient - case GRPC_CALL_ERROR_ALREADY_ACCEPTED: - return .alreadyAccepted - case GRPC_CALL_ERROR_ALREADY_INVOKED: - return .alreadyInvoked - case GRPC_CALL_ERROR_NOT_INVOKED: - return .notInvoked - case GRPC_CALL_ERROR_ALREADY_FINISHED: - return .alreadyFinished - case GRPC_CALL_ERROR_TOO_MANY_OPERATIONS: - return .tooManyOperations - case GRPC_CALL_ERROR_INVALID_FLAGS: - return .invalidFlags - case GRPC_CALL_ERROR_INVALID_METADATA: - return .invalidMetadata - case GRPC_CALL_ERROR_INVALID_MESSAGE: - return .invalidMessage - case GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE: - return .notServerCompletionQueue - case GRPC_CALL_ERROR_BATCH_TOO_BIG: - return .batchTooBig - case GRPC_CALL_ERROR_PAYLOAD_TYPE_MISMATCH: - return .payloadTypeMismatch - default: - return .unknown - } - } -} - -public struct CallResult: CustomStringConvertible { - public let statusCode: StatusCode - public let statusMessage: String? - public let resultData: Data? - public let initialMetadata: Metadata? - public let trailingMetadata: Metadata? - - fileprivate init(_ op: OperationGroup) { - if op.success { - if let statusCodeRawValue = op.receivedStatusCode() { - if let statusCode = StatusCode(rawValue: statusCodeRawValue) { - self.statusCode = statusCode - } else { - statusCode = .unknown - } - } else { - statusCode = .ok - } - statusMessage = op.receivedStatusMessage() - resultData = op.receivedMessage()?.data() - initialMetadata = op.receivedInitialMetadata() - trailingMetadata = op.receivedTrailingMetadata() - } else { - statusCode = .unknown - statusMessage = nil - resultData = nil - initialMetadata = nil - trailingMetadata = nil - } - } - - public var description: String { - var result = "status \(statusCode)" - if let statusMessage = self.statusMessage { - result += ": " + statusMessage - } - if let resultData = self.resultData { - result += "\n" - result += resultData.description - } - if let initialMetadata = self.initialMetadata { - result += "\n" - result += initialMetadata.description - } - if let trailingMetadata = self.trailingMetadata { - result += "\n" - result += trailingMetadata.description - } - return result - } -} - /// A gRPC API call public class Call { /// Shared mutex for synchronizing calls to cgrpc_call_perform() @@ -167,9 +60,6 @@ public class Call { /// Mutex for synchronizing message sending private let sendMutex: Mutex - /// Dispatch queue used for sending messages asynchronously - private let messageDispatchQueue: DispatchQueue = DispatchQueue.global() - /// Initializes a Call representation /// /// - Parameter call: the underlying C representation @@ -209,6 +99,8 @@ public class Call { /// - Parameter metadata: metadata to send with the call /// - Parameter message: data containing the message to send (.unary and .serverStreaming only) /// - Parameter completion: a block to call with call results + /// The argument to `completion` will always have `.success = true` + /// because operations containing `.receiveCloseOnClient` always succeed. /// - Throws: `CallError` if fails to call. public func start(_ style: CallStyle, metadata: Metadata, @@ -240,11 +132,18 @@ public class Call { .receiveStatusOnClient, ] case .clientStreaming, .bidiStreaming: - operations = [ - .sendInitialMetadata(metadata.copy()), - .receiveInitialMetadata, - .receiveStatusOnClient, - ] + try perform(OperationGroup(call: self, + operations: [ + .sendInitialMetadata(metadata.copy()), + .receiveInitialMetadata + ], + completion: nil)) + try perform(OperationGroup(call: self, + operations: [.receiveStatusOnClient], + completion: completion != nil + ? { op in completion?(CallResult(op)) } + : nil)) + return } try perform(OperationGroup(call: self, operations: operations, @@ -258,7 +157,6 @@ public class Call { /// Parameter data: the message data to send /// - Throws: `CallError` if fails to call. `CallWarning` if blocked. public func sendMessage(data: Data, completion: ((Error?) -> Void)? = nil) throws { - messageQueueEmpty.enter() try sendMutex.synchronize { if writing { if let messageQueueMaxLength = Call.messageQueueMaxLength, @@ -269,32 +167,30 @@ public class Call { } else { writing = true try sendWithoutBlocking(data: data, completion: completion) - } + } + messageQueueEmpty.enter() } } /// helper for sending queued messages private func sendWithoutBlocking(data: Data, completion: ((Error?) -> Void)?) throws { - try perform(OperationGroup(call: self, - operations: [.sendMessage(ByteBuffer(data: data))]) { operationGroup in - // TODO(timburks, danielalm): Is the `async` dispatch here needed, and/or should we call the completion handler - // and leave `messageQueueEmpty` in the `async` block as well? - self.messageDispatchQueue.async { - // Always enqueue the next message, even if sending this one failed. This ensures that all send completion - // handlers are called eventually. - self.sendMutex.synchronize { - // if there are messages pending, send the next one - if self.messageQueue.count > 0 { - let (nextMessage, nextCompletionHandler) = self.messageQueue.removeFirst() - do { - try self.sendWithoutBlocking(data: nextMessage, completion: nextCompletionHandler) - } catch (let callError) { - nextCompletionHandler?(callError) - } - } else { - // otherwise, we are finished writing - self.writing = false + try perform(OperationGroup( + call: self, + operations: [.sendMessage(ByteBuffer(data: data))]) { operationGroup in + // Always enqueue the next message, even if sending this one failed. This ensures that all send completion + // handlers are called eventually. + self.sendMutex.synchronize { + // if there are messages pending, send the next one + if self.messageQueue.count > 0 { + let (nextMessage, nextCompletionHandler) = self.messageQueue.removeFirst() + do { + try self.sendWithoutBlocking(data: nextMessage, completion: nextCompletionHandler) + } catch { + nextCompletionHandler?(error) } + } else { + // otherwise, we are finished writing + self.writing = false } } completion?(operationGroup.success ? nil : CallError.unknown) @@ -304,27 +200,17 @@ public class Call { // Receive a message over a streaming connection. /// - Throws: `CallError` if fails to call. - public func closeAndReceiveMessage(completion: @escaping (Data?) throws -> Void) throws { + public func closeAndReceiveMessage(completion: @escaping (CallResult) -> Void) throws { try perform(OperationGroup(call: self, operations: [.sendCloseFromClient, .receiveMessage]) { operationGroup in - if operationGroup.success { - if let messageBuffer = operationGroup.receivedMessage() { - try completion(messageBuffer.data()) - } else { - try completion(nil) // an empty response signals the end of a connection - } - } + completion(CallResult(operationGroup)) }) } // Receive a message over a streaming connection. /// - Throws: `CallError` if fails to call. - public func receiveMessage(completion: @escaping (Data?) throws -> Void) throws { + public func receiveMessage(completion: @escaping (CallResult) -> Void) throws { try perform(OperationGroup(call: self, operations: [.receiveMessage]) { operationGroup in - if operationGroup.success { - try completion(operationGroup.receivedMessage()?.data()) - } else { - try completion(nil) - } + completion(CallResult(operationGroup)) }) } diff --git a/Sources/SwiftGRPC/Core/CallError.swift b/Sources/SwiftGRPC/Core/CallError.swift new file mode 100644 index 000000000..3195aa781 --- /dev/null +++ b/Sources/SwiftGRPC/Core/CallError.swift @@ -0,0 +1,76 @@ +/* + * Copyright 2016, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if SWIFT_PACKAGE +import CgRPC +import Dispatch +#endif +import Foundation + +public enum CallError: Error { + case ok + case unknown + case notOnServer + case notOnClient + case alreadyAccepted + case alreadyInvoked + case notInvoked + case alreadyFinished + case tooManyOperations + case invalidFlags + case invalidMetadata + case invalidMessage + case notServerCompletionQueue + case batchTooBig + case payloadTypeMismatch + + static func callError(grpcCallError error: grpc_call_error) -> CallError { + switch error { + case GRPC_CALL_OK: + return .ok + case GRPC_CALL_ERROR: + return .unknown + case GRPC_CALL_ERROR_NOT_ON_SERVER: + return .notOnServer + case GRPC_CALL_ERROR_NOT_ON_CLIENT: + return .notOnClient + case GRPC_CALL_ERROR_ALREADY_ACCEPTED: + return .alreadyAccepted + case GRPC_CALL_ERROR_ALREADY_INVOKED: + return .alreadyInvoked + case GRPC_CALL_ERROR_NOT_INVOKED: + return .notInvoked + case GRPC_CALL_ERROR_ALREADY_FINISHED: + return .alreadyFinished + case GRPC_CALL_ERROR_TOO_MANY_OPERATIONS: + return .tooManyOperations + case GRPC_CALL_ERROR_INVALID_FLAGS: + return .invalidFlags + case GRPC_CALL_ERROR_INVALID_METADATA: + return .invalidMetadata + case GRPC_CALL_ERROR_INVALID_MESSAGE: + return .invalidMessage + case GRPC_CALL_ERROR_NOT_SERVER_COMPLETION_QUEUE: + return .notServerCompletionQueue + case GRPC_CALL_ERROR_BATCH_TOO_BIG: + return .batchTooBig + case GRPC_CALL_ERROR_PAYLOAD_TYPE_MISMATCH: + return .payloadTypeMismatch + default: + return .unknown + } + } +} + diff --git a/Sources/SwiftGRPC/Core/CallResult.swift b/Sources/SwiftGRPC/Core/CallResult.swift new file mode 100644 index 000000000..9e8505a03 --- /dev/null +++ b/Sources/SwiftGRPC/Core/CallResult.swift @@ -0,0 +1,76 @@ +/* + * Copyright 2016, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if SWIFT_PACKAGE +import CgRPC +import Dispatch +#endif +import Foundation + +public struct CallResult: CustomStringConvertible { + public let success: Bool + public let statusCode: StatusCode + public let statusMessage: String? + public let resultData: Data? + public let initialMetadata: Metadata? + public let trailingMetadata: Metadata? + + init(_ op: OperationGroup) { + success = op.success + if let statusCodeRawValue = op.receivedStatusCode(), + let statusCode = StatusCode(rawValue: statusCodeRawValue) { + self.statusCode = statusCode + } else { + statusCode = .unknown + } + statusMessage = op.receivedStatusMessage() + resultData = op.receivedMessage()?.data() + initialMetadata = op.receivedInitialMetadata() + trailingMetadata = op.receivedTrailingMetadata() + } + + fileprivate init(success: Bool, statusCode: StatusCode, statusMessage: String?, resultData: Data?, + initialMetadata: Metadata?, trailingMetadata: Metadata?) { + self.success = success + self.statusCode = statusCode + self.statusMessage = statusMessage + self.resultData = resultData + self.initialMetadata = initialMetadata + self.trailingMetadata = trailingMetadata + } + + public var description: String { + var result = "\(success ? "successful" : "unsuccessful"), status \(statusCode)" + if let statusMessage = self.statusMessage { + result += ": " + statusMessage + } + if let resultData = self.resultData { + result += "\nresultData: " + result += resultData.description + } + if let initialMetadata = self.initialMetadata { + result += "\ninitialMetadata: " + result += initialMetadata.description + } + if let trailingMetadata = self.trailingMetadata { + result += "\ntrailingMetadata: " + result += trailingMetadata.description + } + return result + } + + static let fakeOK = CallResult(success: true, statusCode: .ok, statusMessage: "OK", resultData: nil, + initialMetadata: nil, trailingMetadata: nil) +} diff --git a/Sources/SwiftGRPC/Core/CompletionQueue.swift b/Sources/SwiftGRPC/Core/CompletionQueue.swift index 150df4790..0a34fe636 100644 --- a/Sources/SwiftGRPC/Core/CompletionQueue.swift +++ b/Sources/SwiftGRPC/Core/CompletionQueue.swift @@ -66,6 +66,8 @@ class CompletionQueue { /// Mutex for synchronizing access to operationGroups private let operationGroupsMutex: Mutex = Mutex() + + private var hasBeenShutdown = false /// Initializes a CompletionQueue /// @@ -90,21 +92,27 @@ class CompletionQueue { /// - Parameter operationGroup: the operation group to handle func register(_ operationGroup: OperationGroup) { operationGroupsMutex.synchronize { - operationGroups[operationGroup.tag] = operationGroup + if !hasBeenShutdown { + operationGroups[operationGroup.tag] = operationGroup + } else { + // The queue has been shut down already, so there's no spinloop to call the operation group's completion handler + // on. To guarantee that the completion handler gets called, we'll enqueue it right now. + DispatchQueue.global().async { + operationGroup.success = false + operationGroup.completion?(operationGroup) + } + } } } /// Runs a completion queue and call a completion handler when finished /// - /// - Parameter callbackQueue: a DispatchQueue to use to call the completion handler /// - Parameter completion: a completion handler that is called when the queue stops running - func runToCompletion(callbackQueue: DispatchQueue = DispatchQueue.main, - completion: (() -> Void)?) { + func runToCompletion(completion: (() -> Void)?) { // run the completion queue on a new background thread DispatchQueue.global().async { - var running = true - while running { - let event = cgrpc_completion_queue_get_next_event(self.underlyingCompletionQueue, -1.0) + spinloop: while true { + let event = cgrpc_completion_queue_get_next_event(self.underlyingCompletionQueue, 600) switch event.type { case GRPC_OP_COMPLETE: let tag = cgrpc_event_tag(event) @@ -113,43 +121,35 @@ class CompletionQueue { self.operationGroupsMutex.unlock() if let operationGroup = operationGroup { // call the operation group completion handler - do { - operationGroup.success = (event.success == 1) - try operationGroup.completion?(operationGroup) - } catch (let callError) { - print("CompletionQueue runToCompletion: grpc error \(callError)") - } + operationGroup.success = (event.success == 1) + operationGroup.completion?(operationGroup) self.operationGroupsMutex.synchronize { self.operationGroups[tag] = nil } + } else { + print("CompletionQueue.runToCompletion error: operation group with tag \(tag) not found") } - break case GRPC_QUEUE_SHUTDOWN: - running = false - do { - for operationGroup in self.operationGroups.values { - operationGroup.success = false - try operationGroup.completion?(operationGroup) - } - } catch (let callError) { - print("CompletionQueue runToCompletion: grpc error \(callError)") - } - self.operationGroupsMutex.synchronize { - self.operationGroups = [:] + self.operationGroupsMutex.lock() + let currentOperationGroups = self.operationGroups + self.operationGroups = [:] + self.hasBeenShutdown = true + self.operationGroupsMutex.unlock() + + for operationGroup in currentOperationGroups.values { + operationGroup.success = false + operationGroup.completion?(operationGroup) } - break + break spinloop case GRPC_QUEUE_TIMEOUT: - break + continue spinloop default: - break - } - } - if let completion = completion { - callbackQueue.async { - // when the queue stops running, call the queue completion handler - completion() + print("CompletionQueue.runToCompletion error: unknown event type \(event.type)") + break spinloop } } + // when the queue stops running, call the queue completion handler + completion?() } } diff --git a/Sources/SwiftGRPC/Core/Handler.swift b/Sources/SwiftGRPC/Core/Handler.swift index 90b6f6d0d..55d2f7390 100644 --- a/Sources/SwiftGRPC/Core/Handler.swift +++ b/Sources/SwiftGRPC/Core/Handler.swift @@ -30,7 +30,7 @@ public class Handler { public let requestMetadata: Metadata /// A Call object that can be used to respond to the request - private(set) lazy var call: Call = { + public private(set) lazy var call: Call = { Call(underlyingCall: cgrpc_handler_get_call(self.underlyingHandler), owned: false, completionQueue: self.completionQueue) @@ -83,139 +83,72 @@ public class Handler { throw CallError.callError(grpcCallError: error) } } - - /// Receive the message sent with a call - /// - public func receiveMessage(initialMetadata: Metadata, - completion: @escaping (Data?) throws -> Void) throws { - let operations = OperationGroup(call: call, - operations: [ - .sendInitialMetadata(initialMetadata), - .receiveMessage - ]) { operationGroup in - if operationGroup.success { - try completion(operationGroup.receivedMessage()?.data()) - } else { - try completion(nil) - } - } - try call.perform(operations) - } - - /// Sends the response to a request - /// - /// - Parameter message: the message to send - /// - Parameter statusCode: status code to send - /// - Parameter statusMessage: status message to send - /// - Parameter trailingMetadata: trailing metadata to send - public func sendResponse(message: Data, - statusCode: StatusCode, - statusMessage: String, - trailingMetadata: Metadata) throws { - let messageBuffer = ByteBuffer(data: message) - let operations = OperationGroup(call: call, - operations: [ - .receiveCloseOnServer, - .sendStatusFromServer(statusCode, statusMessage, trailingMetadata), - .sendMessage(messageBuffer) - ]) { operationGroup in - self.shutdown() - } - try call.perform(operations) - } - - /// Sends the response to a request - /// - /// - Parameter statusCode: status code to send - /// - Parameter statusMessage: status message to send - /// - Parameter trailingMetadata: trailing metadata to send - public func sendResponse(statusCode: StatusCode, - statusMessage: String, - trailingMetadata: Metadata) throws { - let operations = OperationGroup(call: call, - operations: [ - .receiveCloseOnServer, - .sendStatusFromServer(statusCode, statusMessage, trailingMetadata) - ]) { operationGroup in - self.shutdown() - } - try call.perform(operations) - } - + /// Shuts down the handler's completion queue public func shutdown() { completionQueue.shutdown() } - + /// Send initial metadata in response to a connection /// /// - Parameter initialMetadata: initial metadata to send /// - Parameter completion: a completion handler to call after the metadata has been sent public func sendMetadata(initialMetadata: Metadata, - completion: ((Bool) throws -> Void)? = nil) throws { - let operations = OperationGroup(call: call, - operations: [.sendInitialMetadata(initialMetadata)], - completion: completion != nil - ? { operationGroup in try completion?(operationGroup.success) } - : nil) - try call.perform(operations) + completion: ((Bool) -> Void)? = nil) throws { + try call.perform(OperationGroup( + call: call, + operations: [.sendInitialMetadata(initialMetadata)], + completion: completion != nil + ? { operationGroup in completion?(operationGroup.success) } + : nil)) } - + /// Receive the message sent with a call /// - /// - Parameter completion: a completion handler to call after the message has been received - /// - Returns: a tuple containing status codes and a message (if available) - public func receiveMessage(completion: @escaping (Data?) throws -> Void) throws { - try call.receiveMessage(completion: completion) - } - - /// Sends the response to a request - /// - /// - Parameter message: the message to send - /// - Parameter completion: a completion handler to call after the response has been sent - public func sendResponse(message: Data, - completion: ((Error?) -> Void)? = nil) throws { - try call.sendMessage(data: message, completion: completion) - } - - /// Recognize when the client has closed a request - /// - /// - Parameter completion: a completion handler to call after request has been closed - public func receiveClose(completion: @escaping (Bool) throws -> Void) throws { - let operations = OperationGroup(call: call, - operations: [.receiveCloseOnServer]) { operationGroup in - try completion(operationGroup.success) - } - try call.perform(operations) + public func receiveMessage(initialMetadata: Metadata, + completion: @escaping (Data?) -> Void) throws { + try call.perform(OperationGroup( + call: call, + operations: [ + .sendInitialMetadata(initialMetadata), + .receiveMessage + ]) { operationGroup in + if operationGroup.success { + completion(operationGroup.receivedMessage()?.data()) + } else { + completion(nil) + } + }) } - /// Send final status to the client - /// - /// - Parameter statusCode: status code to send - /// - Parameter statusMessage: status message to send - /// - Parameter trailingMetadata: trailing metadata to send - /// - Parameter completion: a completion handler to call after the status has been sent - public func sendStatus(statusCode: StatusCode, - statusMessage: String, - trailingMetadata: Metadata, - completion: ((Bool) -> Void)? = nil) throws { - let operations = OperationGroup(call: call, - operations: [ - .sendStatusFromServer(statusCode, - statusMessage, - trailingMetadata) - ]) { operationGroup in - completion?(operationGroup.success) + /// Sends the response to a request. + /// The completion handler does not take an argument because operations containing `.receiveCloseOnServer` always succeed. + public func sendResponse(message: Data, status: ServerStatus, + completion: (() -> Void)? = nil) throws { + let messageBuffer = ByteBuffer(data: message) + try call.perform(OperationGroup( + call: call, + operations: [ + .sendMessage(messageBuffer), + .receiveCloseOnServer, + .sendStatusFromServer(status.code, status.message, status.trailingMetadata) + ]) { _ in + completion?() self.shutdown() - } - try call.perform(operations) + }) } -} -extension Handler: Hashable { - public var hashValue: Int { return underlyingHandler.hashValue } - - public static func ==(A: Handler, B: Handler) -> Bool { - return A === B + /// Send final status to the client. + /// The completion handler does not take an argument because operations containing `.receiveCloseOnServer` always succeed. + public func sendStatus(_ status: ServerStatus, completion: (() -> Void)? = nil) throws { + try call.perform(OperationGroup( + call: call, + operations: [ + .receiveCloseOnServer, + .sendStatusFromServer(status.code, status.message, status.trailingMetadata) + ]) { _ in + completion?() + self.shutdown() + }) } } diff --git a/Sources/SwiftGRPC/Core/OperationGroup.swift b/Sources/SwiftGRPC/Core/OperationGroup.swift index 1c4cfdd3d..96a4c05f2 100644 --- a/Sources/SwiftGRPC/Core/OperationGroup.swift +++ b/Sources/SwiftGRPC/Core/OperationGroup.swift @@ -29,10 +29,11 @@ class OperationGroup { let tag: Int64 /// The call associated with the operation group. Retained while the operations are running. + // FIXME(danielalm): Is this property needed? private let call: Call /// An array of operation objects that are passed into the initializer. - private let operations: [Operation] + let operations: [Operation] /// An array of observers used to watch the operation private var underlyingObservers: [UnsafeMutableRawPointer] = [] @@ -41,10 +42,13 @@ class OperationGroup { let underlyingOperations: UnsafeMutableRawPointer? /// Completion handler that is called when the group completes - let completion: ((OperationGroup) throws -> Void)? + let completion: ((OperationGroup) -> Void)? /// Indicates that the OperationGroup completed successfully var success = false + + fileprivate var cachedInitialMetadata: Metadata? + fileprivate var cachedTrailingMetadata: Metadata? /// Creates the underlying observer needed to run an operation /// @@ -81,7 +85,7 @@ class OperationGroup { /// - Parameter operations: an array of operations init(call: Call, operations: [Operation], - completion: ((OperationGroup) throws -> Void)? = nil) { + completion: ((OperationGroup) -> Void)? = nil) { self.call = call self.operations = operations self.completion = completion @@ -107,7 +111,8 @@ class OperationGroup { cgrpc_operations_destroy(underlyingOperations) } - /// WARNING: The following assumes that at most one operation of each type is in the group. + /// WARNING: The following assumes that at most one operation of each type is in the group, + /// and these methods must ONLY be called after the operation has been returned to a completion queue. /// Gets the message that was received /// @@ -131,10 +136,15 @@ class OperationGroup { /// /// - Returns: metadata func receivedInitialMetadata() -> Metadata? { + if let cachedInitialMetadata = self.cachedInitialMetadata { + return cachedInitialMetadata + } for (i, operation) in operations.enumerated() { switch operation { case .receiveInitialMetadata: - return Metadata(underlyingArray: cgrpc_observer_recv_initial_metadata_get_metadata(underlyingObservers[i])) + cachedInitialMetadata = Metadata( + underlyingArray: cgrpc_observer_recv_initial_metadata_get_metadata(underlyingObservers[i])) + return cachedInitialMetadata! default: continue } @@ -164,7 +174,7 @@ class OperationGroup { for (i, operation) in operations.enumerated() { switch operation { case .receiveStatusOnClient: - // We actually know that this method will never return nil, so we can forcibly the result. (Also below.) + // We actually know that this method will never return nil, so we can forcibly unwrap the result. (Also below.) let string = cgrpc_observer_recv_status_on_client_copy_status_details(underlyingObservers[i])! defer { cgrpc_free_copied_string(string) } return String(cString: string, encoding: String.Encoding.utf8) @@ -179,10 +189,15 @@ class OperationGroup { /// /// - Returns: metadata func receivedTrailingMetadata() -> Metadata? { + if let cachedTrailingMetadata = self.cachedTrailingMetadata { + return cachedTrailingMetadata + } for (i, operation) in operations.enumerated() { switch operation { case .receiveStatusOnClient: - return Metadata(underlyingArray: cgrpc_observer_recv_status_on_client_get_metadata(underlyingObservers[i])) + cachedTrailingMetadata = Metadata( + underlyingArray: cgrpc_observer_recv_status_on_client_get_metadata(underlyingObservers[i])) + return cachedTrailingMetadata! default: continue } diff --git a/Sources/SwiftGRPC/Core/Server.swift b/Sources/SwiftGRPC/Core/Server.swift index 511bb929f..80f78f786 100644 --- a/Sources/SwiftGRPC/Core/Server.swift +++ b/Sources/SwiftGRPC/Core/Server.swift @@ -21,18 +21,18 @@ import Foundation /// gRPC Server public class Server { + static let handlerCallTag = 101 + + // These are sent by the CgRPC shim. + static let stopTag = 0 + static let destroyTag = 1000 + /// Pointer to underlying C representation private let underlyingServer: UnsafeMutableRawPointer /// Completion queue used for server operations let completionQueue: CompletionQueue - /// Active handlers - private var handlers = Set() - - /// Mutex for synchronizing access to handlers - private let handlersMutex: Mutex = Mutex() - /// Optional callback when server stops serving public var onCompletion: (() -> Void)? @@ -70,32 +70,33 @@ public class Server { while running { do { let handler = Handler(underlyingServer: self.underlyingServer) - try handler.requestCall(tag: 101) + try handler.requestCall(tag: Server.handlerCallTag) // block while waiting for an incoming request let event = self.completionQueue.wait(timeout: 600) if event.type == .complete { - if event.tag == 101 { + if event.tag == Server.handlerCallTag { // run the handler and remove it when it finishes if event.success != 0 { // hold onto the handler while it runs - self.handlersMutex.synchronize { - self.handlers.insert(handler) - } + var strongHandlerReference: Handler? + strongHandlerReference = handler + // To prevent the "Variable 'strongHandlerReference' was written to, but never read" warning. + _ = strongHandlerReference // this will start the completion queue on a new thread - handler.completionQueue.runToCompletion(callbackQueue: dispatchQueue) { + handler.completionQueue.runToCompletion { dispatchQueue.async { - self.handlersMutex.synchronize { - // release the handler when it finishes - self.handlers.remove(handler) - } + // release the handler when it finishes + strongHandlerReference = nil } } - // call the handler function on the server thread - handlerFunction(handler) + dispatchQueue.async { + // dispatch the handler function on a separate thread + handlerFunction(handler) + } } - } else if event.tag == 0 || event.tag == 1000 { + } else if event.tag == Server.stopTag || event.tag == Server.destroyTag { running = false // exit the loop } } else if event.type == .queueTimeout { @@ -103,8 +104,8 @@ public class Server { } else if event.type == .queueShutdown { running = false } - } catch (let callError) { - print("server call error: \(callError)") + } catch { + print("server call error: \(error)") running = false } } diff --git a/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift b/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift index ff16f0f4e..9b2fd3d11 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift @@ -25,7 +25,10 @@ public protocol ClientCallBidirectionalStreaming: ClientCall { // as the protocol would then have an associated type requirement (and become pretty much unusable in the process). } -open class ClientCallBidirectionalStreamingBase: ClientCallBase, ClientCallBidirectionalStreaming { +open class ClientCallBidirectionalStreamingBase: ClientCallBase, ClientCallBidirectionalStreaming, StreamReceiving, StreamSending { + public typealias ReceivedType = OutputType + public typealias SentType = InputType + /// Call this to start a call. Nonblocking. public func start(metadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Self { @@ -33,45 +36,6 @@ open class ClientCallBidirectionalStreamingBase Void) throws { - do { - try call.receiveMessage { data in - if let data = data { - if let returnMessage = try? OutputType(serializedData: data) { - completion(returnMessage, nil) - } else { - completion(nil, .invalidMessageReceived) - } - } else { - completion(nil, .endOfStream) - } - } - } - } - - public func receive() throws -> OutputType { - var returnError: ClientError? - var returnMessage: OutputType! - let sem = DispatchSemaphore(value: 0) - do { - try receive { response, error in - returnMessage = response - returnError = error - sem.signal() - } - _ = sem.wait() - } - if let returnError = returnError { - throw returnError - } - return returnMessage - } - - public func send(_ message: InputType, completion: @escaping (Error?) -> Void) throws { - let messageData = try message.serializedData() - try call.sendMessage(data: messageData, completion: completion) - } - public func closeSend(completion: (() -> Void)?) throws { try call.close(completion: completion) } @@ -83,10 +47,6 @@ open class ClientCallBidirectionalStreamingBase Void) throws { - if let output = outputs.first { - outputs.removeFirst() - completion(output, nil) - } else { - completion(nil, .endOfStream) - } + open func receive() throws -> OutputType? { + defer { if !outputs.isEmpty { outputs.removeFirst() } } + return outputs.first } - - open func receive() throws -> OutputType { - if let output = outputs.first { - outputs.removeFirst() - return output - } else { - throw ClientError.endOfStream - } + + open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { + completion(.result(try self.receive())) } open func send(_ message: InputType, completion _: @escaping (Error?) -> Void) throws { inputs.append(message) } + + open func send(_ message: InputType) throws { + inputs.append(message) + } open func closeSend(completion: (() -> Void)?) throws { completion?() } diff --git a/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift b/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift index ccf48807d..db515e108 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift @@ -25,55 +25,40 @@ public protocol ClientCallClientStreaming: ClientCall { // as the protocol would then have an associated type requirement (and become pretty much unusable in the process). } -open class ClientCallClientStreamingBase: ClientCallBase, ClientCallClientStreaming { +open class ClientCallClientStreamingBase: ClientCallBase, ClientCallClientStreaming, StreamSending { + public typealias SentType = InputType + /// Call this to start a call. Nonblocking. public func start(metadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Self { try call.start(.clientStreaming, metadata: metadata, completion: completion) return self } - public func send(_ message: InputType, completion: @escaping (Error?) -> Void) throws { - let messageData = try message.serializedData() - try call.sendMessage(data: messageData, completion: completion) - } - - public func closeAndReceive(completion: @escaping (OutputType?, ClientError?) -> Void) throws { - do { - try call.closeAndReceiveMessage { responseData in - if let responseData = responseData, - let response = try? OutputType(serializedData: responseData) { - completion(response, nil) - } else { - completion(nil, .invalidMessageReceived) - } + public func closeAndReceive(completion: @escaping (ResultOrRPCError) -> Void) throws { + try call.closeAndReceiveMessage { callResult in + guard let responseData = callResult.resultData else { + completion(.error(.callError(callResult))); return + } + if let response = try? OutputType(serializedData: responseData) { + completion(.result(response)) + } else { + completion(.error(.invalidMessageReceived)) } - } catch (let error) { - throw error } } public func closeAndReceive() throws -> OutputType { - var returnError: ClientError? - var returnResponse: OutputType! + var result: ResultOrRPCError? let sem = DispatchSemaphore(value: 0) - do { - try closeAndReceive { response, error in - returnResponse = response - returnError = error - sem.signal() - } - _ = sem.wait() - } catch (let error) { - throw error + try closeAndReceive { + result = $0 + sem.signal() } - if let returnError = returnError { - throw returnError + _ = sem.wait() + switch result! { + case .result(let response): return response + case .error(let error): throw error } - return returnResponse - } - - public func waitForSendOperationsToFinish() { - call.messageQueueEmpty.wait() } } @@ -90,9 +75,13 @@ open class ClientCallClientStreamingTestStub Void) throws { inputs.append(message) } + + open func send(_ message: InputType) throws { + inputs.append(message) + } - open func closeAndReceive(completion: @escaping (OutputType?, ClientError?) -> Void) throws { - completion(output!, nil) + open func closeAndReceive(completion: @escaping (ResultOrRPCError) -> Void) throws { + completion(.result(output!)) } open func closeAndReceive() throws -> OutputType { diff --git a/Sources/SwiftGRPC/Runtime/ClientCallServerStreaming.swift b/Sources/SwiftGRPC/Runtime/ClientCallServerStreaming.swift index 91fdef0ec..862adca28 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCallServerStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCallServerStreaming.swift @@ -23,7 +23,9 @@ public protocol ClientCallServerStreaming: ClientCall { // as the protocol would then have an associated type requirement (and become pretty much unusable in the process). } -open class ClientCallServerStreamingBase: ClientCallBase, ClientCallServerStreaming { +open class ClientCallServerStreamingBase: ClientCallBase, ClientCallServerStreaming, StreamReceiving { + public typealias ReceivedType = OutputType + /// Call this once with the message to send. Nonblocking. public func start(request: InputType, metadata: Metadata, completion: ((CallResult) -> Void)?) throws -> Self { let requestData = try request.serializedData() @@ -33,40 +35,6 @@ open class ClientCallServerStreamingBase Void) throws { - do { - try call.receiveMessage { responseData in - if let responseData = responseData { - if let response = try? OutputType(serializedData: responseData) { - completion(response, nil) - } else { - completion(nil, .invalidMessageReceived) - } - } else { - completion(nil, .endOfStream) - } - } - } - } - - public func receive() throws -> OutputType { - var returnError: ClientError? - var returnResponse: OutputType! - let sem = DispatchSemaphore(value: 0) - do { - try receive { response, error in - returnResponse = response - returnError = error - sem.signal() - } - _ = sem.wait() - } - if let returnError = returnError { - throw returnError - } - return returnResponse - } } /// Simple fake implementation of ClientCallServerStreamingBase that returns a previously-defined set of results. @@ -76,23 +44,14 @@ open class ClientCallServerStreamingTestStub: ClientCallSer open var outputs: [OutputType] = [] public init() {} - - open func receive(completion: @escaping (OutputType?, ClientError?) -> Void) throws { - if let output = outputs.first { - outputs.removeFirst() - completion(output, nil) - } else { - completion(nil, .endOfStream) - } + + open func receive() throws -> OutputType? { + defer { if !outputs.isEmpty { outputs.removeFirst() } } + return outputs.first } - - open func receive() throws -> OutputType { - if let output = outputs.first { - outputs.removeFirst() - return output - } else { - throw ClientError.endOfStream - } + + open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { + completion(.result(try self.receive())) } open func cancel() {} diff --git a/Sources/SwiftGRPC/Runtime/ClientCallUnary.swift b/Sources/SwiftGRPC/Runtime/ClientCallUnary.swift index 8343d3845..b70441f5c 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCallUnary.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCallUnary.swift @@ -22,7 +22,7 @@ public protocol ClientCallUnary: ClientCall {} open class ClientCallUnaryBase: ClientCallBase, ClientCallUnary { /// Run the call. Blocks until the reply is received. - /// - Throws: `BinaryEncodingError` if encoding fails. `CallError` if fails to call. `ClientError` if receives no response. + /// - Throws: `BinaryEncodingError` if encoding fails. `CallError` if fails to call. `RPCError` if receives no response. public func run(request: InputType, metadata: Metadata) throws -> OutputType { let sem = DispatchSemaphore(value: 0) var returnCallResult: CallResult! @@ -36,7 +36,7 @@ open class ClientCallUnaryBase: ClientC if let returnResponse = returnResponse { return returnResponse } else { - throw ClientError.error(c: returnCallResult) + throw RPCError.callError(returnCallResult) } } @@ -47,9 +47,8 @@ open class ClientCallUnaryBase: ClientC completion: @escaping ((OutputType?, CallResult) -> Void)) throws -> Self { let requestData = try request.serializedData() try call.start(.unary, metadata: metadata, message: requestData) { callResult in - if let responseData = callResult.resultData, - let response = try? OutputType(serializedData: responseData) { - completion(response, callResult) + if let responseData = callResult.resultData { + completion(try? OutputType(serializedData: responseData), callResult) } else { completion(nil, callResult) } diff --git a/Sources/SwiftGRPC/Runtime/ClientError.swift b/Sources/SwiftGRPC/Runtime/RPCError.swift similarity index 53% rename from Sources/SwiftGRPC/Runtime/ClientError.swift rename to Sources/SwiftGRPC/Runtime/RPCError.swift index 4579a7e06..28ebfd42c 100644 --- a/Sources/SwiftGRPC/Runtime/ClientError.swift +++ b/Sources/SwiftGRPC/Runtime/RPCError.swift @@ -17,8 +17,39 @@ import Foundation /// Type for errors thrown from generated client code. -public enum ClientError: Error { - case endOfStream +public enum RPCError: Error { case invalidMessageReceived - case error(c: CallResult) + case callError(CallResult) } + +public extension RPCError { + var callResult: CallResult? { + switch self { + case .invalidMessageReceived: return nil + case .callError(let callResult): return callResult + } + } +} + + +public enum ResultOrRPCError { + case result(ResultType) + case error(RPCError) +} + +public extension ResultOrRPCError { + var result: ResultType? { + switch self { + case .result(let result): return result + case .error: return nil + } + } + + var error: RPCError? { + switch self { + case .result: return nil + case .error(let error): return error + } + } +} + diff --git a/Sources/SwiftGRPC/Runtime/ServerError.swift b/Sources/SwiftGRPC/Runtime/ServerError.swift deleted file mode 100644 index cb9513703..000000000 --- a/Sources/SwiftGRPC/Runtime/ServerError.swift +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2018, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import Foundation - -/// Type for errors thrown from generated server code. -public enum ServerError: Error { - case endOfStream -} diff --git a/Sources/SwiftGRPC/Runtime/ServerSession.swift b/Sources/SwiftGRPC/Runtime/ServerSession.swift index 15c26c68c..02cb8fa1d 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSession.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSession.swift @@ -18,34 +18,54 @@ import Dispatch import Foundation import SwiftProtobuf +public struct ServerStatus: Error { + public let code: StatusCode + public let message: String + public let trailingMetadata: Metadata + + public init(code: StatusCode, message: String, trailingMetadata: Metadata = Metadata()) { + self.code = code + self.message = message + self.trailingMetadata = trailingMetadata + } + + public static let ok = ServerStatus(code: .ok, message: "OK") + public static let processingError = ServerStatus(code: .internalError, message: "unknown error processing request") + public static let noRequestData = ServerStatus(code: .invalidArgument, message: "no request data received") + public static let sendingInitialMetadataFailed = ServerStatus(code: .internalError, message: "sending initial metadata failed") +} + public protocol ServerSession: class { var requestMetadata: Metadata { get } - var statusCode: StatusCode { get set } - var statusMessage: String { get set } var initialMetadata: Metadata { get set } - var trailingMetadata: Metadata { get set } + + func cancel() } open class ServerSessionBase: ServerSession { public var handler: Handler public var requestMetadata: Metadata { return handler.requestMetadata } - public var statusCode: StatusCode = .ok - public var statusMessage: String = "OK" public var initialMetadata: Metadata = Metadata() - public var trailingMetadata: Metadata = Metadata() + + public var call: Call { return handler.call } public init(handler: Handler) { self.handler = handler } + + public func cancel() { + call.cancel() + } } open class ServerSessionTestStub: ServerSession { open var requestMetadata = Metadata() - open var statusCode = StatusCode.ok - open var statusMessage = "OK" open var initialMetadata = Metadata() - open var trailingMetadata = Metadata() + + public init() {} + + open func cancel() {} } diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift b/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift index 3c4226fa7..6c655f419 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift @@ -22,7 +22,10 @@ public protocol ServerSessionBidirectionalStreaming: ServerSession { func waitForSendOperationsToFinish() } -open class ServerSessionBidirectionalStreamingBase: ServerSessionBase, ServerSessionBidirectionalStreaming { +open class ServerSessionBidirectionalStreamingBase: ServerSessionBase, ServerSessionBidirectionalStreaming, StreamReceiving, StreamSending { + public typealias ReceivedType = InputType + public typealias SentType = OutputType + public typealias ProviderBlock = (ServerSessionBidirectionalStreamingBase) throws -> Void private var providerBlock: ProviderBlock @@ -30,55 +33,33 @@ open class ServerSessionBidirectionalStreamingBase InputType { - let sem = DispatchSemaphore(value: 0) - var requestMessage: InputType? - try handler.receiveMessage { requestData in - if let requestData = requestData { - do { - requestMessage = try InputType(serializedData: requestData) - } catch (let error) { - print("error \(error)") - } - } - sem.signal() - } - _ = sem.wait() - if let requestMessage = requestMessage { - return requestMessage - } else { - throw ServerError.endOfStream - } - } - - public func send(_ response: OutputType, completion: ((Error?) -> Void)?) throws { - try handler.sendResponse(message: response.serializedData(), completion: completion) - } - - public func close() throws { - let sem = DispatchSemaphore(value: 0) - try handler.sendStatus(statusCode: statusCode, - statusMessage: statusMessage, - trailingMetadata: trailingMetadata) { _ in sem.signal() } - _ = sem.wait() - } - + public func run(queue: DispatchQueue) throws { - try handler.sendMetadata(initialMetadata: initialMetadata) { _ in + try handler.sendMetadata(initialMetadata: initialMetadata) { success in queue.async { - do { - try self.providerBlock(self) - } catch (let error) { - print("error \(error)") + var responseStatus: ServerStatus? + if success { + do { + try self.providerBlock(self) + } catch { + responseStatus = (error as? ServerStatus) ?? .processingError + } + } else { + print("ServerSessionBidirectionalStreamingBase.run sending initial metadata failed") + responseStatus = .sendingInitialMetadataFailed + } + + if let responseStatus = responseStatus { + // Error encountered, notify the client. + do { + try self.handler.sendStatus(responseStatus) + } catch { + print("ServerSessionBidirectionalStreamingBase.run error sending status: \(error)") + } } } } } - - public func waitForSendOperationsToFinish() { - handler.call.messageQueueEmpty.wait() - } } /// Simple fake implementation of ServerSessionBidirectionalStreaming that returns a previously-defined set of results @@ -86,21 +67,29 @@ open class ServerSessionBidirectionalStreamingBase: ServerSessionTestStub, ServerSessionBidirectionalStreaming { open var inputs: [InputType] = [] open var outputs: [OutputType] = [] + open var status: ServerStatus? - open func receive() throws -> InputType { - if let input = inputs.first { - inputs.removeFirst() - return input - } else { - throw ServerError.endOfStream - } + open func receive() throws -> InputType? { + defer { if !inputs.isEmpty { inputs.removeFirst() } } + return inputs.first + } + + open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { + completion(.result(try self.receive())) + } + + open func send(_ message: OutputType, completion _: @escaping (Error?) -> Void) throws { + outputs.append(message) } - open func send(_ response: OutputType, completion _: ((Error?) -> Void)?) throws { - outputs.append(response) + open func send(_ message: OutputType) throws { + outputs.append(message) } - open func close() throws {} + open func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws { + self.status = status + completion?() + } open func waitForSendOperationsToFinish() {} } diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift b/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift index f0a3a89e1..b1a943502 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift @@ -20,7 +20,9 @@ import SwiftProtobuf public protocol ServerSessionClientStreaming: ServerSession {} -open class ServerSessionClientStreamingBase: ServerSessionBase, ServerSessionClientStreaming { +open class ServerSessionClientStreamingBase: ServerSessionBase, ServerSessionClientStreaming, StreamReceiving { + public typealias ReceivedType = InputType + public typealias ProviderBlock = (ServerSessionClientStreamingBase) throws -> Void private var providerBlock: ProviderBlock @@ -28,37 +30,38 @@ open class ServerSessionClientStreamingBase InputType { - let sem = DispatchSemaphore(value: 0) - var requestMessage: InputType? - try handler.receiveMessage { requestData in - if let requestData = requestData { - requestMessage = try? InputType(serializedData: requestData) - } - sem.signal() - } - _ = sem.wait() - if requestMessage == nil { - throw ServerError.endOfStream - } - return requestMessage! + + public func sendAndClose(response: OutputType, status: ServerStatus = .ok, + completion: (() -> Void)? = nil) throws { + try handler.sendResponse(message: response.serializedData(), status: status, completion: completion) } - public func sendAndClose(_ response: OutputType) throws { - try handler.sendResponse(message: response.serializedData(), - statusCode: statusCode, - statusMessage: statusMessage, - trailingMetadata: trailingMetadata) + public func sendErrorAndClose(status: ServerStatus, completion: (() -> Void)? = nil) throws { + try handler.sendStatus(status, completion: completion) } - + public func run(queue: DispatchQueue) throws { - try handler.sendMetadata(initialMetadata: initialMetadata) { _ in + try handler.sendMetadata(initialMetadata: initialMetadata) { success in queue.async { - do { - try self.providerBlock(self) - } catch (let error) { - print("error \(error)") + var responseStatus: ServerStatus? + if success { + do { + try self.providerBlock(self) + } catch { + responseStatus = (error as? ServerStatus) ?? .processingError + } + } else { + print("ServerSessionClientStreamingBase.run sending initial metadata failed") + responseStatus = .sendingInitialMetadataFailed + } + + if let responseStatus = responseStatus { + // Error encountered, notify the client. + do { + try self.handler.sendStatus(responseStatus) + } catch { + print("ServerSessionClientStreamingBase.run error sending status: \(error)") + } } } } @@ -70,19 +73,27 @@ open class ServerSessionClientStreamingBase: ServerSessionTestStub, ServerSessionClientStreaming { open var inputs: [InputType] = [] open var output: OutputType? + open var status: ServerStatus? - open func receive() throws -> InputType { - if let input = inputs.first { - inputs.removeFirst() - return input - } else { - throw ServerError.endOfStream - } + open func receive() throws -> InputType? { + defer { if !inputs.isEmpty { inputs.removeFirst() } } + return inputs.first + } + + open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { + completion(.result(try self.receive())) } - open func sendAndClose(_ response: OutputType) throws { - output = response + open func sendAndClose(response: OutputType, status: ServerStatus, completion: (() -> Void)?) throws { + self.output = response + self.status = status + completion?() } + open func sendErrorAndClose(status: ServerStatus, completion: (() -> Void)? = nil) throws { + self.status = status + completion?() + } + open func close() throws {} } diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift b/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift index a9e4ec1c0..a02e3ff92 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift @@ -22,7 +22,9 @@ public protocol ServerSessionServerStreaming: ServerSession { func waitForSendOperationsToFinish() } -open class ServerSessionServerStreamingBase: ServerSessionBase, ServerSessionServerStreaming { +open class ServerSessionServerStreamingBase: ServerSessionBase, ServerSessionServerStreaming, StreamSending { + public typealias SentType = OutputType + public typealias ProviderBlock = (InputType, ServerSessionServerStreamingBase) throws -> Void private var providerBlock: ProviderBlock @@ -30,52 +32,54 @@ open class ServerSessionServerStreamingBase Void)?) throws { - try handler.sendResponse(message: response.serializedData(), completion: completion) - } - + public func run(queue: DispatchQueue) throws { try handler.receiveMessage(initialMetadata: initialMetadata) { requestData in - // TODO(danielalm): Unify this behavior with `ServerSessionBidirectionalStreamingBase.run()`. - if let requestData = requestData { - do { - let requestMessage = try InputType(serializedData: requestData) - // to keep providers from blocking the server thread, - // we dispatch them to another queue. - queue.async { - do { - try self.providerBlock(requestMessage, self) - try self.handler.sendStatus(statusCode: self.statusCode, - statusMessage: self.statusMessage, - trailingMetadata: self.trailingMetadata, - completion: nil) - } catch (let error) { - print("error: \(error)") - } + queue.async { + var responseStatus: ServerStatus? + if let requestData = requestData { + do { + let requestMessage = try InputType(serializedData: requestData) + try self.providerBlock(requestMessage, self) + } catch { + responseStatus = (error as? ServerStatus) ?? .processingError + } + } else { + print("ServerSessionServerStreamingBase.run empty request data") + responseStatus = .noRequestData + } + + if let responseStatus = responseStatus { + // Error encountered, notify the client. + do { + try self.handler.sendStatus(responseStatus) + } catch { + print("ServerSessionServerStreamingBase.run error sending status: \(error)") } - } catch (let error) { - print("error: \(error)") } } } } - - public func waitForSendOperationsToFinish() { - handler.call.messageQueueEmpty.wait() - } } /// Simple fake implementation of ServerSessionServerStreaming that returns a previously-defined set of results /// and stores sent values for later verification. open class ServerSessionServerStreamingTestStub: ServerSessionTestStub, ServerSessionServerStreaming { open var outputs: [OutputType] = [] + open var status: ServerStatus? - open func send(_ response: OutputType, completion _: ((Error?) -> Void)?) throws { - outputs.append(response) + open func send(_ message: OutputType, completion _: @escaping (Error?) -> Void) throws { + outputs.append(message) } - open func close() throws {} + open func send(_ message: OutputType) throws { + outputs.append(message) + } + + open func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws { + self.status = status + completion?() + } open func waitForSendOperationsToFinish() {} } diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionUnary.swift b/Sources/SwiftGRPC/Runtime/ServerSessionUnary.swift index 2d9895929..8eb9ed317 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionUnary.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionUnary.swift @@ -21,6 +21,8 @@ import SwiftProtobuf public protocol ServerSessionUnary: ServerSession {} open class ServerSessionUnaryBase: ServerSessionBase, ServerSessionUnary { + public typealias SentType = OutputType + public typealias ProviderBlock = (InputType, ServerSessionUnaryBase) throws -> OutputType private var providerBlock: ProviderBlock @@ -28,16 +30,34 @@ open class ServerSessionUnaryBase: Serv self.providerBlock = providerBlock super.init(handler: handler) } - - public func run(queue _: DispatchQueue) throws { + + public func run(queue: DispatchQueue) throws { try handler.receiveMessage(initialMetadata: initialMetadata) { requestData in - if let requestData = requestData { - let requestMessage = try InputType(serializedData: requestData) - let replyMessage = try self.providerBlock(requestMessage, self) - try self.handler.sendResponse(message: replyMessage.serializedData(), - statusCode: self.statusCode, - statusMessage: self.statusMessage, - trailingMetadata: self.trailingMetadata) + queue.async { + let responseStatus: ServerStatus + if let requestData = requestData { + do { + let requestMessage = try InputType(serializedData: requestData) + let responseMessage = try self.providerBlock(requestMessage, self) + try self.handler.call.sendMessage(data: responseMessage.serializedData()) { + guard let error = $0 + else { return } + print("ServerSessionUnaryBase.run error sending response: \(error)") + } + responseStatus = .ok + } catch { + responseStatus = (error as? ServerStatus) ?? .processingError + } + } else { + print("ServerSessionUnaryBase.run empty request data") + responseStatus = .noRequestData + } + + do { + try self.handler.sendStatus(responseStatus) + } catch { + print("ServerSessionUnaryBase.run error sending status: \(error)") + } } } } diff --git a/Sources/SwiftGRPC/Runtime/ServiceServer.swift b/Sources/SwiftGRPC/Runtime/ServiceServer.swift index 992967e73..ba22dd48e 100644 --- a/Sources/SwiftGRPC/Runtime/ServiceServer.swift +++ b/Sources/SwiftGRPC/Runtime/ServiceServer.swift @@ -65,17 +65,24 @@ open class ServiceServer { + " calling " + unwrappedMethod + " from " + unwrappedCaller + " with " + handler.requestMetadata.description) - + do { - if try !strongSelf.handleMethod(unwrappedMethod, handler: handler, queue: queue) { - // handle unknown requests - try handler.receiveMessage(initialMetadata: Metadata()) { _ in - try handler.sendResponse(statusCode: .unimplemented, - statusMessage: "unknown method " + unwrappedMethod, - trailingMetadata: Metadata()) + if !(try strongSelf.handleMethod(unwrappedMethod, handler: handler, queue: queue)) { + do { + try handler.call.perform(OperationGroup( + call: handler.call, + operations: [ + .sendInitialMetadata(Metadata()), + .receiveCloseOnServer, + .sendStatusFromServer(.unimplemented, "unknown method " + unwrappedMethod, Metadata()) + ]) { _ in + handler.shutdown() + }) + } catch { + print("ServiceServer.start error sending status for unknown method: \(error)") } } - } catch (let error) { + } catch { print("Server error: \(error)") } } diff --git a/Sources/SwiftGRPC/Runtime/StreamReceiving.swift b/Sources/SwiftGRPC/Runtime/StreamReceiving.swift new file mode 100644 index 000000000..15710bf7c --- /dev/null +++ b/Sources/SwiftGRPC/Runtime/StreamReceiving.swift @@ -0,0 +1,59 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Dispatch +import Foundation +import SwiftProtobuf + +public protocol StreamReceiving { + associatedtype ReceivedType: Message + + var call: Call { get } +} + +extension StreamReceiving { + public func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { + try call.receiveMessage { callResult in + guard let responseData = callResult.resultData else { + if callResult.success { + completion(.result(nil)) + } else { + completion(.error(.callError(callResult))) + } + return + } + if let response = try? ReceivedType(serializedData: responseData) { + completion(.result(response)) + } else { + completion(.error(.invalidMessageReceived)) + } + } + } + + public func receive() throws -> ReceivedType? { + var result: ResultOrRPCError? + let sem = DispatchSemaphore(value: 0) + try receive { + result = $0 + sem.signal() + } + _ = sem.wait() + switch result! { + case .result(let response): return response + case .error(let error): throw error + } + } +} diff --git a/Sources/SwiftGRPC/Runtime/StreamSending.swift b/Sources/SwiftGRPC/Runtime/StreamSending.swift new file mode 100644 index 000000000..e0ca459cc --- /dev/null +++ b/Sources/SwiftGRPC/Runtime/StreamSending.swift @@ -0,0 +1,54 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Dispatch +import Foundation +import SwiftProtobuf + +public protocol StreamSending { + associatedtype SentType: Message + + var call: Call { get } +} + +extension StreamSending { + public func send(_ message: SentType, completion: @escaping (Error?) -> Void) throws { + try call.sendMessage(data: message.serializedData(), completion: completion) + } + + public func send(_ message: SentType) throws { + var resultError: Error? + let sem = DispatchSemaphore(value: 0) + try send(message) { + resultError = $0 + sem.signal() + } + _ = sem.wait() + if let resultError = resultError { + throw resultError + } + } + + public func waitForSendOperationsToFinish() { + call.messageQueueEmpty.wait() + } +} + +extension StreamSending where Self: ServerSessionBase { + public func close(withStatus status: ServerStatus = .ok, completion: (() -> Void)? = nil) throws { + try handler.sendStatus(status, completion: completion) + } +} diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Client.swift b/Sources/protoc-gen-swiftgrpc/Generator-Client.swift index fb59b4c63..b283fe1cd 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Client.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Client.swift @@ -18,7 +18,6 @@ import SwiftProtobuf import SwiftProtobufPluginLibrary extension Generator { - internal func printClient() { for method in service.methods { self.method = method @@ -57,10 +56,7 @@ extension Generator { private func printServiceClientMethodCallServerStreaming() { println("\(access) protocol \(callName): ClientCallServerStreaming {") indent() - println("/// Call this to wait for a result. Blocking.") - println("func receive() throws -> \(methodOutputName)") - println("/// Call this to wait for a result. Nonblocking.") - println("func receive(completion: @escaping (\(methodOutputName)?, ClientError?) -> Void) throws") + printStreamReceiveMethods(receivedType: methodOutputName) outdent() println("}") println() @@ -83,13 +79,12 @@ extension Generator { private func printServiceClientMethodCallClientStreaming() { println("\(options.visibility.sourceSnippet) protocol \(callName): ClientCallClientStreaming {") indent() - println("/// Call this to send each message in the request stream. Nonblocking.") - println("func send(_ message: \(methodInputName), completion: @escaping (Error?) -> Void) throws") + printStreamSendMethods(sentType: methodInputName) println() println("/// Call this to close the connection and wait for a response. Blocking.") println("func closeAndReceive() throws -> \(methodOutputName)") println("/// Call this to close the connection and wait for a response. Nonblocking.") - println("func closeAndReceive(completion: @escaping (\(methodOutputName)?, ClientError?) -> Void) throws") + println("func closeAndReceive(completion: @escaping (ResultOrRPCError<\(methodOutputName)>) -> Void) throws") outdent() println("}") println() @@ -114,13 +109,9 @@ extension Generator { private func printServiceClientMethodCallBidiStreaming() { println("\(access) protocol \(callName): ClientCallBidirectionalStreaming {") indent() - println("/// Call this to wait for a result. Blocking.") - println("func receive() throws -> \(methodOutputName)") - println("/// Call this to wait for a result. Nonblocking.") - println("func receive(completion: @escaping (\(methodOutputName)?, ClientError?) -> Void) throws") + printStreamReceiveMethods(receivedType: methodOutputName) println() - println("/// Call this to send each message in the request stream.") - println("func send(_ message: \(methodInputName), completion: @escaping (Error?) -> Void) throws") + printStreamSendMethods(sentType: methodInputName) println() println("/// Call this to close the sending connection. Blocking.") println("func closeSend() throws") diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift b/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift new file mode 100644 index 000000000..f32f66157 --- /dev/null +++ b/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift @@ -0,0 +1,34 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Foundation +import SwiftProtobuf +import SwiftProtobufPluginLibrary + +extension Generator { + func printStreamReceiveMethods(receivedType: String) { + println("/// Call this to wait for a result. Blocking.") + println("func receive() throws -> \(receivedType)?") + println("/// Call this to wait for a result. Nonblocking.") + println("func receive(completion: @escaping (ResultOrRPCError<\(receivedType)?>) -> Void) throws") + } + + func printStreamSendMethods(sentType: String) { + println("/// Send a message to the stream. Nonblocking.") + println("func send(_ message: \(sentType), completion: @escaping (Error?) -> Void) throws") + println("/// Send a message to the stream and wait for the send operation to finish. Blocking.") + println("func send(_ message: \(sentType)) throws") + } +} diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift index d5409867f..34171ef41 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift @@ -141,15 +141,23 @@ extension Generator { println("class \(methodSessionName)TestStub: ServerSessionUnaryTestStub, \(methodSessionName) {}") } } + + private func printServerMethodSendAndClose(sentType: String) { + println("/// You MUST call one of these two methods once you are done processing the request.") + println("/// Close the connection and send a single result. Non-blocking.") + println("func sendAndClose(response: \(sentType), status: ServerStatus, completion: (() -> Void)?) throws") + println("/// Close the connection and send an error. Non-blocking.") + println("/// Use this method if you encountered an error that makes it impossible to send a response.") + println("/// Accordingly, it does not make sense to call this method with a status of `.ok`.") + println("func sendErrorAndClose(status: ServerStatus, completion: (() -> Void)?) throws") + } private func printServerMethodClientStreaming() { println("\(access) protocol \(methodSessionName): ServerSessionClientStreaming {") indent() - println("/// Receive a message. Blocks until a message is received or the client closes the connection.") - println("func receive() throws -> \(methodInputName)") + printStreamReceiveMethods(receivedType: methodInputName) println() - println("/// Send a response and close the connection.") - println("func sendAndClose(_ response: \(methodOutputName)) throws") + printServerMethodSendAndClose(sentType: methodOutputName) outdent() println("}") println() @@ -160,11 +168,18 @@ extension Generator { } } + private func printServerMethodClose() { + println("/// Close the connection and send the status. Non-blocking.") + println("/// You MUST call this method once you are done processing the request.") + println("func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws") + } + private func printServerMethodServerStreaming() { println("\(access) protocol \(methodSessionName): ServerSessionServerStreaming {") indent() - println("/// Send a message. Nonblocking.") - println("func send(_ response: \(methodOutputName), completion: ((Error?) -> Void)?) throws") + printStreamSendMethods(sentType: methodOutputName) + println() + printServerMethodClose() outdent() println("}") println() @@ -178,14 +193,11 @@ extension Generator { private func printServerMethodBidirectional() { println("\(access) protocol \(methodSessionName): ServerSessionBidirectionalStreaming {") indent() - println("/// Receive a message. Blocks until a message is received or the client closes the connection.") - println("func receive() throws -> \(methodInputName)") + printStreamReceiveMethods(receivedType: methodInputName) println() - println("/// Send a message. Nonblocking.") - println("func send(_ response: \(methodOutputName), completion: ((Error?) -> Void)?) throws") + printStreamSendMethods(sentType: methodOutputName) println() - println("/// Close a connection. Blocks until the connection is closed.") - println("func close() throws") + printServerMethodClose() outdent() println("}") println() diff --git a/Sources/protoc-gen-swiftgrpc/main.swift b/Sources/protoc-gen-swiftgrpc/main.swift index 120135efb..404610f81 100644 --- a/Sources/protoc-gen-swiftgrpc/main.swift +++ b/Sources/protoc-gen-swiftgrpc/main.swift @@ -124,6 +124,6 @@ func main() throws { do { try main() -} catch (let error) { +} catch { Log("ERROR: \(error)") } diff --git a/Sources/protoc-gen-swiftgrpc/options.swift b/Sources/protoc-gen-swiftgrpc/options.swift index c4ccc0792..083e42103 100644 --- a/Sources/protoc-gen-swiftgrpc/options.swift +++ b/Sources/protoc-gen-swiftgrpc/options.swift @@ -99,9 +99,9 @@ class GeneratorOptions { } // Creates key/value pair and trims whitespace - let key = string.substring(to: index) + let key = string[.. Echo_EchoProvider { return EchoProvider() } + + var defaultTimeout: TimeInterval { return 1.0 } + + var provider: Echo_EchoProvider! + var server: Echo_EchoServer! + var client: Echo_EchoServiceClient! + + var secure: Bool { return false } + + override func setUp() { + super.setUp() + + provider = makeProvider() + + let address = "localhost:5050" + if secure { + let certificateString = String(data: certificateForTests, encoding: .utf8)! + server = Echo_EchoServer(address: address, + certificateString: certificateString, + keyString: String(data: keyForTests, encoding: .utf8)!, + provider: provider) + server.start(queue: DispatchQueue.global()) + client = Echo_EchoServiceClient(address: address, certificates: certificateString, host: "example.com") + client.host = "example.com" + } else { + server = Echo_EchoServer(address: address, provider: provider) + server.start(queue: DispatchQueue.global()) + client = Echo_EchoServiceClient(address: address, secure: false) + } + + client.timeout = defaultTimeout + } + + override func tearDown() { + client = nil + + server.server.stop() + server = nil + + super.tearDown() + } +} diff --git a/Tests/SwiftGRPCTests/ClientCancellingTests.swift b/Tests/SwiftGRPCTests/ClientCancellingTests.swift new file mode 100644 index 000000000..68591f432 --- /dev/null +++ b/Tests/SwiftGRPCTests/ClientCancellingTests.swift @@ -0,0 +1,117 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +class ClientCancellingTests: BasicEchoTestCase { + static var allTests: [(String, (ClientCancellingTests) -> () throws -> Void)] { + return [ + ("testUnary", testUnary), + ("testClientStreaming", testClientStreaming), + ("testServerStreaming", testServerStreaming), + ("testBidirectionalStreaming", testBidirectionalStreaming), + ] + } +} + +extension ClientCancellingTests { + func testUnary() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.get(Echo_EchoRequest(text: "foo bar baz")) { response, callResult in + XCTAssertNil(response) + XCTAssertEqual(.cancelled, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + call.cancel() + + waitForExpectations(timeout: defaultTimeout) + } + + func testClientStreaming() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.cancelled, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + call.cancel() + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in XCTAssertEqual(.unknown, $0 as! CallError); sendExpectation.fulfill() } + call.waitForSendOperationsToFinish() + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, received \(result) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testServerStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in + XCTAssertEqual(.cancelled, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + XCTAssertEqual("Swift echo expand (0): foo", try! call.receive()!.text) + + call.cancel() + + do { + let result = try call.receive() + XCTFail("should have thrown, received \(String(describing: result)) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testBidirectionalStreaming() { + let finalCompletionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.update { callResult in + XCTAssertEqual(.cancelled, callResult.statusCode) + finalCompletionHandlerExpectation.fulfill() + } + + var sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in XCTAssertNil($0); sendExpectation.fulfill() } + XCTAssertEqual("Swift echo update (0): foo", try! call.receive()!.text) + + call.cancel() + + sendExpectation = expectation(description: "send completion handler 2 called") + try! call.send(Echo_EchoRequest(text: "bar")) { [sendExpectation] in XCTAssertEqual(.unknown, $0 as! CallError); sendExpectation.fulfill() } + do { + let result = try call.receive() + XCTFail("should have thrown, received \(String(describing: result)) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + let closeCompletionHandlerExpectation = expectation(description: "close completion handler called") + try! call.closeSend { closeCompletionHandlerExpectation.fulfill() } + + waitForExpectations(timeout: defaultTimeout) + } +} diff --git a/Tests/SwiftGRPCTests/ClientTestExample.swift b/Tests/SwiftGRPCTests/ClientTestExample.swift new file mode 100644 index 000000000..c4deeffea --- /dev/null +++ b/Tests/SwiftGRPCTests/ClientTestExample.swift @@ -0,0 +1,141 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +// Sample test suite to demonstrate how one would test client code that +// uses an object that implements the `...Service` protocol. +// These tests don't really test the logic of the SwiftGRPC library, but are meant +// as an example of how one would go about testing their own client/server code that +// relies on SwiftGRPC. +fileprivate class ClientUnderTest { + let service: Echo_EchoService + + init(service: Echo_EchoService) { + self.service = service + } + + func getWord(_ input: String) throws -> String { + return try service.get(Echo_EchoRequest(text: input)).text + } + + func collectWords(_ input: [String]) throws -> String { + let call = try service.collect(completion: nil) + for text in input { + try call.send(Echo_EchoRequest(text: text), completion: { _ in }) + } + call.waitForSendOperationsToFinish() + return try call.closeAndReceive().text + } + + func expandWords(_ input: String) throws -> [String] { + let call = try service.expand(Echo_EchoRequest(text: input), completion: nil) + var results: [String] = [] + while let response = try call.receive() { + results.append(response.text) + } + return results + } + + func updateWords(_ input: [String]) throws -> [String] { + let call = try service.update(completion: nil) + for text in input { + try call.send(Echo_EchoRequest(text: text), completion: { _ in }) + } + call.waitForSendOperationsToFinish() + + var results: [String] = [] + while let response = try call.receive() { + results.append(response.text) + } + return results + } +} + +class ClientTestExample: XCTestCase { + static var allTests: [(String, (ClientTestExample) -> () throws -> Void)] { + return [ + ("testUnary", testUnary), + ("testClientStreaming", testClientStreaming), + ("testServerStreaming", testServerStreaming), + ("testBidirectionalStreaming", testBidirectionalStreaming) + ] + } +} + +extension ClientTestExample { + func testUnary() { + let fakeService = Echo_EchoServiceTestStub() + fakeService.getResponses.append(Echo_EchoResponse(text: "bar")) + + let client = ClientUnderTest(service: fakeService) + XCTAssertEqual("bar", try client.getWord("foo")) + + // Ensure that all responses have been consumed. + XCTAssertEqual(0, fakeService.getResponses.count) + // Ensure that the expected requests have been sent. + XCTAssertEqual([Echo_EchoRequest(text: "foo")], fakeService.getRequests) + } + + func testClientStreaming() { + let inputStrings = ["foo", "bar", "baz"] + let fakeService = Echo_EchoServiceTestStub() + let fakeCall = Echo_EchoCollectCallTestStub() + fakeCall.output = Echo_EchoResponse(text: "response") + fakeService.collectCalls.append(fakeCall) + + let client = ClientUnderTest(service: fakeService) + XCTAssertEqual("response", try client.collectWords(inputStrings)) + + // Ensure that the expected requests have been sent. + XCTAssertEqual(inputStrings.map { Echo_EchoRequest(text: $0) }, fakeCall.inputs) + } + + func testServerStreaming() { + let outputStrings = ["foo", "bar", "baz"] + let fakeService = Echo_EchoServiceTestStub() + let fakeCall = Echo_EchoExpandCallTestStub() + fakeCall.outputs = outputStrings.map { Echo_EchoResponse(text: $0) } + fakeService.expandCalls.append(fakeCall) + + let client = ClientUnderTest(service: fakeService) + XCTAssertEqual(outputStrings, try client.expandWords("inputWord")) + + // Ensure that all responses have been consumed. + XCTAssertEqual(0, fakeCall.outputs.count) + // Ensure that the expected requests have been sent. + XCTAssertEqual([Echo_EchoRequest(text: "inputWord")], fakeService.expandRequests) + } + + func testBidirectionalStreaming() { + let inputStrings = ["foo", "bar", "baz"] + let outputStrings = ["foo2", "bar2", "baz2"] + let fakeService = Echo_EchoServiceTestStub() + let fakeCall = Echo_EchoUpdateCallTestStub() + fakeCall.outputs = outputStrings.map { Echo_EchoResponse(text: $0) } + fakeService.updateCalls.append(fakeCall) + + let client = ClientUnderTest(service: fakeService) + XCTAssertEqual(outputStrings, try client.updateWords(inputStrings)) + + // Ensure that all responses have been consumed. + XCTAssertEqual(0, fakeCall.outputs.count) + // Ensure that the expected requests have been sent. + XCTAssertEqual(inputStrings.map { Echo_EchoRequest(text: $0) }, fakeCall.inputs) + } +} diff --git a/Tests/SwiftGRPCTests/ClientTimeoutTests.swift b/Tests/SwiftGRPCTests/ClientTimeoutTests.swift new file mode 100644 index 000000000..2ea98109d --- /dev/null +++ b/Tests/SwiftGRPCTests/ClientTimeoutTests.swift @@ -0,0 +1,84 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +class ClientTimeoutTests: BasicEchoTestCase { + static var allTests: [(String, (ClientTimeoutTests) -> () throws -> Void)] { + return [ + ("testClientStreamingTimeoutBeforeSending", testClientStreamingTimeoutBeforeSending), + ("testClientStreamingTimeoutAfterSending", testClientStreamingTimeoutAfterSending) + ] + } + + override var defaultTimeout: TimeInterval { return 0.1 } +} + +extension ClientTimeoutTests { + func testClientStreamingTimeoutBeforeSending() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.deadlineExceeded, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + Thread.sleep(forTimeInterval: 0.2) + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + XCTAssertEqual(.unknown, $0 as! CallError) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, received \(result) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testClientStreamingTimeoutAfterSending() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.deadlineExceeded, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in XCTAssertNil($0); sendExpectation.fulfill() } + call.waitForSendOperationsToFinish() + + Thread.sleep(forTimeInterval: 0.2) + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, received \(result) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + // FIXME(danielalm): Add support for setting a maximum timeout on the server, to prevent DoS attacks where clients + // start a ton of calls, but never finish them (i.e. essentially leaking a connection on the server side). +} diff --git a/Tests/SwiftGRPCTests/ConnectionFailureTests.swift b/Tests/SwiftGRPCTests/ConnectionFailureTests.swift new file mode 100644 index 000000000..eb806e113 --- /dev/null +++ b/Tests/SwiftGRPCTests/ConnectionFailureTests.swift @@ -0,0 +1,126 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +// TODO(danielalm): Also test connection failure with regards to SSL issues. +class ConnectionFailureTests: XCTestCase { + static var allTests: [(String, (ConnectionFailureTests) -> () throws -> Void)] { + return [ + ("testConnectionFailureUnary", testConnectionFailureUnary), + ("testConnectionFailureClientStreaming", testConnectionFailureClientStreaming), + ("testConnectionFailureServerStreaming", testConnectionFailureServerStreaming), + ("testConnectionFailureBidirectionalStreaming", testConnectionFailureBidirectionalStreaming) + ] + } + + let address = "localhost:5050" + + let defaultTimeout: TimeInterval = 0.5 +} + +extension ConnectionFailureTests { + func testConnectionFailureUnary() { + let client = Echo_EchoServiceClient(address: "localhost:1234", secure: false) + client.timeout = defaultTimeout + + do { + let result = try client.get(Echo_EchoRequest(text: "foo")).text + XCTFail("should have thrown, received \(result) instead") + } catch { + guard case let .callError(callResult) = error as! RPCError + else { XCTFail("unexpected error \(error)"); return } + XCTAssertEqual(.unavailable, callResult.statusCode) + XCTAssertEqual("Connect Failed", callResult.statusMessage) + } + } + + func testConnectionFailureClientStreaming() { + let client = Echo_EchoServiceClient(address: "localhost:1234", secure: false) + client.timeout = defaultTimeout + + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.unavailable, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + XCTAssertEqual(.unknown, $0 as! CallError) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, received \(result) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testConnectionFailureServerStreaming() { + let client = Echo_EchoServiceClient(address: "localhost:1234", secure: false) + client.timeout = defaultTimeout + + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in + XCTAssertEqual(.unavailable, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + do { + let result = try call.receive() + XCTFail("should have thrown, received \(String(describing: result)) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testConnectionFailureBidirectionalStreaming() { + let client = Echo_EchoServiceClient(address: "localhost:1234", secure: false) + client.timeout = defaultTimeout + + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.update { callResult in + XCTAssertEqual(.unavailable, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + XCTAssertEqual(.unknown, $0 as! CallError) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + do { + let result = try call.receive() + XCTFail("should have thrown, received \(String(describing: result)) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } +} diff --git a/Tests/SwiftGRPCTests/EchoTests.swift b/Tests/SwiftGRPCTests/EchoTests.swift index f1dcf758a..a1a871dea 100644 --- a/Tests/SwiftGRPCTests/EchoTests.swift +++ b/Tests/SwiftGRPCTests/EchoTests.swift @@ -18,13 +18,7 @@ import Foundation @testable import SwiftGRPC import XCTest -extension Echo_EchoRequest { - init(text: String) { - self.text = text - } -} - -class EchoTests: XCTestCase { +class EchoTests: BasicEchoTestCase { static var allTests: [(String, (EchoTests) -> () throws -> Void)] { return [ ("testUnary", testUnary), @@ -40,51 +34,11 @@ class EchoTests: XCTestCase { } static let lotsOfStrings = (0..<1000).map { String(describing: $0) } - - let defaultTimeout: TimeInterval = 5.0 - - let provider = EchoProvider() - var server: Echo_EchoServer! - var client: Echo_EchoServiceClient! - - var secure: Bool { return false } - - override func setUp() { - super.setUp() - - let address = "localhost:5050" - if secure { - let certificateString = String(data: certificateForTests, encoding: .utf8)! - server = Echo_EchoServer(address: address, - certificateString: certificateString, - keyString: String(data: keyForTests, encoding: .utf8)!, - provider: provider) - server.start(queue: DispatchQueue.global()) - client = Echo_EchoServiceClient(address: address, certificates: certificateString, host: "example.com") - } else { - server = Echo_EchoServer(address: address, provider: provider) - server.start(queue: DispatchQueue.global()) - client = Echo_EchoServiceClient(address: address, secure: false) - } - - client.timeout = defaultTimeout - } - - override func tearDown() { - client = nil - - server.server.stop() - server = nil - - super.tearDown() - } } -// Currently broken and thus commented out. -// TODO(danielalm): Fix these. -//class EchoTestsSecure: EchoTests { -// override var secure: Bool { return true } -//} +class EchoTestsSecure: EchoTests { + override var secure: Bool { return true } +} extension EchoTests { func testUnary() { @@ -146,9 +100,10 @@ extension EchoTests { completionHandlerExpectation.fulfill() } - XCTAssertEqual("Swift echo expand (0): foo", try! call.receive().text) - XCTAssertEqual("Swift echo expand (1): bar", try! call.receive().text) - XCTAssertEqual("Swift echo expand (2): baz", try! call.receive().text) + XCTAssertEqual("Swift echo expand (0): foo", try! call.receive()!.text) + XCTAssertEqual("Swift echo expand (1): bar", try! call.receive()!.text) + XCTAssertEqual("Swift echo expand (2): baz", try! call.receive()!.text) + XCTAssertNil(try! call.receive()) waitForExpectations(timeout: defaultTimeout) } @@ -161,8 +116,9 @@ extension EchoTests { } for string in EchoTests.lotsOfStrings { - XCTAssertEqual("Swift echo expand (\(string)): \(string)", try! call.receive().text) + XCTAssertEqual("Swift echo expand (\(string)): \(string)", try! call.receive()!.text) } + XCTAssertNil(try! call.receive()) waitForExpectations(timeout: defaultTimeout) } @@ -187,9 +143,10 @@ extension EchoTests { let closeCompletionHandlerExpectation = expectation(description: "close completion handler called") try! call.closeSend { closeCompletionHandlerExpectation.fulfill() } - XCTAssertEqual("Swift echo update (0): foo", try! call.receive().text) - XCTAssertEqual("Swift echo update (1): bar", try! call.receive().text) - XCTAssertEqual("Swift echo update (2): baz", try! call.receive().text) + XCTAssertEqual("Swift echo update (0): foo", try! call.receive()!.text) + XCTAssertEqual("Swift echo update (1): bar", try! call.receive()!.text) + XCTAssertEqual("Swift echo update (2): baz", try! call.receive()!.text) + XCTAssertNil(try! call.receive()) waitForExpectations(timeout: defaultTimeout) } @@ -203,19 +160,21 @@ extension EchoTests { var sendExpectation = expectation(description: "send completion handler 1 called") try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in XCTAssertNil($0); sendExpectation.fulfill() } - XCTAssertEqual("Swift echo update (0): foo", try! call.receive().text) + XCTAssertEqual("Swift echo update (0): foo", try! call.receive()!.text) sendExpectation = expectation(description: "send completion handler 2 called") try! call.send(Echo_EchoRequest(text: "bar")) { [sendExpectation] in XCTAssertNil($0); sendExpectation.fulfill() } - XCTAssertEqual("Swift echo update (1): bar", try! call.receive().text) + XCTAssertEqual("Swift echo update (1): bar", try! call.receive()!.text) sendExpectation = expectation(description: "send completion handler 3 called") try! call.send(Echo_EchoRequest(text: "baz")) { [sendExpectation] in XCTAssertNil($0); sendExpectation.fulfill() } - XCTAssertEqual("Swift echo update (2): baz", try! call.receive().text) + XCTAssertEqual("Swift echo update (2): baz", try! call.receive()!.text) let closeCompletionHandlerExpectation = expectation(description: "close completion handler called") try! call.closeSend { closeCompletionHandlerExpectation.fulfill() } + XCTAssertNil(try! call.receive()) + waitForExpectations(timeout: defaultTimeout) } @@ -236,8 +195,9 @@ extension EchoTests { try! call.closeSend { closeCompletionHandlerExpectation.fulfill() } for string in EchoTests.lotsOfStrings { - XCTAssertEqual("Swift echo update (\(string)): \(string)", try! call.receive().text) + XCTAssertEqual("Swift echo update (\(string)): \(string)", try! call.receive()!.text) } + XCTAssertNil(try! call.receive()) waitForExpectations(timeout: defaultTimeout) } @@ -252,11 +212,13 @@ extension EchoTests { for string in EchoTests.lotsOfStrings { let sendExpectation = expectation(description: "send completion handler \(string) called") try! call.send(Echo_EchoRequest(text: string)) { [sendExpectation] in XCTAssertNil($0); sendExpectation.fulfill() } - XCTAssertEqual("Swift echo update (\(string)): \(string)", try! call.receive().text) + XCTAssertEqual("Swift echo update (\(string)): \(string)", try! call.receive()!.text) } let closeCompletionHandlerExpectation = expectation(description: "close completion handler called") try! call.closeSend { closeCompletionHandlerExpectation.fulfill() } + + XCTAssertNil(try! call.receive()) waitForExpectations(timeout: defaultTimeout) } diff --git a/Tests/SwiftGRPCTests/GRPCTests.swift b/Tests/SwiftGRPCTests/GRPCTests.swift index d46554f30..bbc43380a 100644 --- a/Tests/SwiftGRPCTests/GRPCTests.swift +++ b/Tests/SwiftGRPCTests/GRPCTests.swift @@ -69,7 +69,7 @@ let trailingServerMetadata = "11": "eleven", "12": "twelve" ] -let steps = 10 +let steps = 100 let hello = "/hello.unary" let helloServerStream = "/hello.server-stream" let helloBiDiStream = "/hello.bidi-stream" @@ -84,41 +84,29 @@ let eventStatusMessage = "Not Found" func runTest(useSSL: Bool) { gRPC.initialize() - let serverRunningSemaphore = DispatchSemaphore(value: 0) + var serverRunningSemaphore: DispatchSemaphore? // create the server let server: Server if useSSL { - let certificateURL = URL(fileURLWithPath: "Tests/ssl.crt") - let keyURL = URL(fileURLWithPath: "Tests/ssl.key") - guard - let certificate = try? String(contentsOf: certificateURL, encoding: .utf8), - let key = try? String(contentsOf: keyURL, encoding: .utf8) - else { - // FIXME: We don't want tests to silently pass just because the certificates can't be loaded. - return - } server = Server(address: address, - key: key, - certs: certificate) + key: String(data: keyForTests, encoding: .utf8)!, + certs: String(data: certificateForTests, encoding: .utf8)!) } else { server = Server(address: address) } // start the server - DispatchQueue.global().async { - do { - try runServer(server: server) - } catch (let error) { - XCTFail("server error \(error)") - } - serverRunningSemaphore.signal() // when the server exits, the test is finished + do { + serverRunningSemaphore = try runServer(server: server) + } catch { + XCTFail("server error \(error)") } // run the client do { try runClient(useSSL: useSSL) - } catch (let error) { + } catch { XCTFail("client error \(error)") } @@ -126,7 +114,7 @@ func runTest(useSSL: Bool) { server.stop() // wait until the server has shut down - _ = serverRunningSemaphore.wait() + _ = serverRunningSemaphore!.wait() } func verify_metadata(_ metadata: Metadata, expected: [String: String], file: StaticString = #file, line: UInt = #line) { @@ -145,14 +133,9 @@ func runClient(useSSL: Bool) throws { let channel: Channel if useSSL { - let certificateURL = URL(fileURLWithPath: "Tests/ssl.crt") - guard - let certificates = try? String(contentsOf: certificateURL, encoding: .utf8) - else { - return - } - let host = "example.com" - channel = Channel(address: address, certificates: certificates, host: host) + channel = Channel(address: address, + certificates: String(data: certificateForTests, encoding: .utf8)!, + host: host) } else { channel = Channel(address: address, secure: false) } @@ -210,8 +193,8 @@ func callServerStream(channel: Channel) throws { try call.start(.serverStreaming, metadata: metadata, message: message) { response in - XCTAssertEqual(response.statusCode, StatusCode.outOfRange) - XCTAssertEqual(response.statusMessage, "Out of range") + XCTAssertEqual(response.statusCode, .ok) + XCTAssertEqual(response.statusMessage, "Custom Status Message ServerStreaming") // verify the trailing metadata from the server let trailingMetadata = response.trailingMetadata! @@ -222,14 +205,15 @@ func callServerStream(channel: Channel) throws { for _ in 0.. DispatchSemaphore { var requestCount = 0 let sem = DispatchSemaphore(value: 0) server.run { requestHandler in @@ -305,7 +295,7 @@ func runServer(server: Server) throws { } requestCount += 1 - } catch (let error) { + } catch { XCTFail("error \(error)") } } @@ -314,7 +304,7 @@ func runServer(server: Server) throws { sem.signal() } // wait for the server to exit - _ = sem.wait() + return sem } func handleUnary(requestHandler: Handler, requestCount: Int) throws { @@ -332,14 +322,14 @@ func handleUnary(requestHandler: Handler, requestCount: Int) throws { let replyMessage = serverText let trailingMetadataToSend = Metadata(trailingServerMetadata) try requestHandler.sendResponse(message: replyMessage.data(using: .utf8)!, - statusCode: evenStatusCode, - statusMessage: eventStatusMessage, - trailingMetadata: trailingMetadataToSend) + status: ServerStatus(code: evenStatusCode, + message: eventStatusMessage, + trailingMetadata: trailingMetadataToSend)) } else { let trailingMetadataToSend = Metadata(trailingServerMetadata) - try requestHandler.sendResponse(statusCode: oddStatusCode, - statusMessage: oddStatusMessage, - trailingMetadata: trailingMetadataToSend) + try requestHandler.sendStatus(ServerStatus(code: oddStatusCode, + message: oddStatusMessage, + trailingMetadata: trailingMetadataToSend)) } } @@ -357,18 +347,20 @@ func handleServerStream(requestHandler: Handler) throws { let replyMessage = serverText for _ in 0.. Echo_EchoResponse { + session.cancel() + return Echo_EchoResponse() + } + + func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws { + session.cancel() + XCTAssertThrowsError(try session.send(Echo_EchoResponse())) + } + + func collect(session: Echo_EchoCollectSession) throws { + session.cancel() + try! session.sendAndClose(response: Echo_EchoResponse(), status: .ok, completion: nil) + } + + func update(session: Echo_EchoUpdateSession) throws { + session.cancel() + XCTAssertThrowsError(try session.send(Echo_EchoResponse())) + } +} + +class ServerCancellingTests: BasicEchoTestCase { + static var allTests: [(String, (ServerCancellingTests) -> () throws -> Void)] { + return [ + ("testServerThrowsUnary", testServerThrowsUnary), + ("testServerThrowsClientStreaming", testServerThrowsClientStreaming), + ("testServerThrowsServerStreaming", testServerThrowsServerStreaming), + ("testServerThrowsBidirectionalStreaming", testServerThrowsBidirectionalStreaming) + ] + } + + override func makeProvider() -> Echo_EchoProvider { return CancellingProvider() } +} + +extension ServerCancellingTests { + func testServerThrowsUnary() { + do { + let result = try client.get(Echo_EchoRequest(text: "foo")).text + XCTFail("should have thrown, received \(result) instead") + } catch { + guard case let .callError(callResult) = error as! RPCError + else { XCTFail("unexpected error \(error)"); return } + XCTAssertEqual(.cancelled, callResult.statusCode) + XCTAssertEqual("Cancelled", callResult.statusMessage) + } + } + + func testServerThrowsClientStreaming() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.cancelled, callResult.statusCode) + XCTAssertEqual("Cancelled", callResult.statusMessage) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message. + XCTAssertNil($0) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, received \(result) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testServerThrowsServerStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in + XCTAssertEqual(.cancelled, callResult.statusCode) + XCTAssertEqual("Cancelled", callResult.statusMessage) + completionHandlerExpectation.fulfill() + } + + // FIXME(danielalm): Why does `call.receive()` essentially return "end of stream", rather than returning an error? + XCTAssertNil(try! call.receive()) + + waitForExpectations(timeout: defaultTimeout) + } + + func testServerThrowsBidirectionalStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.update { callResult in + XCTAssertEqual(.cancelled, callResult.statusCode) + XCTAssertEqual("Cancelled", callResult.statusMessage) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message. + XCTAssertNil($0) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + // FIXME(danielalm): Why does `call.receive()` essentially return "end of stream", rather than returning an error? + XCTAssertNil(try! call.receive()) + + waitForExpectations(timeout: defaultTimeout) + } +} diff --git a/Tests/SwiftGRPCTests/ServerTestExample.swift b/Tests/SwiftGRPCTests/ServerTestExample.swift new file mode 100644 index 000000000..de0124398 --- /dev/null +++ b/Tests/SwiftGRPCTests/ServerTestExample.swift @@ -0,0 +1,92 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +// Sample test suite to demonstrate how one would test a `Provider` implementation. +// These tests don't really test the logic of the SwiftGRPC library, but are meant +// as an example of how one would go about testing their own client/server code that +// relies on SwiftGRPC. +class ServerTestExample: XCTestCase { + static var allTests: [(String, (ServerTestExample) -> () throws -> Void)] { + return [ + ("testUnary", testUnary), + ("testClientStreaming", testClientStreaming), + ("testServerStreaming", testServerStreaming), + ("testBidirectionalStreaming", testBidirectionalStreaming) + ] + } + + var provider: Echo_EchoProvider! + + override func setUp() { + super.setUp() + + provider = EchoProvider() + } + + override func tearDown() { + provider = nil + + super.tearDown() + } +} + +extension ServerTestExample { + func testUnary() { + XCTAssertEqual(Echo_EchoResponse(text: "Swift echo get: "), + try provider.get(request: Echo_EchoRequest(text: ""), session: Echo_EchoGetSessionTestStub())) + XCTAssertEqual(Echo_EchoResponse(text: "Swift echo get: foo"), + try provider.get(request: Echo_EchoRequest(text: "foo"), session: Echo_EchoGetSessionTestStub())) + XCTAssertEqual(Echo_EchoResponse(text: "Swift echo get: foo bar"), + try provider.get(request: Echo_EchoRequest(text: "foo bar"), session: Echo_EchoGetSessionTestStub())) + } + + func testClientStreaming() { + let session = Echo_EchoCollectSessionTestStub() + session.inputs = ["foo", "bar", "baz"].map { Echo_EchoRequest(text: $0) } + + XCTAssertNoThrow(try provider.collect(session: session)) + + XCTAssertEqual(.ok, session.status!.code) + XCTAssertEqual(Echo_EchoResponse(text: "Swift echo collect: foo bar baz"), + session.output) + } + + func testServerStreaming() { + let session = Echo_EchoExpandSessionTestStub() + XCTAssertNoThrow(try provider.expand(request: Echo_EchoRequest(text: "foo bar baz"), session: session)) + + XCTAssertEqual(.ok, session.status!.code) + XCTAssertEqual(["foo", "bar", "baz"].enumerated() + .map { Echo_EchoResponse(text: "Swift echo expand (\($0)): \($1)") }, + session.outputs) + } + + func testBidirectionalStreaming() { + let inputStrings = ["foo", "bar", "baz"] + let session = Echo_EchoUpdateSessionTestStub() + session.inputs = inputStrings.map { Echo_EchoRequest(text: $0) } + XCTAssertNoThrow(try provider.update(session: session)) + + XCTAssertEqual(.ok, session.status!.code) + XCTAssertEqual(inputStrings.enumerated() + .map { Echo_EchoResponse(text: "Swift echo update (\($0)): \($1)") }, + session.outputs) + } +} diff --git a/Tests/SwiftGRPCTests/ServerThrowingTests.swift b/Tests/SwiftGRPCTests/ServerThrowingTests.swift new file mode 100644 index 000000000..dc6e7c172 --- /dev/null +++ b/Tests/SwiftGRPCTests/ServerThrowingTests.swift @@ -0,0 +1,128 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +fileprivate let testStatus = ServerStatus(code: .permissionDenied, message: "custom status message") + +fileprivate class StatusThrowingProvider: Echo_EchoProvider { + func get(request: Echo_EchoRequest, session _: Echo_EchoGetSession) throws -> Echo_EchoResponse { + throw testStatus + } + + func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws { + throw testStatus + } + + func collect(session: Echo_EchoCollectSession) throws { + throw testStatus + } + + func update(session: Echo_EchoUpdateSession) throws { + throw testStatus + } +} + +class ServerThrowingTests: BasicEchoTestCase { + static var allTests: [(String, (ServerThrowingTests) -> () throws -> Void)] { + return [ + ("testServerThrowsUnary", testServerThrowsUnary), + ("testServerThrowsClientStreaming", testServerThrowsClientStreaming), + ("testServerThrowsServerStreaming", testServerThrowsServerStreaming), + ("testServerThrowsBidirectionalStreaming", testServerThrowsBidirectionalStreaming) + ] + } + + override func makeProvider() -> Echo_EchoProvider { return StatusThrowingProvider() } +} + +extension ServerThrowingTests { + func testServerThrowsUnary() { + do { + let result = try client.get(Echo_EchoRequest(text: "foo")).text + XCTFail("should have thrown, received \(result) instead") + } catch { + guard case let .callError(callResult) = error as! RPCError + else { XCTFail("unexpected error \(error)"); return } + XCTAssertEqual(.permissionDenied, callResult.statusCode) + XCTAssertEqual("custom status message", callResult.statusMessage) + } + } + + func testServerThrowsClientStreaming() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.permissionDenied, callResult.statusCode) + XCTAssertEqual("custom status message", callResult.statusMessage) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message. + XCTAssertNil($0) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, received \(result) instead") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testServerThrowsServerStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in + XCTAssertEqual(.permissionDenied, callResult.statusCode) + XCTAssertEqual("custom status message", callResult.statusMessage) + completionHandlerExpectation.fulfill() + } + + // FIXME(danielalm): Why does `call.receive()` essentially return "end of stream", rather than returning an error? + XCTAssertNil(try! call.receive()) + + waitForExpectations(timeout: defaultTimeout) + } + + func testServerThrowsBidirectionalStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.update { callResult in + XCTAssertEqual(.permissionDenied, callResult.statusCode) + XCTAssertEqual("custom status message", callResult.statusMessage) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message. + XCTAssertNil($0) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + // FIXME(danielalm): Why does `call.receive()` essentially return "end of stream", rather than returning an error? + XCTAssertNil(try! call.receive()) + + waitForExpectations(timeout: defaultTimeout) + } +} diff --git a/Tests/SwiftGRPCTests/ServerTimeoutTests.swift b/Tests/SwiftGRPCTests/ServerTimeoutTests.swift new file mode 100644 index 000000000..5bf20670b --- /dev/null +++ b/Tests/SwiftGRPCTests/ServerTimeoutTests.swift @@ -0,0 +1,128 @@ +/* + * Copyright 2018, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Dispatch +import Foundation +@testable import SwiftGRPC +import XCTest + +fileprivate class TimingOutEchoProvider: Echo_EchoProvider { + func get(request: Echo_EchoRequest, session _: Echo_EchoGetSession) throws -> Echo_EchoResponse { + Thread.sleep(forTimeInterval: 0.2) + return Echo_EchoResponse() + } + + func expand(request: Echo_EchoRequest, session: Echo_EchoExpandSession) throws { + Thread.sleep(forTimeInterval: 0.2) + } + + func collect(session: Echo_EchoCollectSession) throws { + Thread.sleep(forTimeInterval: 0.2) + } + + func update(session: Echo_EchoUpdateSession) throws { + Thread.sleep(forTimeInterval: 0.2) + } +} + +class ServerTimeoutTests: BasicEchoTestCase { + static var allTests: [(String, (ServerTimeoutTests) -> () throws -> Void)] { + return [ + ("testTimeoutUnary", testTimeoutUnary), + ("testTimeoutClientStreaming", testTimeoutClientStreaming), + ("testTimeoutServerStreaming", testTimeoutServerStreaming), + ("testTimeoutBidirectionalStreaming", testTimeoutBidirectionalStreaming) + ] + } + + override func makeProvider() -> Echo_EchoProvider { return TimingOutEchoProvider() } + + override var defaultTimeout: TimeInterval { return 0.1 } +} + +extension ServerTimeoutTests { + func testTimeoutUnary() { + do { + let result = try client.get(Echo_EchoRequest(text: "foo")).text + XCTFail("should have thrown, received \(result) instead") + } catch { + guard case let .callError(callResult) = error as! RPCError + else { XCTFail("unexpected error \(error)"); return } + XCTAssertEqual(.deadlineExceeded, callResult.statusCode) + XCTAssertEqual("Deadline Exceeded", callResult.statusMessage) + } + } + + func testTimeoutClientStreaming() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.collect { callResult in + XCTAssertEqual(.deadlineExceeded, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message. + XCTAssertNil($0) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + do { + let result = try call.closeAndReceive() + XCTFail("should have thrown, instead received \(result)") + } catch let receiveError { + XCTAssertEqual(.unknown, (receiveError as! RPCError).callResult!.statusCode) + } + + waitForExpectations(timeout: defaultTimeout) + } + + func testTimeoutServerStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.expand(Echo_EchoRequest(text: "foo bar baz")) { callResult in + XCTAssertEqual(.deadlineExceeded, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + // FIXME(danielalm): Why does `call.receive()` essentially return "end of stream" once the call times out, + // rather than returning an error? + XCTAssertNil(try! call.receive()) + + waitForExpectations(timeout: defaultTimeout) + } + + func testTimeoutBidirectionalStreaming() { + let completionHandlerExpectation = expectation(description: "completion handler called") + let call = try! client.update { callResult in + XCTAssertEqual(.deadlineExceeded, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + let sendExpectation = expectation(description: "send completion handler 1 called") + try! call.send(Echo_EchoRequest(text: "foo")) { [sendExpectation] in + // The server only times out later in its lifecycle, so we shouldn't get an error when trying to send a message. + XCTAssertNil($0) + sendExpectation.fulfill() + } + call.waitForSendOperationsToFinish() + + // FIXME(danielalm): Why does `call.receive()` essentially return "end of stream" once the call times out, + // rather than returning an error? + XCTAssertNil(try! call.receive()) + + waitForExpectations(timeout: defaultTimeout) + } +} diff --git a/fix-indentation-settings.rb b/fix-indentation-settings.rb new file mode 100644 index 000000000..43c6f355f --- /dev/null +++ b/fix-indentation-settings.rb @@ -0,0 +1,7 @@ +require 'xcodeproj' +project_path = './SwiftGRPC.xcodeproj' +project = Xcodeproj::Project.open(project_path) +project.main_group.uses_tabs = '0' +project.main_group.tab_width = '2' +project.main_group.indent_width = '2' +project.save