Skip to content

Infer the types of function/closure arguments when captured by an exit test. #1130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions Sources/TestingMacros/ConditionMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -437,21 +437,23 @@ extension ExitTestConditionMacro {
var bodyArgumentExpr = arguments[trailingClosureIndex].expression
bodyArgumentExpr = removeParentheses(from: bodyArgumentExpr) ?? bodyArgumentExpr

// Find any captured values and extract them from the trailing closure.
var capturedValues = [CapturedValueInfo]()
if ExitTestExpectMacro.isValueCapturingEnabled {
// The source file imports @_spi(Experimental), so allow value capturing.
if var closureExpr = bodyArgumentExpr.as(ClosureExprSyntax.self),
let captureList = closureExpr.signature?.capture?.items {
closureExpr.signature?.capture = ClosureCaptureClauseSyntax(items: [], trailingTrivia: .space)
capturedValues = captureList.map { CapturedValueInfo($0, in: context) }
bodyArgumentExpr = ExprSyntax(closureExpr)
// Before building the macro expansion, look for any problems and return
// early if found.
guard _diagnoseIssues(with: macro, body: bodyArgumentExpr, in: context) else {
if Self.isThrowing {
return #"{ () async throws -> Testing.ExitTest.Result in Swift.fatalError("Unreachable") }()"#
} else {
return #"{ () async -> Testing.ExitTest.Result in Swift.fatalError("Unreachable") }()"#
}
}

} else if let closureExpr = bodyArgumentExpr.as(ClosureExprSyntax.self),
let captureClause = closureExpr.signature?.capture,
!captureClause.items.isEmpty {
context.diagnose(.captureClauseUnsupported(captureClause, in: closureExpr, inExitTest: macro))
// Find any captured values and extract them from the trailing closure.
var capturedValues = [CapturedValueInfo]()
if var closureExpr = bodyArgumentExpr.as(ClosureExprSyntax.self),
let captureList = closureExpr.signature?.capture?.items {
closureExpr.signature?.capture = ClosureCaptureClauseSyntax(items: [], trailingTrivia: .space)
capturedValues = captureList.map { CapturedValueInfo($0, in: context) }
bodyArgumentExpr = ExprSyntax(closureExpr)
}

// Generate a unique identifier for this exit test.
Expand Down Expand Up @@ -610,6 +612,66 @@ extension ExitTestConditionMacro {
return ExprSyntax(tupleExpr)
}
}

/// Diagnose issues with an exit test macro call.
///
/// - Parameters:
/// - macro: The exit test macro call.
/// - bodyArgumentExpr: The exit test's body.
/// - context: The macro context in which the expression is being parsed.
///
/// - Returns: Whether or not macro expansion should continue (i.e. stopping
/// if a fatal error was diagnosed.)
private static func _diagnoseIssues(
with macro: some FreestandingMacroExpansionSyntax,
body bodyArgumentExpr: ExprSyntax,
in context: some MacroExpansionContext
) -> Bool {
var diagnostics = [DiagnosticMessage]()

var hasCaptureList = false
if let closureExpr = bodyArgumentExpr.as(ClosureExprSyntax.self),
let captureClause = closureExpr.signature?.capture,
!captureClause.items.isEmpty {
hasCaptureList = true

// Disallow capture lists if the experimental feature is not enabled.
if !ExitTestExpectMacro.isValueCapturingEnabled {
diagnostics.append(.captureClauseUnsupported(captureClause, in: closureExpr, inExitTest: macro))
}
}

for lexicalContext in context.lexicalContext {
// Disallow exit tests in generic functions as they cannot be correctly
// expanded.
if let functionDecl = lexicalContext.as(FunctionDeclSyntax.self) {
if let genericClause = functionDecl.genericParameterClause {
diagnostics.append(.expressionMacroUnsupported(macro, inGenericContextBecauseOf: genericClause, on: functionDecl))
} else if let whereClause = functionDecl.genericWhereClause {
diagnostics.append(.expressionMacroUnsupported(macro, inGenericContextBecauseOf: whereClause, on: functionDecl))
} else {
for parameter in functionDecl.signature.parameterClause.parameters {
if parameter.type.isSome {
diagnostics.append(.expressionMacroUnsupported(macro, inGenericContextBecauseOf: parameter, on: functionDecl))
}
}
}
} else if hasCaptureList, let lexicalContext = lexicalContext.asProtocol((any WithGenericParametersSyntax).self) {
// Disallow exit tests in generic types if they have capture lists (because
// the types may be ambiguous.)
if let genericClause = lexicalContext.genericParameterClause {
diagnostics.append(.expressionMacroUnsupported(macro, inGenericContextBecauseOf: genericClause, on: lexicalContext))
} else if let whereClause = lexicalContext.genericWhereClause {
diagnostics.append(.expressionMacroUnsupported(macro, inGenericContextBecauseOf: whereClause, on: lexicalContext))
}
}
}

for diagnostic in diagnostics {
context.diagnose(diagnostic)
}
return diagnostics.isEmpty
}
}

extension ExitTestExpectMacro {
Expand Down
60 changes: 57 additions & 3 deletions Sources/TestingMacros/Support/ClosureCaptureListParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ struct CapturedValueInfo {

// Potentially get the name of the type comprising the current lexical
// context (i.e. whatever `Self` is.)
lazy var lexicalContext = context.lexicalContext
lazy var typeNameOfLexicalContext = {
let lexicalContext = context.lexicalContext.drop { !$0.isProtocol((any DeclGroupSyntax).self) }
let lexicalContext = lexicalContext.drop { !$0.isProtocol((any DeclGroupSyntax).self) }
return context.type(ofLexicalContext: lexicalContext)
}()

Expand All @@ -71,18 +72,71 @@ struct CapturedValueInfo {
// Copying self.
self.type = typeNameOfLexicalContext
} else {
context.diagnose(.typeOfCaptureIsAmbiguous(capture, initializedWith: initializer))
// Handle literals. Any other types are ambiguous.
switch self.expression.kind {
case .integerLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("IntegerLiteralType")))
case .floatLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("FloatLiteralType")))
case .booleanLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("BooleanLiteralType")))
case .stringLiteralExpr, .simpleStringLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("StringLiteralType")))
default:
context.diagnose(.typeOfCaptureIsAmbiguous(capture, initializedWith: initializer))
}
}

} else if capture.name.tokenKind == .keyword(.self),
let typeNameOfLexicalContext {
// Capturing self.
self.expression = "self"
self.type = typeNameOfLexicalContext

} else if let parameterType = Self._findTypeOfParameter(named: capture.name, in: lexicalContext) {
self.expression = ExprSyntax(DeclReferenceExprSyntax(baseName: capture.name.trimmed))
self.type = parameterType
} else {
// Not enough contextual information to derive the type here.
context.diagnose(.typeOfCaptureIsAmbiguous(capture))
}
}

/// Find a function or closure parameter in the given lexical context with a
/// given name and return its type.
///
/// - Parameters:
/// - parameterName: The name of the parameter of interest.
/// - lexicalContext: The lexical context to examine.
///
/// - Returns: The Swift type of first parameter found whose name matches, or
/// `nil` if none was found. The lexical context is searched in the order
/// provided which, by default, starts with the innermost scope.
private static func _findTypeOfParameter(named parameterName: TokenSyntax, in lexicalContext: [Syntax]) -> TypeSyntax? {
for lexicalContext in lexicalContext {
var parameterType: TypeSyntax?
if let functionDecl = lexicalContext.as(FunctionDeclSyntax.self) {
parameterType = functionDecl.signature.parameterClause.parameters
.first { ($0.secondName ?? $0.firstName).tokenKind == parameterName.tokenKind }
.map(\.type)
} else if let closureExpr = lexicalContext.as(ClosureExprSyntax.self) {
if case let .parameterClause(parameterClause) = closureExpr.signature?.parameterClause {
parameterType = parameterClause.parameters
.first { ($0.secondName ?? $0.firstName).tokenKind == parameterName.tokenKind }
.flatMap(\.type)
}
} else if lexicalContext.is(DeclSyntax.self) {
// If we've reached any other enclosing declaration, then any parameters
// beyond it won't be capturable and thus it isn't possible to infer
// types from them (any capture of `x`, for instance, must refer to some
// more-local variable with that name, not to a parameter named `x`.)
return nil
}

if let parameterType {
return parameterType
}
}

return nil
}
}
32 changes: 32 additions & 0 deletions Sources/TestingMacros/Support/DiagnosticMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -870,4 +870,36 @@ extension DiagnosticMessage {
]
)
}

/// Create a diagnostic message stating that an expression macro is not
/// supported in a generic context.
///
/// - Parameters:
/// - macro: The invalid macro.
/// - genericClause: The child node on `genericDecl` that makes it generic.
/// - genericDecl: The generic declaration to which `genericClause` is
/// attached, possibly equal to `decl`.
///
/// - Returns: A diagnostic message.
static func expressionMacroUnsupported(_ macro: some FreestandingMacroExpansionSyntax, inGenericContextBecauseOf genericClause: some SyntaxProtocol, on genericDecl: some SyntaxProtocol) -> Self {
if let functionDecl = genericDecl.as(FunctionDeclSyntax.self) {
return Self(
syntax: Syntax(macro),
message: "Cannot call macro '\(_macroName(macro))' within generic function '\(functionDecl.completeName)'",
severity: .error
)
} else if let namedDecl = genericDecl.asProtocol((any NamedDeclSyntax).self) {
return Self(
syntax: Syntax(macro),
message: "Cannot call macro '\(_macroName(macro))' within generic \(_kindString(for: genericDecl)) '\(namedDecl.name.trimmed)'",
severity: .error
)
} else {
return Self(
syntax: Syntax(macro),
message: "Cannot call macro '\(_macroName(macro))' within a generic \(_kindString(for: genericDecl))",
severity: .error
)
}
}
}
18 changes: 18 additions & 0 deletions Tests/TestingMacrosTests/ConditionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,22 @@ struct ConditionMacroTests {
#expect(diagnostic.message.contains("is redundant"))
}

@Test("#expect(processExitsWith:) diagnostics",
arguments: [
"func f<T>() { #expectExitTest(processExitsWith: x) {} }":
"Cannot call macro ''#expectExitTest(processExitsWith:_:)'' within generic function 'f()'",
]
)
func exitTestDiagnostics(input: String, expectedMessage: String) throws {
let (_, diagnostics) = try parse(input)

#expect(diagnostics.count > 0)
for diagnostic in diagnostics {
#expect(diagnostic.diagMessage.severity == .error)
#expect(diagnostic.message == expectedMessage)
}
}

#if ExperimentalExitTestValueCapture
@Test("#expect(processExitsWith:) produces a diagnostic for a bad capture",
arguments: [
Expand All @@ -445,6 +461,8 @@ struct ConditionMacroTests {
"Type of captured value 'a' is ambiguous",
"#expectExitTest(processExitsWith: x) { [a = b] in }":
"Type of captured value 'a' is ambiguous",
"struct S<T> { func f() { #expectExitTest(processExitsWith: x) { [a] in } } }":
"Cannot call macro ''#expectExitTest(processExitsWith:_:)'' within generic structure 'S'",
]
)
func exitTestCaptureDiagnostics(input: String, expectedMessage: String) throws {
Expand Down
39 changes: 39 additions & 0 deletions Tests/TestingTests/ExitTestTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,45 @@ private import _TestingInternals
#expect(instance.x == 123)
}
}

@Test("Capturing a parameter to the test function")
func captureListWithParameter() async {
let i = Int.random(in: 0 ..< 1000)

func f(j: Int) async {
await #expect(processExitsWith: .success) { [i = i as Int, j] in
#expect(i == j)
#expect(j >= 0)
#expect(j < 1000)
}
}
await f(j: i)

await { (j: Int) in
_ = await #expect(processExitsWith: .success) { [i = i as Int, j] in
#expect(i == j)
#expect(j >= 0)
#expect(j < 1000)
}
}(i)

// FAILS TO COMPILE: shadowing `i` with a variable of a different type will
// prevent correct expansion (we need an equivalent of decltype() for that.)
// let i = String(i)
// await #expect(processExitsWith: .success) { [i] in
// #expect(!i.isEmpty)
// }
}

@Test("Capturing a literal expression")
func captureListWithLiterals() async {
await #expect(processExitsWith: .success) { [i = 0, f = 1.0, s = "", b = true] in
#expect(i == 0)
#expect(f == 1.0)
#expect(s == "")
#expect(b == true)
}
}
#endif
}

Expand Down