Skip to content

Commit ea595b5

Browse files
authored
Code generation for async-await (#1259)
Motivation: Manually constructing clients and servers is an error prone nightmare. We should generate them instead! Modifications: - Add async-await code-generation for server and client. - The client code generation is missing "simple-safe" wrappers for now, this can be added later. - Naming represents the current state of the branch rather than anything final - Add options for "ExperimentalAsyncClient" and "ExperimentalAsyncServer" -- these may be used in conjunction with the 'regular' "Client" and "Server" options. Result: We can generate async-await style grpc clients and servers.
1 parent af81568 commit ea595b5

File tree

8 files changed

+552
-11
lines changed

8 files changed

+552
-11
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright 2021, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
import SwiftProtobuf
17+
import SwiftProtobufPluginLibrary
18+
19+
// MARK: - Client protocol
20+
21+
extension Generator {
22+
internal func printAsyncServiceClientProtocol() {
23+
let comments = self.service.protoSourceComments()
24+
if !comments.isEmpty {
25+
// Source comments already have the leading '///'
26+
self.println(comments, newline: false)
27+
}
28+
29+
self.printAvailabilityForAsyncAwait()
30+
self.println("\(self.access) protocol \(self.asyncClientProtocolName): GRPCClient {")
31+
self.withIndentation {
32+
self.println("var serviceName: String { get }")
33+
self.println("var interceptors: \(self.clientInterceptorProtocolName)? { get }")
34+
35+
for method in service.methods {
36+
self.println()
37+
self.method = method
38+
39+
let rpcType = streamingType(self.method)
40+
let callType = Types.call(for: rpcType)
41+
42+
let arguments: [String]
43+
switch rpcType {
44+
case .unary, .serverStreaming:
45+
arguments = [
46+
"_ request: \(self.methodInputName)",
47+
"callOptions: \(Types.clientCallOptions)?",
48+
]
49+
50+
case .clientStreaming, .bidirectionalStreaming:
51+
arguments = [
52+
"callOptions: \(Types.clientCallOptions)?",
53+
]
54+
}
55+
56+
self.printFunction(
57+
name: "make\(self.method.name)Call",
58+
arguments: arguments,
59+
returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
60+
bodyBuilder: nil
61+
)
62+
}
63+
}
64+
self.println("}") // protocol
65+
}
66+
}
67+
68+
// MARK: - Client protocol default implementation: Calls
69+
70+
extension Generator {
71+
internal func printAsyncClientProtocolExtension() {
72+
self.printAvailabilityForAsyncAwait()
73+
self.withIndentation("extension \(self.asyncClientProtocolName)", braces: .curly) {
74+
// Service name. TODO: use static metadata.
75+
self.withIndentation("\(self.access) var serviceName: String", braces: .curly) {
76+
self.println("return \"\(self.servicePath)\"")
77+
}
78+
self.println()
79+
80+
// Interceptor factory.
81+
self.withIndentation(
82+
"\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?",
83+
braces: .curly
84+
) {
85+
self.println("return nil")
86+
}
87+
88+
// 'Unsafe' calls.
89+
for method in self.service.methods {
90+
self.println()
91+
self.method = method
92+
93+
let rpcType = streamingType(self.method)
94+
let callType = Types.call(for: rpcType)
95+
let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false)
96+
97+
switch rpcType {
98+
case .unary, .serverStreaming:
99+
self.printFunction(
100+
name: "make\(self.method.name)Call",
101+
arguments: [
102+
"_ request: \(self.methodInputName)",
103+
"callOptions: \(Types.clientCallOptions)? = nil",
104+
],
105+
returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
106+
access: self.access
107+
) {
108+
self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) {
109+
self.println("path: \(self.methodPath),")
110+
self.println("request: request,")
111+
self.println("callOptions: callOptions ?? self.defaultCallOptions")
112+
}
113+
}
114+
115+
case .clientStreaming, .bidirectionalStreaming:
116+
self.printFunction(
117+
name: "make\(self.method.name)Call",
118+
arguments: ["callOptions: \(Types.clientCallOptions)? = nil"],
119+
returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>",
120+
access: self.access
121+
) {
122+
self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) {
123+
self.println("path: \(self.methodPath),")
124+
self.println("callOptions: callOptions ?? self.defaultCallOptions")
125+
}
126+
}
127+
}
128+
}
129+
}
130+
}
131+
}
132+
133+
// MARK: - Client protocol implementation
134+
135+
extension Generator {
136+
internal func printAsyncServiceClientImplementation() {
137+
self.printAvailabilityForAsyncAwait()
138+
self.withIndentation(
139+
"\(self.access) struct \(self.asyncClientClassName): \(self.asyncClientProtocolName)",
140+
braces: .curly
141+
) {
142+
self.println("\(self.access) var channel: GRPCChannel")
143+
self.println("\(self.access) var defaultCallOptions: CallOptions")
144+
self.println("\(self.access) var interceptors: \(self.clientInterceptorProtocolName)?")
145+
self.println()
146+
147+
self.println("\(self.access) init(")
148+
self.withIndentation {
149+
self.println("channel: GRPCChannel,")
150+
self.println("defaultCallOptions: CallOptions = CallOptions(),")
151+
self.println("interceptors: \(self.clientInterceptorProtocolName)? = nil")
152+
}
153+
self.println(") {")
154+
self.withIndentation {
155+
self.println("self.channel = channel")
156+
self.println("self.defaultCallOptions = defaultCallOptions")
157+
self.println("self.interceptors = interceptors")
158+
}
159+
self.println("}")
160+
}
161+
}
162+
}

Sources/protoc-gen-grpc-swift/Generator-Client.swift

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ extension Generator {
3030
self.printServiceClientImplementation()
3131
}
3232

33+
if self.options.generateAsyncClient {
34+
self.println()
35+
self.printIfCompilerGuardForAsyncAwait()
36+
self.printAsyncServiceClientProtocol()
37+
self.println()
38+
self.printAsyncClientProtocolExtension()
39+
self.println()
40+
self.printAsyncServiceClientImplementation()
41+
self.println()
42+
self.printEndCompilerGuardForAsyncAwait()
43+
}
44+
3345
if self.options.generateTestClient {
3446
self.println()
3547
self.printTestClient()
@@ -41,19 +53,39 @@ extension Generator {
4153
arguments: [String],
4254
returnType: String?,
4355
access: String? = nil,
56+
sendable: Bool = false,
57+
async: Bool = false,
58+
throws: Bool = false,
59+
genericWhereClause: String? = nil,
4460
bodyBuilder: (() -> Void)?
4561
) {
4662
// Add a space after access, if it exists.
47-
let accessOrEmpty = access.map { $0 + " " } ?? ""
48-
let `return` = returnType.map { "-> " + $0 } ?? ""
63+
let functionHead = (access.map { $0 + " " } ?? "") + (sendable ? "@Sendable " : "")
64+
let `return` = returnType.map { " -> " + $0 } ?? ""
65+
let genericWhere = genericWhereClause.map { " " + $0 } ?? ""
66+
67+
let asyncThrows: String
68+
switch (async, `throws`) {
69+
case (true, true):
70+
asyncThrows = " async throws"
71+
case (true, false):
72+
asyncThrows = " async"
73+
case (false, true):
74+
asyncThrows = " throws"
75+
case (false, false):
76+
asyncThrows = ""
77+
}
4978

5079
let hasBody = bodyBuilder != nil
5180

5281
if arguments.isEmpty {
5382
// Don't bother splitting across multiple lines if there are no arguments.
54-
self.println("\(accessOrEmpty)func \(name)() \(`return`)", newline: !hasBody)
83+
self.println(
84+
"\(functionHead)func \(name)()\(asyncThrows)\(`return`)\(genericWhere)",
85+
newline: !hasBody
86+
)
5587
} else {
56-
self.println("\(accessOrEmpty)func \(name)(")
88+
self.println("\(functionHead)func \(name)(")
5789
self.withIndentation {
5890
// Add a comma after each argument except the last.
5991
arguments.forEach(beforeLast: {
@@ -62,7 +94,7 @@ extension Generator {
6294
self.println($0)
6395
})
6496
}
65-
self.println(") \(`return`)", newline: !hasBody)
97+
self.println(")\(asyncThrows)\(`return`)\(genericWhere)", newline: !hasBody)
6698
}
6799

68100
if let bodyBuilder = bodyBuilder {

Sources/protoc-gen-grpc-swift/Generator-Names.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,18 @@ extension Generator {
7474
return nameForPackageService(file, service) + "Provider"
7575
}
7676

77+
internal var asyncProviderName: String {
78+
return nameForPackageService(file, service) + "AsyncProvider"
79+
}
80+
7781
internal var clientClassName: String {
7882
return nameForPackageService(file, service) + "Client"
7983
}
8084

85+
internal var asyncClientClassName: String {
86+
return nameForPackageService(file, service) + "AsyncClient"
87+
}
88+
8189
internal var testClientClassName: String {
8290
return nameForPackageService(self.file, self.service) + "TestClient"
8391
}
@@ -86,6 +94,10 @@ extension Generator {
8694
return nameForPackageService(file, service) + "ClientProtocol"
8795
}
8896

97+
internal var asyncClientProtocolName: String {
98+
return nameForPackageService(file, service) + "AsyncClientProtocol"
99+
}
100+
89101
internal var clientInterceptorProtocolName: String {
90102
return nameForPackageService(file, service) + "ClientInterceptorFactoryProtocol"
91103
}

0 commit comments

Comments
 (0)