Skip to content

Commit ab46cdd

Browse files
authored
Merge pull request #2247 from gohanlon/unit-test-attribute-remover
Unit testing of `AttributeRemover`
2 parents 693a130 + 80113cd commit ab46cdd

File tree

2 files changed

+139
-175
lines changed

2 files changed

+139
-175
lines changed

Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -384,19 +384,24 @@ struct MacroSystem {
384384
// MARK: - MacroApplication
385385

386386
/// Removes attributes from a syntax tree while maintaining their surrounding trivia.
387-
private class AttributeRemover: SyntaxRewriter {
388-
var attributesToRemove: [AttributeSyntax]
387+
@_spi(Testing)
388+
public class AttributeRemover: SyntaxRewriter {
389+
let predicate: (AttributeSyntax) -> Bool
389390

390391
var triviaToAttachToNextToken: Trivia = Trivia()
391392

392-
init(attributesToRemove: [AttributeSyntax]) {
393-
self.attributesToRemove = attributesToRemove
393+
/// Initializes an attribute remover with a given predicate to determine which attributes to remove.
394+
///
395+
/// - Parameter predicate: A closure that determines whether a given `AttributeSyntax` should be removed.
396+
/// If this closure returns `true` for an attribute, that attribute will be removed.
397+
public init(removingWhere predicate: @escaping (AttributeSyntax) -> Bool) {
398+
self.predicate = predicate
394399
}
395400

396-
override func visit(_ node: AttributeListSyntax) -> AttributeListSyntax {
401+
public override func visit(_ node: AttributeListSyntax) -> AttributeListSyntax {
397402
var filteredAttributes: [AttributeListSyntax.Element] = []
398403
for case .attribute(let attribute) in node {
399-
if attributesToRemove.contains(attribute) {
404+
if self.predicate(attribute) {
400405
var leadingTrivia = attribute.leadingTrivia
401406

402407
// Don't leave behind an empty line when the attribute being removed is on its own line,
@@ -450,7 +455,7 @@ private class AttributeRemover: SyntaxRewriter {
450455
return AttributeListSyntax(filteredAttributes)
451456
}
452457

453-
override func visit(_ token: TokenSyntax) -> TokenSyntax {
458+
public override func visit(_ token: TokenSyntax) -> TokenSyntax {
454459
return prependAndClearAccumulatedTrivia(to: token)
455460
}
456461

@@ -573,7 +578,7 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
573578

574579
let attributesToRemove = self.macroAttributes(attachedTo: visitedNode).map(\.attributeNode)
575580

576-
return AttributeRemover(attributesToRemove: attributesToRemove).rewrite(visitedNode)
581+
return AttributeRemover(removingWhere: { attributesToRemove.contains($0) }).rewrite(visitedNode)
577582
}
578583

579584
return nil

0 commit comments

Comments
 (0)