diff --git a/Sources/SwiftOperators/OperatorTable+Folding.swift b/Sources/SwiftOperators/OperatorTable+Folding.swift index 06371c9f50a..c1c4f7f6025 100644 --- a/Sources/SwiftOperators/OperatorTable+Folding.swift +++ b/Sources/SwiftOperators/OperatorTable+Folding.swift @@ -155,6 +155,8 @@ extension OperatorTable { ) ) } + // NOTE: If you add a new try/await/unsafe-like hoisting case here, make + // sure to also update `allMacroLexicalContexts` to handle it. // The form of the binary operation depends on the operator itself, // which will be one of the unresolved infix operators. diff --git a/Sources/SwiftSyntaxMacros/Syntax+LexicalContext.swift b/Sources/SwiftSyntaxMacros/Syntax+LexicalContext.swift index 5f17e391db4..f9132040dff 100644 --- a/Sources/SwiftSyntaxMacros/Syntax+LexicalContext.swift +++ b/Sources/SwiftSyntaxMacros/Syntax+LexicalContext.swift @@ -67,6 +67,24 @@ extension SyntaxProtocol { case let freestandingMacro as FreestandingMacroExpansionSyntax: return Syntax(freestandingMacro.detached) as Syntax + // `try`, `await`, and `unsafe` are preserved: A freestanding expression + // macro may need to know whether those keywords are present so it can + // propagate them to any expressions in its expansion which were passed as + // arguments to the macro. The sub-expression is replaced with a trivial + // placeholder, though. + case var tryExpr as TryExprSyntax: + tryExpr = tryExpr.detached + tryExpr.expression = ExprSyntax(TypeExprSyntax(type: IdentifierTypeSyntax(name: .wildcardToken()))) + return Syntax(tryExpr) + case var awaitExpr as AwaitExprSyntax: + awaitExpr = awaitExpr.detached + awaitExpr.expression = ExprSyntax(TypeExprSyntax(type: IdentifierTypeSyntax(name: .wildcardToken()))) + return Syntax(awaitExpr) + case var unsafeExpr as UnsafeExprSyntax: + unsafeExpr = unsafeExpr.detached + unsafeExpr.expression = ExprSyntax(TypeExprSyntax(type: IdentifierTypeSyntax(name: .wildcardToken()))) + return Syntax(unsafeExpr) + default: return nil } @@ -92,6 +110,43 @@ extension SyntaxProtocol { if let parentContext = parentNode.asMacroLexicalContext() { parentContexts.append(parentContext) } + // Unfolded sequence expressions require special handling - effect marker + // nodes like `try`, `await`, and `unsafe` are treated as lexical contexts + // for all the nodes on their right. Cases where they don't end up + // covering nodes to their right in the folded tree are invalid and will + // be diagnosed by the compiler. This matches the compiler's ASTScope + // handling logic. + if let sequence = parentNode.as(SequenceExprSyntax.self) { + var sequenceExprContexts: [Syntax] = [] + for elt in sequence.elements { + if elt.range.contains(self.position) { + // `sequenceExprContexts` is built from the top-down, but we + // build the rest of the contexts bottom-up. Reverse for + // consistency. + parentContexts += sequenceExprContexts.reversed() + break + } + var elt = elt + while true { + if let tryElt = elt.as(TryExprSyntax.self) { + sequenceExprContexts.append(tryElt.asMacroLexicalContext()!) + elt = tryElt.expression + continue + } + if let awaitElt = elt.as(AwaitExprSyntax.self) { + sequenceExprContexts.append(awaitElt.asMacroLexicalContext()!) + elt = awaitElt.expression + continue + } + if let unsafeElt = elt.as(UnsafeExprSyntax.self) { + sequenceExprContexts.append(unsafeElt.asMacroLexicalContext()!) + elt = unsafeElt.expression + continue + } + break + } + } + } currentNode = parentNode } diff --git a/Tests/SwiftSyntaxMacroExpansionTest/LexicalContextTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/LexicalContextTests.swift index 533a2a4d19f..e0cbe3ef9e2 100644 --- a/Tests/SwiftSyntaxMacroExpansionTest/LexicalContextTests.swift +++ b/Tests/SwiftSyntaxMacroExpansionTest/LexicalContextTests.swift @@ -531,7 +531,7 @@ final class LexicalContextTests: XCTestCase { struct S { let arg: C var contextDescription: String { - #lexicalContextDescription + unsafe try await #lexicalContextDescription } } return S(arg: c) @@ -542,7 +542,10 @@ final class LexicalContextTests: XCTestCase { struct S { let arg: C var contextDescription: String { - """ + unsafe try await """ + await _ + try _ + unsafe _ contextDescription: String struct S {} { c in @@ -551,7 +554,7 @@ final class LexicalContextTests: XCTestCase { struct S { let arg: C var contextDescription: String { - #lexicalContextDescription + unsafe try await #lexicalContextDescription } } return S(arg: c) @@ -565,4 +568,182 @@ final class LexicalContextTests: XCTestCase { macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] ) } + + func testEffectMarkersInSequenceLexicalContext() { + // Valid cases. + assertMacroExpansion( + "unsafe try await #lexicalContextDescription + #lexicalContextDescription", + expandedSource: #""" + unsafe try await """ + await _ + try _ + unsafe _ + """ + """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + assertMacroExpansion( + "try unsafe await 0 + 1 + foo(#lexicalContextDescription) + 2", + expandedSource: #""" + try unsafe await 0 + 1 + foo(""" + await _ + unsafe _ + try _ + """) + 2 + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + assertMacroExpansion( + "x = try await unsafe 0 + 1 + foo(#lexicalContextDescription) + 2", + expandedSource: #""" + x = try await unsafe 0 + 1 + foo(""" + unsafe _ + await _ + try _ + """) + 2 + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + // `unsafe try await` in the 'then' branch doesn't cover condition or else. + assertMacroExpansion( + "#lexicalContextDescription ? unsafe try await #lexicalContextDescription : #lexicalContextDescription", + expandedSource: #""" + """ + """ ? unsafe try await """ + await _ + try _ + unsafe _ + """ : """ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + // Same for else. + assertMacroExpansion( + "#lexicalContextDescription ? #lexicalContextDescription : unsafe try await #lexicalContextDescription", + expandedSource: #""" + """ + """ ? """ + """ : unsafe try await """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + // 'unsafe try await' in the condition here covers the entire expression + assertMacroExpansion( + "unsafe try await #lexicalContextDescription ? #lexicalContextDescription : #lexicalContextDescription ~~ #lexicalContextDescription", + expandedSource: #""" + unsafe try await """ + await _ + try _ + unsafe _ + """ ? """ + await _ + try _ + unsafe _ + """ : """ + await _ + try _ + unsafe _ + """ ~~ """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + assertMacroExpansion( + "x = unsafe try try! await 0 + #lexicalContextDescription", + expandedSource: #""" + x = unsafe try try! await 0 + """ + await _ + try! _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + + // Invalid cases + assertMacroExpansion( + "0 + unsafe try await #lexicalContextDescription", + expandedSource: #""" + 0 + unsafe try await """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + // The `unsafe try await` may not actually cover `lexicalContextDescription` + // here, but this will be rejected by the compiler. + assertMacroExpansion( + "0 + unsafe try await 1 ^ #lexicalContextDescription", + expandedSource: #""" + 0 + unsafe try await 1 ^ """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + // Invalid if '^' has a lower precedence than '='. + assertMacroExpansion( + "x = unsafe try await 0 ^ #lexicalContextDescription", + expandedSource: #""" + x = unsafe try await 0 ^ """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + // Unassignable + assertMacroExpansion( + "#lexicalContextDescription = unsafe try await 0 + 1", + expandedSource: #""" + """ + """ = unsafe try await 0 + 1 + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + assertMacroExpansion( + "unsafe try await #lexicalContextDescription = 0 + #lexicalContextDescription", + expandedSource: #""" + unsafe try await """ + await _ + try _ + unsafe _ + """ = 0 + """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + assertMacroExpansion( + "unsafe try await foo() ? 0 : 1 = #lexicalContextDescription", + expandedSource: #""" + unsafe try await foo() ? 0 : 1 = """ + await _ + try _ + unsafe _ + """ + """#, + macros: ["lexicalContextDescription": LexicalContextDescriptionMacro.self] + ) + } }