Skip to content

Commit f6e7f0f

Browse files
authored
Merge pull request #2924 from rintaro/syntax-unsafe-casting
[Syntax] Add `init(_unsafeCasting: Syntax)` to concrete node types
2 parents eae92fe + a25f490 commit f6e7f0f

19 files changed

+6829
-2896
lines changed

CodeGeneration/Sources/generate-swift-syntax/templates/swiftsyntax/SyntaxCollectionsFile.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ let syntaxCollectionsFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
5050
"""
5151
)
5252

53+
DeclSyntax(
54+
"""
55+
@_transparent
56+
init(unsafeCasting node: Syntax) {
57+
self._syntaxNode = node
58+
}
59+
"""
60+
)
61+
5362
DeclSyntax("public static let syntaxKind = SyntaxKind.\(node.memberCallName)")
5463
}
5564
}

CodeGeneration/Sources/generate-swift-syntax/templates/swiftsyntax/SyntaxNodesFile.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ func syntaxNode(nodesStartingWith: [Character]) -> SourceFileSyntax {
5252
"""
5353
)
5454

55+
DeclSyntax(
56+
"""
57+
@_transparent
58+
init(unsafeCasting node: Syntax) {
59+
self._syntaxNode = node
60+
}
61+
"""
62+
)
63+
5564
let initSignature = InitSignature(node)
5665

5766
try! InitializerDeclSyntax(

CodeGeneration/Sources/generate-swift-syntax/templates/swiftsyntax/SyntaxRewriterFile.swift

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
7272
"""
7373
/// Rewrite `node`, keeping its parent unless `detach` is `true`.
7474
public func rewrite(_ node: some SyntaxProtocol, detach: Bool = false) -> Syntax {
75-
var rewritten = Syntax(node)
76-
self.dispatchVisit(&rewritten)
75+
let rewritten = self.visitImpl(Syntax(node))
7776
if detach {
7877
return rewritten
7978
}
@@ -87,11 +86,20 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
8786

8887
DeclSyntax(
8988
"""
90-
/// Visit a ``TokenSyntax``.
91-
/// - Parameter token: the token that is being visited
89+
/// Visit any Syntax node.
90+
/// - Parameter node: the node that is being visited
9291
/// - Returns: the rewritten node
93-
open func visit(_ token: TokenSyntax) -> TokenSyntax {
94-
return token
92+
@available(*, deprecated, renamed: "rewrite(_:detach:)")
93+
public func visit(_ node: Syntax) -> Syntax {
94+
return visitImpl(node)
95+
}
96+
"""
97+
)
98+
99+
DeclSyntax(
100+
"""
101+
public func visit<T: SyntaxChildChoices>(_ node: T) -> T {
102+
visitImpl(Syntax(node)).cast(T.self)
95103
}
96104
"""
97105
)
@@ -133,24 +141,11 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
133141

134142
DeclSyntax(
135143
"""
136-
/// Visit any Syntax node.
137-
/// - Parameter node: the node that is being visited
144+
/// Visit a ``TokenSyntax``.
145+
/// - Parameter token: the token that is being visited
138146
/// - Returns: the rewritten node
139-
@available(*, deprecated, renamed: "rewrite(_:detach:)")
140-
public func visit(_ node: Syntax) -> Syntax {
141-
var rewritten = node
142-
dispatchVisit(&rewritten)
143-
return rewritten
144-
}
145-
"""
146-
)
147-
148-
DeclSyntax(
149-
"""
150-
public func visit<T: SyntaxChildChoices>(_ node: T) -> T {
151-
var rewritten = Syntax(node)
152-
dispatchVisit(&rewritten)
153-
return rewritten.cast(T.self)
147+
open func visit(_ token: TokenSyntax) -> TokenSyntax {
148+
return token
154149
}
155150
"""
156151
)
@@ -164,7 +159,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
164159
/// - Returns: the rewritten node
165160
\(node.apiAttributes())\
166161
open func visit(_ node: \(node.kind.syntaxType)) -> \(node.kind.syntaxType) {
167-
return visitChildren(node._syntaxNode).cast(\(node.kind.syntaxType).self)
162+
return \(node.kind.syntaxType)(unsafeCasting: visitChildren(node._syntaxNode))
168163
}
169164
"""
170165
)
@@ -176,7 +171,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
176171
/// - Returns: the rewritten node
177172
\(node.apiAttributes())\
178173
open func visit(_ node: \(node.kind.syntaxType)) -> \(node.baseType.syntaxBaseName) {
179-
return \(node.baseType.syntaxBaseName)(visitChildren(node._syntaxNode).cast(\(node.kind.syntaxType).self))
174+
return \(node.baseType.syntaxBaseName)(\(node.kind.syntaxType)(unsafeCasting: visitChildren(node._syntaxNode)))
180175
}
181176
"""
182177
)
@@ -193,32 +188,35 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
193188
/// - Returns: the rewritten node
194189
\(baseNode.apiAttributes())\
195190
public func visit(_ node: \(baseKind.syntaxType)) -> \(baseKind.syntaxType) {
196-
var node: Syntax = Syntax(node)
197-
dispatchVisit(&node)
198-
return node.cast(\(baseKind.syntaxType).self)
191+
visitImpl(Syntax(node)).cast(\(baseKind.syntaxType).self)
199192
}
200193
"""
201194
)
202195
}
203196

197+
// NOTE: '@inline(never)' because perf tests showed the best results.
198+
// It keeps 'dispatchVisit(_:)' function small, and make all 'case' bodies exactly the same pattern.
199+
// Which enables some optimizations.
204200
DeclSyntax(
205201
"""
206-
/// Interpret `node` as a node of type `nodeType`, visit it, calling
207-
/// the `visit` to transform the node.
208-
@inline(__always)
209-
private func visitImpl<NodeType: SyntaxProtocol>(
210-
_ node: inout Syntax,
211-
_ nodeType: NodeType.Type,
212-
_ visit: (NodeType) -> some SyntaxProtocol
213-
) {
214-
let origNode = node
215-
visitPre(origNode)
216-
node = visitAny(origNode) ?? Syntax(visit(origNode.cast(NodeType.self)))
217-
visitPost(origNode)
202+
@inline(never)
203+
private func visitTokenSyntaxImpl(_ node: Syntax) -> Syntax {
204+
Syntax(visit(TokenSyntax(unsafeCasting: node)))
218205
}
219206
"""
220207
)
221208

209+
for node in NON_BASE_SYNTAX_NODES {
210+
DeclSyntax(
211+
"""
212+
@inline(never)
213+
private func visit\(node.kind.syntaxType)Impl(_ node: Syntax) -> Syntax {
214+
Syntax(visit(\(node.kind.syntaxType)(unsafeCasting: node)))
215+
}
216+
"""
217+
)
218+
}
219+
222220
try IfConfigDeclSyntax(
223221
leadingTrivia:
224222
"""
@@ -255,26 +253,26 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
255253
/// that determines the correct visitation function will be popped of the
256254
/// stack before the function is being called, making the switch's stack
257255
/// space transient instead of having it linger in the call stack.
258-
private func visitationFunc(for node: Syntax) -> ((inout Syntax) -> Void)
256+
private func visitationFunc(for node: Syntax) -> (Syntax) -> Syntax
259257
"""
260258
) {
261259
try SwitchExprSyntax("switch node.raw.kind") {
262260
SwitchCaseSyntax("case .token:") {
263-
StmtSyntax("return { self.visitImpl(&$0, TokenSyntax.self, self.visit) }")
261+
StmtSyntax("return self.visitTokenSyntaxImpl(_:)")
264262
}
265263

266264
for node in NON_BASE_SYNTAX_NODES {
267265
SwitchCaseSyntax("case .\(node.enumCaseCallName):") {
268-
StmtSyntax("return { self.visitImpl(&$0, \(node.kind.syntaxType).self, self.visit) }")
266+
StmtSyntax("return self.visit\(node.kind.syntaxType)Impl(_:)")
269267
}
270268
}
271269
}
272270
}
273271

274272
DeclSyntax(
275273
"""
276-
private func dispatchVisit(_ node: inout Syntax) {
277-
visitationFunc(for: node)(&node)
274+
private func dispatchVisit(_ node: Syntax) -> Syntax {
275+
visitationFunc(for: node)(node)
278276
}
279277
"""
280278
)
@@ -285,15 +283,15 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
285283
poundKeyword: .poundElseToken(),
286284
elements: .statements(
287285
CodeBlockItemListSyntax {
288-
try! FunctionDeclSyntax("private func dispatchVisit(_ node: inout Syntax)") {
286+
try! FunctionDeclSyntax("private func dispatchVisit(_ node: Syntax) -> Syntax") {
289287
try SwitchExprSyntax("switch node.raw.kind") {
290288
SwitchCaseSyntax("case .token:") {
291-
StmtSyntax("return visitImpl(&node, TokenSyntax.self, visit)")
289+
StmtSyntax("return visitTokenSyntaxImpl(node)")
292290
}
293291

294292
for node in NON_BASE_SYNTAX_NODES {
295293
SwitchCaseSyntax("case .\(node.enumCaseCallName):") {
296-
StmtSyntax("return visitImpl(&node, \(node.kind.syntaxType).self, visit)")
294+
StmtSyntax("return visit\(node.kind.syntaxType)Impl(node)")
297295
}
298296
}
299297
}
@@ -304,6 +302,16 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
304302
}
305303
)
306304

305+
DeclSyntax(
306+
"""
307+
private func visitImpl(_ node: Syntax) -> Syntax {
308+
visitPre(node)
309+
defer { visitPost(node) }
310+
return visitAny(node) ?? dispatchVisit(node)
311+
}
312+
"""
313+
)
314+
307315
DeclSyntax(
308316
"""
309317
private func visitChildren(_ node: Syntax) -> Syntax {
@@ -325,9 +333,7 @@ let syntaxRewriterFile = SourceFileSyntax(leadingTrivia: copyrightHeader) {
325333
for case let (child?, info) in RawSyntaxChildren(node) where viewMode.shouldTraverse(node: child) {
326334
327335
// Build the Syntax node to rewrite
328-
var childNode = nodeFactory.create(parent: node, raw: child, absoluteInfo: info)
329-
330-
dispatchVisit(&childNode)
336+
var childNode = visitImpl(nodeFactory.create(parent: node, raw: child, absoluteInfo: info))
331337
if childNode.raw.id != child.id {
332338
// The node was rewritten, let's handle it
333339

0 commit comments

Comments
 (0)