Skip to content

Commit 31785d8

Browse files
committed
Get collection diffing working
1 parent 3965d35 commit 31785d8

File tree

10 files changed

+335
-69
lines changed

10 files changed

+335
-69
lines changed

Sources/Testing/Expectations/Expectation.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
public struct Expectation: Sendable {
1313
/// The expression evaluated by this expectation.
1414
@_spi(ForToolsIntegrationOnly)
15-
public var evaluatedExpression: Expression
15+
public internal(set) var evaluatedExpression: Expression
1616

1717
/// A description of the error mismatch that occurred, if any.
1818
///
1919
/// If this expectation passed, the value of this property is `nil` because no
2020
/// error mismatch occurred.
2121
@_spi(Experimental) @_spi(ForToolsIntegrationOnly)
22-
public var mismatchedErrorDescription: String?
22+
public internal(set) var mismatchedErrorDescription: String?
2323

2424
/// A description of the difference between the operands in the expression
2525
/// evaluated by this expectation, if the difference could be determined.
@@ -28,7 +28,9 @@ public struct Expectation: Sendable {
2828
/// the difference is only computed when necessary to assist with diagnosing
2929
/// test failures.
3030
@_spi(Experimental) @_spi(ForToolsIntegrationOnly)
31-
public var differenceDescription: String?
31+
public var differenceDescription: String? {
32+
evaluatedExpression.differenceDescription
33+
}
3234

3335
/// A description of the exit condition that was expected to be matched.
3436
///

Sources/Testing/Expectations/ExpectationChecking+Macro.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public func __checkCondition(
111111
isRequired: Bool,
112112
sourceLocation: SourceLocation
113113
) rethrows -> Result<Void, any Error> {
114-
var expectationContext = __ExpectationContext(sourceCode: sourceCode)
114+
var expectationContext = __ExpectationContext.init(sourceCode: sourceCode)
115115
let condition = try condition(&expectationContext)
116116

117117
return check(

Sources/Testing/Expectations/ExpectationContext.swift

Lines changed: 178 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,65 @@ public struct __ExpectationContext: ~Copyable {
3333
/// will not be assigned a runtime value.
3434
var runtimeValues: [__ExpressionID: () -> Expression.Value?]
3535

36-
init(sourceCode: [__ExpressionID: String] = [:], runtimeValues: [__ExpressionID: () -> Expression.Value?] = [:]) {
36+
/// Computed differences between the operands or arguments of expressions.
37+
///
38+
/// The values in this dictionary are gathered at runtime as subexpressions
39+
/// are evaluated, much like ``runtimeValues``.
40+
var differences: [__ExpressionID: () -> CollectionDifference<Any>?]
41+
42+
init(
43+
sourceCode: [__ExpressionID: String] = [:],
44+
runtimeValues: [__ExpressionID: () -> Expression.Value?] = [:],
45+
differences: [__ExpressionID: () -> CollectionDifference<Any>?] = [:]
46+
) {
3747
self.sourceCode = sourceCode
3848
self.runtimeValues = runtimeValues
49+
self.differences = differences
50+
}
51+
52+
/// Convert an instance of `CollectionDifference` to one that is type-erased
53+
/// over elements of type `Any`.
54+
///
55+
/// - Parameters:
56+
/// - difference: The difference to convert.
57+
///
58+
/// - Returns: A type-erased copy of `difference`.
59+
private static func _typeEraseCollectionDifference(_ difference: CollectionDifference<some Any>) -> CollectionDifference<Any> {
60+
CollectionDifference<Any>(
61+
difference.lazy.map { change in
62+
switch change {
63+
case let .insert(offset, element, associatedWith):
64+
return .insert(offset: offset, element: element as Any, associatedWith: associatedWith)
65+
case let .remove(offset, element, associatedWith):
66+
return .remove(offset: offset, element: element as Any, associatedWith: associatedWith)
67+
}
68+
}
69+
)!
70+
}
71+
72+
/// Generate a description of a previously-computed collection difference.
73+
///
74+
/// - Parameters:
75+
/// - difference: The difference to describe.
76+
///
77+
/// - Returns: A human-readable string describing `difference`.
78+
private borrowing func _description(of difference: CollectionDifference<some Any>) -> String {
79+
let insertions: [String] = difference.insertions.lazy
80+
.map(\.element)
81+
.map(String.init(describingForTest:))
82+
let removals: [String] = difference.removals.lazy
83+
.map(\.element)
84+
.map(String.init(describingForTest:))
85+
86+
var resultComponents = [String]()
87+
if !insertions.isEmpty {
88+
resultComponents.append("inserted [\(insertions.joined(separator: ", "))]")
89+
}
90+
if !removals.isEmpty {
91+
resultComponents.append("removed [\(removals.joined(separator: ", "))]")
92+
}
93+
94+
return resultComponents.joined(separator: ", ")
3995
}
4096

4197
/// Collapse the given expression graph into one or more expressions with
@@ -102,6 +158,15 @@ public struct __ExpectationContext: ~Copyable {
102158
expressionGraph[keyPath] = expression
103159
}
104160
}
161+
162+
for (id, difference) in differences {
163+
let keyPath = id.keyPath
164+
if var expression = expressionGraph[keyPath], let difference = difference() {
165+
let differenceDescription = _description(of: difference)
166+
expression.differenceDescription = differenceDescription
167+
expressionGraph[keyPath] = expression
168+
}
169+
}
105170
}
106171

107172
// Flatten the expression graph.
@@ -154,11 +219,12 @@ extension __ExpectationContext {
154219
///
155220
/// - Warning: This function is used to implement the `#expect()` and
156221
/// `#require()` macros. Do not call it directly.
157-
public mutating func callAsFunction<T>(_ value: T, _ id: __ExpressionID) -> T where T: Copyable {
222+
public mutating func callAsFunction<T>(_ value: T, _ id: __ExpressionID) -> T {
158223
runtimeValues[id] = { Expression.Value(reflecting: value) }
159224
return value
160225
}
161226

227+
#if SWT_SUPPORTS_MOVE_ONLY_EXPRESSION_EXPANSION
162228
/// Capture information about a value for use if the expectation currently
163229
/// being evaluated fails.
164230
///
@@ -176,7 +242,113 @@ extension __ExpectationContext {
176242
// TODO: add support for borrowing non-copyable expressions (need @lifetime)
177243
return value
178244
}
245+
#endif
246+
}
247+
248+
// MARK: - Collection comparison
249+
250+
extension __ExpectationContext {
251+
/// Compare two values using `==` or `!=`.
252+
///
253+
/// - Parameters:
254+
/// - lhs: The left-hand operand.
255+
/// - lhsID: A value that uniquely identifies the expression represented by
256+
/// `lhs` in the context of the expectation currently being evaluated.
257+
/// - rhs: The left-hand operand.
258+
/// - rhsID: A value that uniquely identifies the expression represented by
259+
/// `rhs` in the context of the expectation currently being evaluated.
260+
/// - op: A function that performs an operation on `lhs` and `rhs`.
261+
/// - opID: A value that uniquely identifies the expression represented by
262+
/// `op` in the context of the expectation currently being evaluated.
263+
///
264+
/// - Returns: The result of calling `op(lhs, rhs)`.
265+
///
266+
/// This overload of `__cmp()` serves as a catch-all for operands that are not
267+
/// collections or otherwise are not interesting to the testing library.
268+
///
269+
/// - Warning: This function is used to implement the `#expect()` and
270+
/// `#require()` macros. Do not call it directly.
271+
public mutating func __cmp<T, U, R>(
272+
_ lhs: T,
273+
_ lhsID: __ExpressionID,
274+
_ rhs: U,
275+
_ rhsID: __ExpressionID,
276+
_ op: (T, U) throws -> R,
277+
_ opID: __ExpressionID
278+
) rethrows -> R {
279+
try self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
280+
}
281+
282+
public mutating func __cmp<C>(
283+
_ lhs: C,
284+
_ lhsID: __ExpressionID,
285+
_ rhs: C,
286+
_ rhsID: __ExpressionID,
287+
_ op: (C, C) -> Bool,
288+
_ opID: __ExpressionID
289+
) -> Bool where C: BidirectionalCollection, C.Element: Equatable {
290+
let result = self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
291+
292+
if !result {
293+
differences[opID] = { [lhs, rhs] in
294+
Self._typeEraseCollectionDifference(lhs.difference(from: rhs))
295+
}
296+
}
297+
298+
return result
299+
}
300+
301+
public mutating func __cmp<R>(
302+
_ lhs: R,
303+
_ lhsID: __ExpressionID,
304+
_ rhs: R,
305+
_ rhsID: __ExpressionID,
306+
_ op: (R, R) -> Bool,
307+
_ opID: __ExpressionID
308+
) -> Bool where R: RangeExpression & BidirectionalCollection, R.Element: Equatable {
309+
self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
310+
}
311+
312+
public mutating func __cmp<S>(
313+
_ lhs: S,
314+
_ lhsID: __ExpressionID,
315+
_ rhs: S,
316+
_ rhsID: __ExpressionID,
317+
_ op: (S, S) -> Bool,
318+
_ opID: __ExpressionID
319+
) -> Bool where S: StringProtocol {
320+
let result = self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
179321

322+
if !result {
323+
differences[opID] = { [lhs, rhs] in
324+
// Compare strings by line, not by character.
325+
let lhsLines = String(lhs).split(whereSeparator: \.isNewline)
326+
let rhsLines = String(rhs).split(whereSeparator: \.isNewline)
327+
328+
if lhsLines.count == 1 && rhsLines.count == 1 {
329+
// There are no newlines in either string, so there's no meaningful
330+
// per-line difference. Bail.
331+
return nil
332+
}
333+
334+
let diff = lhsLines.difference(from: rhsLines)
335+
if diff.isEmpty {
336+
// The strings must have compared on a per-character basis, or this
337+
// operator doesn't behave the way we expected. Bail.
338+
return nil
339+
}
340+
341+
return Self._typeEraseCollectionDifference(diff)
342+
}
343+
}
344+
345+
return result
346+
}
347+
}
348+
349+
// MARK: - Casting
350+
351+
extension __ExpectationContext {
180352
/// Perform a conditional cast (`as?`) on a value.
181353
///
182354
/// - Parameters:
@@ -258,15 +430,15 @@ extension __ExpectationContext {
258430
///
259431
/// - Warning: This function is used to implement the `#expect()` and
260432
/// `#require()` macros. Do not call it directly.
261-
public mutating func callAsFunction<T, U>(_ value: T, _ id: __ExpressionID) -> U where T: StringProtocol, U: _Pointer {
433+
public mutating func callAsFunction<P>(_ value: String, _ id: __ExpressionID) -> P where P: _Pointer {
262434
// Perform the normal value capture.
263435
let result = self(value, id)
264436

265437
// Create a C string copy of `value`.
266438
#if os(Windows)
267-
let resultCString = _strdup(String(result))!
439+
let resultCString = _strdup(result)!
268440
#else
269-
let resultCString = strdup(String(result))!
441+
let resultCString = strdup(result)!
270442
#endif
271443

272444
// Store the C string pointer so we can free it later when this context is
@@ -277,7 +449,7 @@ extension __ExpectationContext {
277449
_transformedCStrings.append(resultCString)
278450

279451
// Return the C string as whatever pointer type the caller wants.
280-
return U(bitPattern: Int(bitPattern: resultCString)).unsafelyUnwrapped
452+
return P(bitPattern: Int(bitPattern: resultCString)).unsafelyUnwrapped
281453
}
282454
}
283455
#endif

Sources/Testing/SourceAttribution/Expression.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,18 @@ public struct __Expression: Sendable {
240240
@_spi(ForToolsIntegrationOnly)
241241
public internal(set) var subexpressions = [Self]()
242242

243+
/// A description of the difference between the operands in this expression,
244+
/// if that difference could be determined.
245+
///
246+
/// The value of this property is set for the binary operators `==` and `!=`
247+
/// when used to compare collections.
248+
///
249+
/// If the containing expectation passed, the value of this property is `nil`
250+
/// because the difference is only computed when necessary to assist with
251+
/// diagnosing test failures.
252+
@_spi(Experimental) @_spi(ForToolsIntegrationOnly)
253+
public internal(set) var differenceDescription: String?
254+
243255
@_spi(ForToolsIntegrationOnly)
244256
@available(*, deprecated, message: "The value of this property is always nil.")
245257
public var stringLiteralValue: String? {

Sources/TestingMacros/ConditionMacro.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,11 @@ extension ConditionMacro {
198198

199199
checkArguments.append(Argument(expression: argumentExpr))
200200

201-
let sourceCodeNodeIDs = rewrittenNodes.compactMap { $0.expressionID(rootedAt: originalArgumentExpr) }
202-
let sourceCodeExprs = rewrittenNodes.map { StringLiteralExprSyntax(content: $0.trimmedDescription) }
201+
// Sort the rewritten nodes. This isn't strictly necessary for
202+
// correctness but it does make the produced code more consistent.
203+
let sortedRewrittenNodes = rewrittenNodes.sorted { $0.id < $1.id }
204+
let sourceCodeNodeIDs = sortedRewrittenNodes.compactMap { $0.expressionID(rootedAt: originalArgumentExpr) }
205+
let sourceCodeExprs = sortedRewrittenNodes.map { StringLiteralExprSyntax(content: $0.trimmedDescription) }
203206
let sourceCodeExpr = DictionaryExprSyntax {
204207
for (nodeID, sourceCodeExpr) in zip(sourceCodeNodeIDs, sourceCodeExprs) {
205208
DictionaryElementSyntax(key: nodeID, value: sourceCodeExpr)

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
109109
/// A list of any syntax nodes that have been rewritten.
110110
///
111111
/// The nodes in this array are the _original_ nodes, not the rewritten nodes.
112-
var rewrittenNodes = [Syntax]()
112+
var rewrittenNodes = Set<Syntax>()
113113

114114
init(in context: C, for macro: M, rootedAt effectiveRootNode: Syntax, expressionContextName: TokenSyntax) {
115115
self.context = context
@@ -131,7 +131,14 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
131131
/// - Returns: A rewritten copy of `node` that calls into the expression
132132
/// context when it is evaluated at runtime.
133133
private func _rewrite<E>(_ node: E, originalWas originalNode: some SyntaxProtocol) -> ExprSyntax where E: ExprSyntaxProtocol {
134-
rewrittenNodes.append(Syntax(originalNode))
134+
if rewrittenNodes.contains(Syntax(originalNode)) {
135+
// If this node has already been rewritten, we don't need to rewrite it
136+
// again. (Currently, this can only happen when expanding binary operators
137+
// which need a bit of extra help.)
138+
return ExprSyntax(node)
139+
}
140+
141+
rewrittenNodes.insert(Syntax(originalNode))
135142

136143
var result = FunctionCallExprSyntax(calledExpression: expressionContextNameExpr) {
137144
LabeledExprSyntax(expression: node.trimmed)
@@ -316,7 +323,47 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
316323
}
317324

318325
override func visit(_ node: InfixOperatorExprSyntax) -> ExprSyntax {
319-
_rewrite(
326+
if let op = node.operator.as(BinaryOperatorExprSyntax.self)?.operator.textWithoutBackticks,
327+
op == "==" || op == "!=" || op == "===" || op == "!==" {
328+
329+
rewrittenNodes.insert(Syntax(node))
330+
rewrittenNodes.insert(Syntax(node.leftOperand))
331+
rewrittenNodes.insert(Syntax(node.rightOperand))
332+
333+
var result = FunctionCallExprSyntax(
334+
calledExpression: MemberAccessExprSyntax(
335+
base: expressionContextNameExpr,
336+
name: .identifier("__cmp")
337+
)
338+
) {
339+
LabeledExprSyntax(expression: visit(node.leftOperand))
340+
LabeledExprSyntax(expression: node.leftOperand.expressionID(rootedAt: effectiveRootNode))
341+
LabeledExprSyntax(expression: visit(node.rightOperand))
342+
LabeledExprSyntax(expression: node.rightOperand.expressionID(rootedAt: effectiveRootNode))
343+
LabeledExprSyntax(
344+
expression: ClosureExprSyntax {
345+
InfixOperatorExprSyntax(
346+
leftOperand: DeclReferenceExprSyntax(
347+
baseName: .dollarIdentifier("$0")
348+
).with(\.trailingTrivia, .space),
349+
operator: BinaryOperatorExprSyntax(text: op),
350+
rightOperand: DeclReferenceExprSyntax(
351+
baseName: .dollarIdentifier("$1")
352+
).with(\.leadingTrivia, .space)
353+
)
354+
}
355+
)
356+
LabeledExprSyntax(expression: node.expressionID(rootedAt: effectiveRootNode))
357+
}
358+
result.leftParen = .leftParenToken()
359+
result.rightParen = .rightParenToken()
360+
result.leadingTrivia = node.leadingTrivia
361+
result.trailingTrivia = node.trailingTrivia
362+
363+
return ExprSyntax(result)
364+
}
365+
366+
return _rewrite(
320367
node
321368
.with(\.leftOperand, visit(node.leftOperand))
322369
.with(\.rightOperand, visit(node.rightOperand)),
@@ -481,7 +528,7 @@ func insertCalls(
481528
for macro: some FreestandingMacroExpansionSyntax,
482529
rootedAt effectiveRootNode: some SyntaxProtocol,
483530
in context: some MacroExpansionContext
484-
) -> (Syntax, rewrittenNodes: [Syntax]) {
531+
) -> (Syntax, rewrittenNodes: Set<Syntax>) {
485532
if let node = node.as(ExprSyntax.self) {
486533
_diagnoseTrivialBooleanValue(from: node, for: macro, in: context)
487534
}

0 commit comments

Comments
 (0)