Skip to content

Commit 2f5dd64

Browse files
authored
Merge pull request #2447 from ApolloZhu/macro/expression-as-default-argument
[Macros] Checking expression macro as caller-side default argument
2 parents 4f382a4 + f78e477 commit 2f5dd64

File tree

4 files changed

+228
-58
lines changed

4 files changed

+228
-58
lines changed

Sources/SwiftSyntaxMacroExpansion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_swift_syntax_library(SwiftSyntaxMacroExpansion
22
BasicMacroExpansionContext.swift
33
FunctionParameterUtils.swift
44
IndentationUtils.swift
5+
MacroArgument.swift
56
MacroExpansion.swift
67
MacroExpansionDiagnosticMessages.swift
78
MacroReplacement.swift
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import SwiftDiagnostics
14+
import SwiftSyntax
15+
16+
enum DeclReferenceError: DiagnosticMessage {
17+
case nonLiteral
18+
19+
var message: String {
20+
switch self {
21+
case .nonLiteral:
22+
return "only literals are permitted"
23+
}
24+
}
25+
26+
var diagnosticID: MessageID {
27+
.init(domain: "SwiftMacros", id: "\(self)")
28+
}
29+
30+
var severity: DiagnosticSeverity {
31+
.error
32+
}
33+
}
34+
35+
/// Check sub-expressions to ensure all expressions are literals, and call
36+
/// `diagnoseNonLiteral` for all other expressions.
37+
class OnlyLiteralExprChecker: SyntaxAnyVisitor {
38+
var diagnostics: [Diagnostic] = []
39+
40+
init() {
41+
super.init(viewMode: .fixedUp)
42+
}
43+
44+
// Integer literals
45+
override func visit(_ node: IntegerLiteralExprSyntax) -> SyntaxVisitorContinueKind {
46+
.visitChildren
47+
}
48+
49+
// Floating point literals
50+
override func visit(_ node: FloatLiteralExprSyntax) -> SyntaxVisitorContinueKind {
51+
.visitChildren
52+
}
53+
54+
// Negative numbers
55+
override func visit(_ node: PrefixOperatorExprSyntax) -> SyntaxVisitorContinueKind {
56+
switch node.operator.tokenKind {
57+
case .prefixOperator("-")
58+
// only allow negation on numbers, not other literal types
59+
where node.expression.is(IntegerLiteralExprSyntax.self)
60+
|| node.expression.is(FloatLiteralExprSyntax.self):
61+
return .visitChildren
62+
default:
63+
return diagnoseNonLiteral(node)
64+
}
65+
}
66+
67+
// Bool literals
68+
override func visit(_ node: BooleanLiteralExprSyntax) -> SyntaxVisitorContinueKind {
69+
.visitChildren
70+
}
71+
72+
// nil literals
73+
override func visit(_ node: NilLiteralExprSyntax) -> SyntaxVisitorContinueKind {
74+
.visitChildren
75+
}
76+
77+
// String literals
78+
override func visit(_ node: StringLiteralExprSyntax) -> SyntaxVisitorContinueKind {
79+
.visitChildren
80+
}
81+
82+
// String interpolation
83+
override func visit(_ node: StringLiteralSegmentListSyntax) -> SyntaxVisitorContinueKind {
84+
guard node.count == 1,
85+
case .stringSegment = node.first!
86+
else {
87+
return diagnoseNonLiteral(node)
88+
}
89+
return .visitChildren
90+
}
91+
92+
// Array literals
93+
override func visit(_ node: ArrayExprSyntax) -> SyntaxVisitorContinueKind {
94+
.visitChildren
95+
}
96+
97+
// Dictionary literals
98+
override func visit(_ node: DictionaryExprSyntax) -> SyntaxVisitorContinueKind {
99+
.visitChildren
100+
}
101+
102+
// Tuple literals
103+
override func visit(_ node: TupleExprSyntax) -> SyntaxVisitorContinueKind {
104+
.visitChildren
105+
}
106+
107+
// Regex literals
108+
override func visit(_ node: RegexLiteralExprSyntax) -> SyntaxVisitorContinueKind {
109+
.visitChildren
110+
}
111+
112+
// Macro uses.
113+
override func visit(_ node: MacroExpansionExprSyntax) -> SyntaxVisitorContinueKind {
114+
.visitChildren
115+
}
116+
117+
// References to declarations.
118+
override func visit(_ node: DeclReferenceExprSyntax) -> SyntaxVisitorContinueKind {
119+
return diagnoseNonLiteral(node)
120+
}
121+
122+
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
123+
if node.is(ExprSyntax.self) {
124+
// We have an expression that is not one of the allowed forms, so
125+
// diagnose it.
126+
return diagnoseNonLiteral(node)
127+
}
128+
129+
return .visitChildren
130+
}
131+
132+
func diagnoseNonLiteral(_ node: some SyntaxProtocol) -> SyntaxVisitorContinueKind {
133+
diagnostics.append(
134+
Diagnostic(
135+
node: node,
136+
message: DeclReferenceError.nonLiteral
137+
)
138+
)
139+
140+
return .skipChildren
141+
}
142+
}
143+
144+
extension MacroExpansionExprSyntax {
145+
/// For compiler to check a macro expression used as default argument.
146+
///
147+
/// Only literals are permitted as arguments to these expressions.
148+
///
149+
/// If there are diagnostics, they will be wrapped into an error and thrown.
150+
@_spi(Compiler)
151+
public func checkDefaultArgumentMacroExpression() throws {
152+
let visitor = OnlyLiteralExprChecker()
153+
visitor.walk(arguments)
154+
155+
if !visitor.diagnostics.isEmpty {
156+
throw DiagnosticsError(diagnostics: visitor.diagnostics)
157+
}
158+
}
159+
}

Sources/SwiftSyntaxMacroExpansion/MacroReplacement.swift

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ enum MacroExpanderError: DiagnosticMessage {
1919
case definitionNotMacroExpansion
2020
case nonParameterReference(TokenSyntax)
2121
case nonTypeReference(TokenSyntax)
22-
case nonLiteralOrParameter(ExprSyntax)
22+
case nonLiteralOrParameter
2323

2424
var message: String {
2525
switch self {
@@ -93,55 +93,14 @@ extension MacroDefinition {
9393
}
9494
}
9595

96-
fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {
96+
fileprivate class ParameterReplacementVisitor: OnlyLiteralExprChecker {
9797
let macro: MacroDeclSyntax
9898
var replacements: [MacroDefinition.Replacement] = []
9999
var genericReplacements: [MacroDefinition.GenericArgumentReplacement] = []
100-
var diagnostics: [Diagnostic] = []
101100

102101
init(macro: MacroDeclSyntax) {
103102
self.macro = macro
104-
super.init(viewMode: .fixedUp)
105-
}
106-
107-
// Integer literals
108-
override func visit(_ node: IntegerLiteralExprSyntax) -> SyntaxVisitorContinueKind {
109-
.visitChildren
110-
}
111-
112-
// Floating point literals
113-
override func visit(_ node: FloatLiteralExprSyntax) -> SyntaxVisitorContinueKind {
114-
.visitChildren
115-
}
116-
117-
// nil literals
118-
override func visit(_ node: NilLiteralExprSyntax) -> SyntaxVisitorContinueKind {
119-
.visitChildren
120-
}
121-
122-
// String literals
123-
override func visit(_ node: StringLiteralExprSyntax) -> SyntaxVisitorContinueKind {
124-
.visitChildren
125-
}
126-
127-
// Array literals
128-
override func visit(_ node: ArrayExprSyntax) -> SyntaxVisitorContinueKind {
129-
.visitChildren
130-
}
131-
132-
// Dictionary literals
133-
override func visit(_ node: DictionaryExprSyntax) -> SyntaxVisitorContinueKind {
134-
.visitChildren
135-
}
136-
137-
// Tuple literals
138-
override func visit(_ node: TupleExprSyntax) -> SyntaxVisitorContinueKind {
139-
.visitChildren
140-
}
141-
142-
// Macro uses.
143-
override func visit(_ node: MacroExpansionExprSyntax) -> SyntaxVisitorContinueKind {
144-
.visitChildren
103+
super.init()
145104
}
146105

147106
// References to declarations. Only accept those that refer to a parameter
@@ -216,23 +175,16 @@ fileprivate class ParameterReplacementVisitor: SyntaxAnyVisitor {
216175
return .visitChildren
217176
}
218177

219-
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
220-
if let expr = node.as(ExprSyntax.self) {
221-
// We have an expression that is not one of the allowed forms, so
222-
// diagnose it.
223-
diagnostics.append(
224-
Diagnostic(
225-
node: node,
226-
message: MacroExpanderError.nonLiteralOrParameter(expr)
227-
)
178+
override func diagnoseNonLiteral(_ node: some SyntaxProtocol) -> SyntaxVisitorContinueKind {
179+
diagnostics.append(
180+
Diagnostic(
181+
node: node,
182+
message: MacroExpanderError.nonLiteralOrParameter
228183
)
184+
)
229185

230-
return .skipChildren
231-
}
232-
233-
return .visitChildren
186+
return .skipChildren
234187
}
235-
236188
}
237189

238190
extension MacroDeclSyntax {
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2024 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import SwiftDiagnostics
14+
import SwiftSyntax
15+
import SwiftSyntaxBuilder
16+
@_spi(Compiler) import SwiftSyntaxMacroExpansion
17+
import XCTest
18+
import _SwiftSyntaxTestSupport
19+
20+
final class MacroArgumentTests: XCTestCase {
21+
func testDefaultArgumentMacroExprGood() throws {
22+
let macro: ExprSyntax =
23+
"""
24+
#otherMacro(first: (/foo/, 0x42), second: ["a": nil], third: [3.14159, -2.71828], fourth: true)
25+
"""
26+
27+
XCTAssertNoThrow(
28+
try macro.as(MacroExpansionExprSyntax.self)!
29+
.checkDefaultArgumentMacroExpression()
30+
)
31+
}
32+
33+
func testDefaultArgumentMacroExprNonLiteral() throws {
34+
let macro: ExprSyntax =
35+
#"""
36+
#otherMacro(first: b, second: "\(false)", third: 1 + 2)
37+
"""#
38+
39+
XCTAssertThrowsError(
40+
try macro.as(MacroExpansionExprSyntax.self)!
41+
.checkDefaultArgumentMacroExpression()
42+
) { error in
43+
guard let diagError = error as? DiagnosticsError else {
44+
XCTFail("should have failed with a diagnostics error")
45+
return
46+
}
47+
let diags = diagError.diagnostics
48+
49+
XCTAssertEqual(diags.count, 3)
50+
for diag in diags {
51+
XCTAssertEqual(
52+
diag.diagMessage.message,
53+
"only literals are permitted"
54+
)
55+
}
56+
}
57+
}
58+
}

0 commit comments

Comments
 (0)