From 12c088713a180476c5abdcb006d42ef32e1c85f3 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Fri, 1 Sep 2023 09:03:29 -0700 Subject: [PATCH] Fold operators using the standard operator table in `MacroSystem` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The compiler folds operators in attributes and freestanding macro nodes but `MacroSystem` wasn’t doing that. But it should to match the compiler behavior. rdar://114786803 Fixes #2128 --- .../PluginMacroExpansionContext.swift | 2 +- .../BasicMacroExpansionContext.swift | 15 ++++ .../MacroSystem.swift | 64 +++++++++++--- .../ExpressionMacroTests.swift | 58 +++++++++++++ .../MemberMacroTests.swift | 83 +++++++++++++++++++ 5 files changed, 211 insertions(+), 11 deletions(-) diff --git a/Sources/SwiftCompilerPluginMessageHandling/PluginMacroExpansionContext.swift b/Sources/SwiftCompilerPluginMessageHandling/PluginMacroExpansionContext.swift index 63629880d5a..b574c104a70 100644 --- a/Sources/SwiftCompilerPluginMessageHandling/PluginMacroExpansionContext.swift +++ b/Sources/SwiftCompilerPluginMessageHandling/PluginMacroExpansionContext.swift @@ -85,7 +85,7 @@ class SourceManager { case .attribute: node = Syntax(AttributeSyntax.parse(from: &parser)) } - if let operatorTable = operatorTable { + if let operatorTable { node = operatorTable.foldAll(node, errorHandler: { _ in /*ignore*/ }) } diff --git a/Sources/SwiftSyntaxMacroExpansion/BasicMacroExpansionContext.swift b/Sources/SwiftSyntaxMacroExpansion/BasicMacroExpansionContext.swift index 2618aa79b4d..dd4e8f59c78 100644 --- a/Sources/SwiftSyntaxMacroExpansion/BasicMacroExpansionContext.swift +++ b/Sources/SwiftSyntaxMacroExpansion/BasicMacroExpansionContext.swift @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// import SwiftDiagnostics +import SwiftOperators import SwiftSyntax import SwiftSyntaxMacros @@ -77,6 +78,20 @@ extension BasicMacroExpansionContext { detachedNodes[Syntax(detached)] = Syntax(node) return detached } + + /// Fold all operators in `node` and associated the ``KnownSourceFile`` + /// information of `node` with the original new, folded tree. + func foldAllOperators(of node: some SyntaxProtocol, with operatorTable: OperatorTable) -> Syntax { + let folded = operatorTable.foldAll(node, errorHandler: { _ in /*ignore*/ }) + if let originalSourceFile = node.root.as(SourceFileSyntax.self), + let newSourceFile = folded.root.as(SourceFileSyntax.self) + { + // Folding operators doesn't change the source file and its associated locations + // Record the `KnownSourceFile` information for the folded tree. + sourceFiles[newSourceFile] = sourceFiles[originalSourceFile] + } + return folded + } } extension String { diff --git a/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift b/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift index 9659cc8d2a2..6e145d85f9f 100644 --- a/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift +++ b/Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// import SwiftDiagnostics +import SwiftOperators @_spi(MacroExpansion) import SwiftParser import SwiftSyntax import SwiftSyntaxBuilder @@ -55,7 +56,7 @@ private func expandFreestandingMemberDeclList( let expanded = try expandFreestandingMacro( definition: definition, macroRole: inferFreestandingMacroRole(definition: definition), - node: node.detach(in: context), + node: node.detach(in: context, foldingWith: .standardOperators), in: context, indentationWidth: indentationWidth ) @@ -80,7 +81,7 @@ private func expandFreestandingCodeItemList( let expanded = try expandFreestandingMacro( definition: definition, macroRole: inferFreestandingMacroRole(definition: definition), - node: node.detach(in: context), + node: node.detach(in: context, foldingWith: .standardOperators), in: context, indentationWidth: indentationWidth ) @@ -108,7 +109,7 @@ private func expandFreestandingExpr( let expanded = expandFreestandingMacro( definition: definition, macroRole: .expression, - node: node.detach(in: context), + node: node.detach(in: context, foldingWith: .standardOperators), in: context, indentationWidth: indentationWidth ) @@ -134,7 +135,7 @@ private func expandMemberMacro( let expanded = expandAttachedMacro( definition: definition, macroRole: .member, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: attachedTo.detach(in: context), parentDeclNode: nil, extendedType: nil, @@ -163,7 +164,7 @@ private func expandMemberAttributeMacro( let expanded = expandAttachedMacro( definition: definition, macroRole: .memberAttribute, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: member.detach(in: context), parentDeclNode: declaration.detach(in: context), extendedType: nil, @@ -191,7 +192,7 @@ private func expandPeerMacroMember( let expanded = expandAttachedMacro( definition: definition, macroRole: .peer, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: attachedTo.detach(in: context), parentDeclNode: nil, extendedType: nil, @@ -219,7 +220,7 @@ private func expandPeerMacroCodeItem( let expanded = expandAttachedMacro( definition: definition, macroRole: .peer, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: attachedTo.detach(in: context), parentDeclNode: nil, extendedType: nil, @@ -251,7 +252,7 @@ private func expandAccessorMacroWithoutExistingAccessors( let expanded = expandAttachedMacro( definition: definition, macroRole: .accessor, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: attachedTo.detach(in: context), parentDeclNode: nil, extendedType: nil, @@ -285,7 +286,7 @@ private func expandAccessorMacroWithExistingAccessors( let expanded = expandAttachedMacro( definition: definition, macroRole: .accessor, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: attachedTo.detach(in: context), parentDeclNode: nil, extendedType: nil, @@ -322,7 +323,7 @@ private func expandExtensionMacro( let expanded = expandAttachedMacro( definition: definition, macroRole: .extension, - attributeNode: attributeNode.detach(in: context), + attributeNode: attributeNode.detach(in: context, foldingWith: .standardOperators), declarationNode: attachedTo.detach(in: context), parentDeclNode: nil, extendedType: extendedType.detach(in: context), @@ -1011,4 +1012,47 @@ private extension SyntaxProtocol { return self.detached } + + /// Fold operators in this node using the given operator table, detach the + /// node and inform the macro expansion context, if it needs to know. + func detach( + in context: MacroExpansionContext, + foldingWith operatorTable: OperatorTable? + ) -> Syntax { + let folded: Syntax + if let operatorTable { + if let basicContext = context as? BasicMacroExpansionContext { + folded = basicContext.foldAllOperators(of: self, with: operatorTable) + } else { + folded = operatorTable.foldAll(self, errorHandler: { _ in /*ignore*/ }) + } + } else { + folded = Syntax(self) + } + return folded.detach(in: context) + } +} + +private extension FreestandingMacroExpansionSyntax { + /// Same as `SyntaxProtocol.detach(in:foldingWith:)` but returns a node of type + /// `Self` since we know that operator folding doesn't change the type of any + /// `FreestandingMacroExpansionSyntax`. + func detach( + in context: MacroExpansionContext, + foldingWith operatorTable: OperatorTable? + ) -> Self { + return (detach(in: context, foldingWith: operatorTable) as Syntax).cast(Self.self) + } +} + +private extension AttributeSyntax { + /// Same as `SyntaxProtocol.detach(in:foldingWith:)` but returns a node of type + /// `Self` since we know that operator folding doesn't change the type of any + /// `AttributeSyntax`. + func detach( + in context: MacroExpansionContext, + foldingWith operatorTable: OperatorTable? + ) -> Self { + return (detach(in: context, foldingWith: operatorTable) as Syntax).cast(Self.self) + } } diff --git a/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift index a08399a1db9..8ca2401de95 100644 --- a/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift +++ b/Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift @@ -234,4 +234,62 @@ final class ExpressionMacroTests: XCTestCase { ) } + func testFoldOperators() { + struct ForceSubtractMacro: ExpressionMacro { + static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + guard let argument = node.argumentList.first?.expression else { + fatalError("Must receive an argument") + } + guard var node = argument.as(InfixOperatorExprSyntax.self) else { + return argument + } + node.operator = ExprSyntax(BinaryOperatorExprSyntax(text: "-")) + return ExprSyntax(node) + } + } + assertMacroExpansion( + "#test(1 + 2)", + expandedSource: "1 - 2", + macros: ["test": ForceSubtractMacro.self] + ) + } + + func testDiagnosticFromFoldedOperators() { + struct MyError: Error {} + + struct DiagnoseFirstArgument: ExpressionMacro { + static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + guard let argument = node.argumentList.first?.expression else { + fatalError("Must receive an argument") + } + context.addDiagnostics(from: MyError(), node: argument) + return argument + } + } + + assertMacroExpansion( + """ + /// Test + func test() { + #test(1 + 2) + } + """, + expandedSource: """ + /// Test + func test() { + 1 + 2 + } + """, + diagnostics: [ + DiagnosticSpec(message: "MyError()", line: 3, column: 9, severity: .error) + ], + macros: ["test": DiagnoseFirstArgument.self] + ) + } } diff --git a/Tests/SwiftSyntaxMacroExpansionTest/MemberMacroTests.swift b/Tests/SwiftSyntaxMacroExpansionTest/MemberMacroTests.swift index 359164ce612..91b3bf8bad6 100644 --- a/Tests/SwiftSyntaxMacroExpansionTest/MemberMacroTests.swift +++ b/Tests/SwiftSyntaxMacroExpansionTest/MemberMacroTests.swift @@ -172,4 +172,87 @@ final class MemberMacroTests: XCTestCase { indentationWidth: indentationWidth ) } + + func testFoldOperators() { + struct ForceSubtractMacro: MemberMacro { + static func expansion( + of node: AttributeSyntax, + providingMembersOf declaration: some DeclGroupSyntax, + in context: some MacroExpansionContext + ) throws -> [DeclSyntax] { + guard case .argumentList(let arguments) = node.arguments, let argument = arguments.first?.expression else { + fatalError("Must receive an argument") + } + guard var node = argument.as(InfixOperatorExprSyntax.self) else { + return [] + } + node.operator = ExprSyntax(BinaryOperatorExprSyntax(text: "- ")) + return [ + DeclSyntax( + """ + var x: Int { \(node.trimmed) } + """ + ) + ] + } + } + assertMacroExpansion( + """ + /// Test + /// And another line + @Test(1 + 2) + struct Foo { + } + """, + expandedSource: """ + /// Test + /// And another line + struct Foo { + + var x: Int { + 1 - 2 + } + } + """, + macros: ["Test": ForceSubtractMacro.self], + indentationWidth: indentationWidth + ) + } + + func testDiagnosticFromFoldedOperators() { + struct MyError: Error {} + + struct DiagnoseFirstArgument: MemberMacro { + static func expansion( + of node: AttributeSyntax, + providingMembersOf declaration: some DeclGroupSyntax, + in context: some MacroExpansionContext + ) throws -> [DeclSyntax] { + guard case .argumentList(let arguments) = node.arguments, let argument = arguments.first?.expression else { + fatalError("Must receive an argument") + } + context.addDiagnostics(from: MyError(), node: argument) + return [] + } + } + + assertMacroExpansion( + """ + /// Test + /// And another line + @Test(1 + 2) + struct Foo {} + """, + expandedSource: """ + /// Test + /// And another line + struct Foo {} + """, + diagnostics: [ + DiagnosticSpec(message: "MyError()", line: 3, column: 7, severity: .error) + ], + macros: ["Test": DiagnoseFirstArgument.self] + ) + } + }