Skip to content

Commit 8f9e023

Browse files
committed
Check macro expression used as default argument
1 parent f8c3ccc commit 8f9e023

File tree

4 files changed

+204
-58
lines changed

4 files changed

+204
-58
lines changed

Sources/SwiftSyntaxMacroExpansion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_swift_syntax_library(SwiftSyntaxMacroExpansion
22
BasicMacroExpansionContext.swift
33
FunctionParameterUtils.swift
44
IndentationUtils.swift
5+
MacroArgument.swift
56
MacroExpansion.swift
67
MacroExpansionDiagnosticMessages.swift
78
MacroReplacement.swift
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import SwiftDiagnostics
2+
import SwiftSyntax
3+
4+
enum DeclReferenceError: DiagnosticMessage {
5+
case nonLiteral
6+
7+
var message: String {
8+
switch self {
9+
case .nonLiteral:
10+
return "only literals are permitted"
11+
}
12+
}
13+
14+
var diagnosticID: MessageID {
15+
.init(domain: "SwiftMacros", id: "\(self)")
16+
}
17+
18+
var severity: DiagnosticSeverity {
19+
.error
20+
}
21+
}
22+
23+
class CheckDeclReferenceVisitor: SyntaxAnyVisitor {
24+
var diagnostics: [Diagnostic] = []
25+
26+
init() {
27+
super.init(viewMode: .fixedUp)
28+
}
29+
30+
// Integer literals
31+
override func visit(_ node: IntegerLiteralExprSyntax) -> SyntaxVisitorContinueKind {
32+
.visitChildren
33+
}
34+
35+
// Floating point literals
36+
override func visit(_ node: FloatLiteralExprSyntax) -> SyntaxVisitorContinueKind {
37+
.visitChildren
38+
}
39+
40+
// Negative numbers
41+
override func visit(_ node: PrefixOperatorExprSyntax) -> SyntaxVisitorContinueKind {
42+
switch node.operator.tokenKind {
43+
case .prefixOperator("-")
44+
// only allow negation on numbers, not other literal types
45+
where node.expression.is(IntegerLiteralExprSyntax.self)
46+
|| node.expression.is(FloatLiteralExprSyntax.self):
47+
return .visitChildren
48+
default:
49+
return diagnoseNonLiteral(node)
50+
}
51+
}
52+
53+
// Bool literals
54+
override func visit(_ node: BooleanLiteralExprSyntax) -> SyntaxVisitorContinueKind {
55+
.visitChildren
56+
}
57+
58+
// nil literals
59+
override func visit(_ node: NilLiteralExprSyntax) -> SyntaxVisitorContinueKind {
60+
.visitChildren
61+
}
62+
63+
// String literals
64+
override func visit(_ node: StringLiteralExprSyntax) -> SyntaxVisitorContinueKind {
65+
.visitChildren
66+
}
67+
68+
// String interpolation
69+
override func visit(_ node: StringLiteralSegmentListSyntax) -> SyntaxVisitorContinueKind {
70+
guard node.count == 1,
71+
case .stringSegment = node.first! else {
72+
return diagnoseNonLiteral(node)
73+
}
74+
return .visitChildren
75+
}
76+
77+
// Array literals
78+
override func visit(_ node: ArrayExprSyntax) -> SyntaxVisitorContinueKind {
79+
.visitChildren
80+
}
81+
82+
// Dictionary literals
83+
override func visit(_ node: DictionaryExprSyntax) -> SyntaxVisitorContinueKind {
84+
.visitChildren
85+
}
86+
87+
// Tuple literals
88+
override func visit(_ node: TupleExprSyntax) -> SyntaxVisitorContinueKind {
89+
.visitChildren
90+
}
91+
92+
// Regex literals
93+
override func visit(_ node: RegexLiteralExprSyntax) -> SyntaxVisitorContinueKind {
94+
.visitChildren
95+
}
96+
97+
// Macro uses.
98+
override func visit(_ node: MacroExpansionExprSyntax) -> SyntaxVisitorContinueKind {
99+
.visitChildren
100+
}
101+
102+
// References to declarations.
103+
override func visit(_ node: DeclReferenceExprSyntax) -> SyntaxVisitorContinueKind {
104+
return diagnoseNonLiteral(node)
105+
}
106+
107+
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
108+
if node.is(ExprSyntax.self) {
109+
// We have an expression that is not one of the allowed forms, so
110+
// diagnose it.
111+
return diagnoseNonLiteral(node)
112+
}
113+
114+
return .visitChildren
115+
}
116+
117+
func diagnoseNonLiteral(_ node: some SyntaxProtocol) -> SyntaxVisitorContinueKind {
118+
diagnostics.append(
119+
Diagnostic(
120+
node: node,
121+
message: DeclReferenceError.nonLiteral
122+
)
123+
)
124+
125+
return .skipChildren
126+
}
127+
}
128+
129+
extension MacroExpansionExprSyntax {
130+
public func checkDefaultArgumentMacroExpression() throws {
131+
let visitor = CheckDeclReferenceVisitor()
132+
visitor.walk(arguments)
133+
134+
if !visitor.diagnostics.isEmpty {
135+
throw DiagnosticsError(diagnostics: visitor.diagnostics)
136+
}
137+
}
138+
}

Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ enum MacroExpanderError: DiagnosticMessage {
1919
case definitionNotMacroExpansion
2020
case nonParameterReference(TokenSyntax)
2121
case nonTypeReference(TokenSyntax)
22-
case nonLiteralOrParameter(ExprSyntax)
22+
case nonLiteralOrParameter
2323

2424
var message: String {
2525
switch self {
@@ -93,55 +93,14 @@ extension MacroDefinition {
9393
}
9494
}
9595

96-
fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {
96+
fileprivate class ParameterReplacementVisitor: CheckDeclReferenceVisitor {
9797
let macro: MacroDeclSyntax
9898
var replacements: [MacroDefinition.Replacement] = []
9999
var genericReplacements: [MacroDefinition.GenericArgumentReplacement] = []
100-
var diagnostics: [Diagnostic] = []
101100

102101
init(macro: MacroDeclSyntax) {
103102
self.macro = macro
104-
super.init(viewMode: .fixedUp)
105-
}
106-
107-
// Integer literals
108-
override func visit(_ node: IntegerLiteralExprSyntax) -> SyntaxVisitorContinueKind {
109-
.visitChildren
110-
}
111-
112-
// Floating point literals
113-
override func visit(_ node: FloatLiteralExprSyntax) -> SyntaxVisitorContinueKind {
114-
.visitChildren
115-
}
116-
117-
// nil literals
118-
override func visit(_ node: NilLiteralExprSyntax) -> SyntaxVisitorContinueKind {
119-
.visitChildren
120-
}
121-
122-
// String literals
123-
override func visit(_ node: StringLiteralExprSyntax) -> SyntaxVisitorContinueKind {
124-
.visitChildren
125-
}
126-
127-
// Array literals
128-
override func visit(_ node: ArrayExprSyntax) -> SyntaxVisitorContinueKind {
129-
.visitChildren
130-
}
131-
132-
// Dictionary literals
133-
override func visit(_ node: DictionaryExprSyntax) -> SyntaxVisitorContinueKind {
134-
.visitChildren
135-
}
136-
137-
// Tuple literals
138-
override func visit(_ node: TupleExprSyntax) -> SyntaxVisitorContinueKind {
139-
.visitChildren
140-
}
141-
142-
// Macro uses.
143-
override func visit(_ node: MacroExpansionExprSyntax) -> SyntaxVisitorContinueKind {
144-
.visitChildren
103+
super.init()
145104
}
146105

147106
// References to declarations. Only accept those that refer to a parameter
@@ -216,23 +175,16 @@ fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {
216175
return .visitChildren
217176
}
218177

219-
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
220-
if let expr = node.as(ExprSyntax.self) {
221-
// We have an expression that is not one of the allowed forms, so
222-
// diagnose it.
223-
diagnostics.append(
224-
Diagnostic(
225-
node: node,
226-
message: MacroExpanderError.nonLiteralOrParameter(expr)
227-
)
178+
override func diagnoseNonLiteral(_ node: some SyntaxProtocol) -> SyntaxVisitorContinueKind {
179+
diagnostics.append(
180+
Diagnostic(
181+
node: node,
182+
message: MacroExpanderError.nonLiteralOrParameter
228183
)
184+
)
229185

230-
return .skipChildren
231-
}
232-
233-
return .visitChildren
186+
return .skipChildren
234187
}
235-
236188
}
237189

238190
extension MacroDeclSyntax {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import SwiftDiagnostics
14+
import SwiftSyntax
15+
import SwiftSyntaxBuilder
16+
import SwiftSyntaxMacroExpansion
17+
import XCTest
18+
import _SwiftSyntaxTestSupport
19+
20+
final class MacroArgumentTests: XCTestCase {
21+
func testDefaultArgumentMacroExprGood() throws {
22+
let macro: ExprSyntax =
23+
"""
24+
#otherMacro(first: (/foo/, 0x42), second: ["a": nil], third: [3.14159, -2.71828], fourth: true)
25+
"""
26+
27+
XCTAssertNoThrow(try macro.as(MacroExpansionExprSyntax.self)!
28+
.checkDefaultArgumentMacroExpression())
29+
}
30+
31+
func testDefaultArgumentMacroExprNonLiteral() throws {
32+
let macro: ExprSyntax =
33+
#"""
34+
#otherMacro(first: b, second: "\(false)", third: 1 + 2)
35+
"""#
36+
37+
let diags: [Diagnostic]
38+
do {
39+
try macro.as(MacroExpansionExprSyntax.self)!
40+
.checkDefaultArgumentMacroExpression()
41+
XCTFail("should have failed with an error")
42+
fatalError()
43+
} catch let diagError as DiagnosticsError {
44+
diags = diagError.diagnostics
45+
}
46+
47+
XCTAssertEqual(diags.count, 3)
48+
for diag in diags {
49+
XCTAssertEqual(
50+
diag.diagMessage.message,
51+
"only literals are permitted"
52+
)
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)