diff --git a/Sources/Examples/Echo/Model/echo.grpc.swift b/Sources/Examples/Echo/Model/echo.grpc.swift index 9fda26dca..08218dfab 100644 --- a/Sources/Examples/Echo/Model/echo.grpc.swift +++ b/Sources/Examples/Echo/Model/echo.grpc.swift @@ -248,6 +248,75 @@ extension Echo_EchoAsyncClientProtocol { } } +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension Echo_EchoAsyncClientProtocol { + public func get( + _ request: Echo_EchoRequest, + callOptions: CallOptions? = nil + ) async throws -> Echo_EchoResponse { + return try await self.performAsyncUnaryCall( + path: "/echo.Echo/Get", + request: request, + callOptions: callOptions ?? self.defaultCallOptions + ) + } + + public func expand( + _ request: Echo_EchoRequest, + callOptions: CallOptions? = nil + ) -> GRPCAsyncResponseStream { + return self.performAsyncServerStreamingCall( + path: "/echo.Echo/Expand", + request: request, + callOptions: callOptions ?? self.defaultCallOptions + ) + } + + public func collect( + _ requests: RequestStream, + callOptions: CallOptions? = nil + ) async throws -> Echo_EchoResponse where RequestStream: Sequence, RequestStream.Element == Echo_EchoRequest { + return try await self.performAsyncClientStreamingCall( + path: "/echo.Echo/Collect", + requests: requests, + callOptions: callOptions ?? self.defaultCallOptions + ) + } + + public func collect( + _ requests: RequestStream, + callOptions: CallOptions? = nil + ) async throws -> Echo_EchoResponse where RequestStream: AsyncSequence, RequestStream.Element == Echo_EchoRequest { + return try await self.performAsyncClientStreamingCall( + path: "/echo.Echo/Collect", + requests: requests, + callOptions: callOptions ?? self.defaultCallOptions + ) + } + + public func update( + _ requests: RequestStream, + callOptions: CallOptions? = nil + ) -> GRPCAsyncResponseStream where RequestStream: Sequence, RequestStream.Element == Echo_EchoRequest { + return self.performAsyncBidirectionalStreamingCall( + path: "/echo.Echo/Update", + requests: requests, + callOptions: callOptions ?? self.defaultCallOptions + ) + } + + public func update( + _ requests: RequestStream, + callOptions: CallOptions? = nil + ) -> GRPCAsyncResponseStream where RequestStream: AsyncSequence, RequestStream.Element == Echo_EchoRequest { + return self.performAsyncBidirectionalStreamingCall( + path: "/echo.Echo/Update", + requests: requests, + callOptions: callOptions ?? self.defaultCallOptions + ) + } +} + @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) public struct Echo_EchoAsyncClient: Echo_EchoAsyncClientProtocol { public var channel: GRPCChannel diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCChannel+AsyncAwaitSupport.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCChannel+AsyncAwaitSupport.swift index e7eec9fe4..de4b561a6 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCChannel+AsyncAwaitSupport.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCChannel+AsyncAwaitSupport.swift @@ -26,7 +26,7 @@ extension GRPCChannel { /// - request: The request to send. /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncUnaryCall( + internal func makeAsyncUnaryCall( path: String, request: Request, callOptions: CallOptions, @@ -50,7 +50,7 @@ extension GRPCChannel { /// - request: The request to send. /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncUnaryCall( + internal func makeAsyncUnaryCall( path: String, request: Request, callOptions: CallOptions, @@ -73,7 +73,7 @@ extension GRPCChannel { /// - path: Path of the RPC, e.g. "/echo.Echo/Get" /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncClientStreamingCall( + internal func makeAsyncClientStreamingCall( path: String, callOptions: CallOptions, interceptors: [ClientInterceptor] = [] @@ -94,7 +94,7 @@ extension GRPCChannel { /// - path: Path of the RPC, e.g. "/echo.Echo/Get" /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncClientStreamingCall( + internal func makeAsyncClientStreamingCall( path: String, callOptions: CallOptions, interceptors: [ClientInterceptor] = [] @@ -116,7 +116,7 @@ extension GRPCChannel { /// - request: The request to send. /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncServerStreamingCall( + internal func makeAsyncServerStreamingCall( path: String, request: Request, callOptions: CallOptions, @@ -140,7 +140,7 @@ extension GRPCChannel { /// - request: The request to send. /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncServerStreamingCall( + internal func makeAsyncServerStreamingCall( path: String, request: Request, callOptions: CallOptions, @@ -163,7 +163,7 @@ extension GRPCChannel { /// - path: Path of the RPC, e.g. "/echo.Echo/Get" /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncBidirectionalStreamingCall( + internal func makeAsyncBidirectionalStreamingCall( path: String, callOptions: CallOptions, interceptors: [ClientInterceptor] = [] @@ -184,7 +184,7 @@ extension GRPCChannel { /// - path: Path of the RPC, e.g. "/echo.Echo/Get" /// - callOptions: Options for the RPC. /// - interceptors: A list of interceptors to intercept the request and response stream with. - public func makeAsyncBidirectionalStreamingCall( + internal func makeAsyncBidirectionalStreamingCall( path: String, callOptions: CallOptions, interceptors: [ClientInterceptor] = [] diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCClient+AsyncAwaitSupport.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCClient+AsyncAwaitSupport.swift index 3aa3e1742..ba05e7aa4 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCClient+AsyncAwaitSupport.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCClient+AsyncAwaitSupport.swift @@ -148,4 +148,318 @@ extension GRPCClient { } } +// MARK: - "Simple, but safe" wrappers. + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension GRPCClient { + public func performAsyncUnaryCall( + path: String, + request: Request, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + responseType: Response.Type = Response.self + ) async throws -> Response { + return try await self.channel.makeAsyncUnaryCall( + path: path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ).response + } + + public func performAsyncUnaryCall( + path: String, + request: Request, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + responseType: Response.Type = Response.self + ) async throws -> Response { + return try await self.channel.makeAsyncUnaryCall( + path: path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ).response + } + + public func performAsyncServerStreamingCall< + Request: SwiftProtobuf.Message, + Response: SwiftProtobuf.Message + >( + path: String, + request: Request, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + responseType: Response.Type = Response.self + ) -> GRPCAsyncResponseStream { + return self.channel.makeAsyncServerStreamingCall( + path: path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ).responses + } + + public func performAsyncServerStreamingCall( + path: String, + request: Request, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + responseType: Response.Type = Response.self + ) -> GRPCAsyncResponseStream { + return self.channel.makeAsyncServerStreamingCall( + path: path, + request: request, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ).responses + } + + public func performAsyncClientStreamingCall< + Request: SwiftProtobuf.Message, + Response: SwiftProtobuf.Message, + RequestStream + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) async throws -> Response + where RequestStream: AsyncSequence, RequestStream.Element == Request { + let call = self.channel.makeAsyncClientStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return try await self.perform(call, with: requests) + } + + public func performAsyncClientStreamingCall< + Request: GRPCPayload, + Response: GRPCPayload, + RequestStream + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) async throws -> Response + where RequestStream: AsyncSequence, RequestStream.Element == Request { + let call = self.channel.makeAsyncClientStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return try await self.perform(call, with: requests) + } + + public func performAsyncClientStreamingCall< + Request: SwiftProtobuf.Message, + Response: SwiftProtobuf.Message, + RequestStream + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) async throws -> Response + where RequestStream: Sequence, RequestStream.Element == Request { + let call = self.channel.makeAsyncClientStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return try await self.perform(call, with: AsyncStream(wrapping: requests)) + } + + public func performAsyncClientStreamingCall< + Request: GRPCPayload, + Response: GRPCPayload, + RequestStream + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) async throws -> Response + where RequestStream: Sequence, RequestStream.Element == Request { + let call = self.channel.makeAsyncClientStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return try await self.perform(call, with: AsyncStream(wrapping: requests)) + } + + public func performAsyncBidirectionalStreamingCall< + Request: SwiftProtobuf.Message, + Response: SwiftProtobuf.Message, + RequestStream: AsyncSequence + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) -> GRPCAsyncResponseStream + where RequestStream.Element == Request { + let call = self.channel.makeAsyncBidirectionalStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return self.perform(call, with: requests) + } + + public func performAsyncBidirectionalStreamingCall< + Request: GRPCPayload, + Response: GRPCPayload, + RequestStream: AsyncSequence + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) -> GRPCAsyncResponseStream + where RequestStream.Element == Request { + let call = self.channel.makeAsyncBidirectionalStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return self.perform(call, with: requests) + } + + public func performAsyncBidirectionalStreamingCall< + Request: SwiftProtobuf.Message, + Response: SwiftProtobuf.Message, + RequestStream: Sequence + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) -> GRPCAsyncResponseStream + where RequestStream.Element == Request { + let call = self.channel.makeAsyncBidirectionalStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return self.perform(call, with: AsyncStream(wrapping: requests)) + } + + public func performAsyncBidirectionalStreamingCall< + Request: GRPCPayload, + Response: GRPCPayload, + RequestStream: Sequence + >( + path: String, + requests: RequestStream, + callOptions: CallOptions? = nil, + interceptors: [ClientInterceptor] = [], + requestType: Request.Type = Request.self, + responseType: Response.Type = Response.self + ) -> GRPCAsyncResponseStream + where RequestStream.Element == Request { + let call = self.channel.makeAsyncBidirectionalStreamingCall( + path: path, + callOptions: callOptions ?? self.defaultCallOptions, + interceptors: interceptors + ) + return self.perform(call, with: AsyncStream(wrapping: requests)) + } +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension GRPCClient { + @inlinable + internal func perform( + _ call: GRPCAsyncClientStreamingCall, + with requests: RequestStream + ) async throws -> Response + where RequestStream: AsyncSequence, RequestStream.Element == Request { + // We use a detached task because we use cancellation to signal early, but successful exit. + let requestsTask = Task.detached { + try Task.checkCancellation() + for try await request in requests { + try Task.checkCancellation() + try await call.requestStream.send(request) + } + try Task.checkCancellation() + try await call.requestStream.finish() + try Task.checkCancellation() + } + return try await withTaskCancellationHandler { + // Await the response, which may come before the request stream is exhausted. + let response = try await call.response + // If we have a response, we can stop sending requests. + requestsTask.cancel() + // Return the response. + return response + } onCancel: { + requestsTask.cancel() + // If this outer task is cancelled then we should also cancel the RPC. + Task.detached { + try await call.cancel() + } + } + } + + @inlinable + internal func perform( + _ call: GRPCAsyncBidirectionalStreamingCall, + with requests: RequestStream + ) + -> GRPCAsyncResponseStream + where RequestStream: AsyncSequence, RequestStream.Element == Request { + Task { + try await withTaskCancellationHandler { + try Task.checkCancellation() + for try await request in requests { + try Task.checkCancellation() + try await call.requestStream.send(request) + } + try Task.checkCancellation() + try await call.requestStream.finish() + } onCancel: { + Task.detached { + try await call.cancel() + } + } + } + return call.responses + } +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension AsyncStream { + /// Create an `AsyncStream` from a regular (non-async) `Sequence`. + /// + /// - Note: This is just here to avoid duplicating the above two `perform(_:with:)` functions + /// for `Sequence`. + fileprivate init(wrapping sequence: T) where T: Sequence, T.Element == Element { + self.init { continuation in + var iterator = sequence.makeIterator() + while let value = iterator.next() { + continuation.yield(value) + } + continuation.finish() + } + } +} + #endif diff --git a/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift b/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift index 770f6599d..9244f1ca3 100644 --- a/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift +++ b/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift @@ -130,6 +130,66 @@ extension Generator { } } +// MARK: - Client protocol extension: "Simple, but safe" call wrappers. + +extension Generator { + internal func printAsyncClientProtocolSafeWrappersExtension() { + self.printAvailabilityForAsyncAwait() + self.withIndentation("extension \(self.asyncClientProtocolName)", braces: .curly) { + for (i, method) in self.service.methods.enumerated() { + self.method = method + + let rpcType = streamingType(self.method) + let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false) + + let streamsResponses = [.serverStreaming, .bidirectionalStreaming].contains(rpcType) + let streamsRequests = [.clientStreaming, .bidirectionalStreaming].contains(rpcType) + + let sequenceProtocols = streamsRequests ? ["Sequence", "AsyncSequence"] : [nil] + + for (j, sequenceProtocol) in sequenceProtocols.enumerated() { + // Print a new line if this is not the first function in the extension. + if i > 0 || j > 0 { + self.println() + } + let functionName = streamsRequests + ? "\(self.methodFunctionName)" + : self.methodFunctionName + let requestParamName = streamsRequests ? "requests" : "request" + let requestParamType = streamsRequests ? "RequestStream" : self.methodInputName + let returnType = streamsResponses + ? Types.responseStream(of: self.methodOutputName) + : self.methodOutputName + let maybeWhereClause = sequenceProtocol.map { + "where RequestStream: \($0), RequestStream.Element == \(self.methodInputName)" + } + self.printFunction( + name: functionName, + arguments: [ + "_ \(requestParamName): \(requestParamType)", + "callOptions: \(Types.clientCallOptions)? = nil", + ], + returnType: returnType, + access: self.access, + async: !streamsResponses, + throws: !streamsResponses, + genericWhereClause: maybeWhereClause + ) { + self.withIndentation( + "return\(!streamsResponses ? " try await" : "") self.perform\(callTypeWithoutPrefix)", + braces: .round + ) { + self.println("path: \(self.methodPath),") + self.println("\(requestParamName): \(requestParamName),") + self.println("callOptions: callOptions ?? self.defaultCallOptions") + } + } + } + } + } + } +} + // MARK: - Client protocol implementation extension Generator { diff --git a/Sources/protoc-gen-grpc-swift/Generator-Client.swift b/Sources/protoc-gen-grpc-swift/Generator-Client.swift index 4f8a3e362..29d726963 100644 --- a/Sources/protoc-gen-grpc-swift/Generator-Client.swift +++ b/Sources/protoc-gen-grpc-swift/Generator-Client.swift @@ -37,6 +37,8 @@ extension Generator { self.println() self.printAsyncClientProtocolExtension() self.println() + self.printAsyncClientProtocolSafeWrappersExtension() + self.println() self.printAsyncServiceClientImplementation() self.println() self.printEndCompilerGuardForAsyncAwait() diff --git a/Tests/GRPCTests/AsyncAwaitSupport/AsyncIntegrationTests.swift b/Tests/GRPCTests/AsyncAwaitSupport/AsyncIntegrationTests.swift index 7dd7a5be0..bc5f7eb42 100644 --- a/Tests/GRPCTests/AsyncAwaitSupport/AsyncIntegrationTests.swift +++ b/Tests/GRPCTests/AsyncAwaitSupport/AsyncIntegrationTests.swift @@ -73,6 +73,13 @@ final class AsyncIntegrationTests: GRPCTestCase { } } + func testUnaryWrapper() { + XCTAsyncTest { + let response = try await self.echo.get(.with { $0.text = "hello" }) + XCTAssertEqual(response.text, "Swift echo get: hello") + } + } + func testClientStreaming() { XCTAsyncTest { let collect = self.echo.makeCollectCall() @@ -96,6 +103,19 @@ final class AsyncIntegrationTests: GRPCTestCase { } } + func testClientStreamingWrapper() { + XCTAsyncTest { + let requests: [Echo_EchoRequest] = [ + .with { $0.text = "boyle" }, + .with { $0.text = "jeffers" }, + .with { $0.text = "holt" }, + ] + + let response = try await self.echo.collect(requests) + XCTAssertEqual(response.text, "Swift echo collect: boyle jeffers holt") + } + } + func testServerStreaming() { XCTAsyncTest { let expand = self.echo.makeExpandCall(.with { $0.text = "boyle jeffers holt" }) @@ -103,8 +123,8 @@ final class AsyncIntegrationTests: GRPCTestCase { let initialMetadata = try await expand.initialMetadata initialMetadata.assertFirst("200", forName: ":status") - let respones = try await expand.responses.map { $0.text }.collect() - XCTAssertEqual(respones, [ + let responses = try await expand.responses.map { $0.text }.collect() + XCTAssertEqual(responses, [ "Swift echo expand (0): boyle", "Swift echo expand (1): jeffers", "Swift echo expand (2): holt", @@ -118,6 +138,18 @@ final class AsyncIntegrationTests: GRPCTestCase { } } + func testServerStreamingWrapper() { + XCTAsyncTest { + let responseStream = self.echo.expand(.with { $0.text = "boyle jeffers holt" }) + let responses = try await responseStream.map { $0.text }.collect() + XCTAssertEqual(responses, [ + "Swift echo expand (0): boyle", + "Swift echo expand (1): jeffers", + "Swift echo expand (2): holt", + ]) + } + } + func testBidirectionalStreaming() { XCTAsyncTest { let update = self.echo.makeUpdateCall() @@ -145,6 +177,24 @@ final class AsyncIntegrationTests: GRPCTestCase { XCTAssertTrue(status.isOk) } } + + func testBidirectionalStreamingWrapper() { + XCTAsyncTest { + let requests: [Echo_EchoRequest] = [ + .with { $0.text = "boyle" }, + .with { $0.text = "jeffers" }, + .with { $0.text = "holt" }, + ] + + let responseStream = self.echo.update(requests) + let responses = try await responseStream.map { $0.text }.collect() + XCTAssertEqual(responses, [ + "Swift echo update (0): boyle", + "Swift echo update (1): jeffers", + "Swift echo update (2): holt", + ]) + } + } } extension HPACKHeaders {