@@ -8,6 +8,7 @@ use crate::deriving::generic::ty::*;
8
8
use crate :: deriving:: generic:: * ;
9
9
use crate :: deriving:: { path_local, path_std} ;
10
10
11
+ /// Expands a `#[derive(PartialEq)]` attribute into an implementation for the target item.
11
12
pub ( crate ) fn expand_deriving_partial_eq (
12
13
cx : & ExtCtxt < ' _ > ,
13
14
span : Span ,
@@ -16,62 +17,6 @@ pub(crate) fn expand_deriving_partial_eq(
16
17
push : & mut dyn FnMut ( Annotatable ) ,
17
18
is_const : bool ,
18
19
) {
19
- fn cs_eq ( cx : & ExtCtxt < ' _ > , span : Span , substr : & Substructure < ' _ > ) -> BlockOrExpr {
20
- let base = true ;
21
- let expr = cs_fold (
22
- true , // use foldl
23
- cx,
24
- span,
25
- substr,
26
- |cx, fold| match fold {
27
- CsFold :: Single ( field) => {
28
- let [ other_expr] = & field. other_selflike_exprs [ ..] else {
29
- cx. dcx ( )
30
- . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
31
- } ;
32
-
33
- // We received arguments of type `&T`. Convert them to type `T` by stripping
34
- // any leading `&`. This isn't necessary for type checking, but
35
- // it results in better error messages if something goes wrong.
36
- //
37
- // Note: for arguments that look like `&{ x }`, which occur with packed
38
- // structs, this would cause expressions like `{ self.x } == { other.x }`,
39
- // which isn't valid Rust syntax. This wouldn't break compilation because these
40
- // AST nodes are constructed within the compiler. But it would mean that code
41
- // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42
- // syntax, which would be suboptimal. So we wrap these in parens, giving
43
- // `({ self.x }) == ({ other.x })`, which is valid syntax.
44
- let convert = |expr : & P < Expr > | {
45
- if let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) =
46
- & expr. kind
47
- {
48
- if let ExprKind :: Block ( ..) = & inner. kind {
49
- // `&{ x }` form: remove the `&`, add parens.
50
- cx. expr_paren ( field. span , inner. clone ( ) )
51
- } else {
52
- // `&x` form: remove the `&`.
53
- inner. clone ( )
54
- }
55
- } else {
56
- expr. clone ( )
57
- }
58
- } ;
59
- cx. expr_binary (
60
- field. span ,
61
- BinOpKind :: Eq ,
62
- convert ( & field. self_expr ) ,
63
- convert ( other_expr) ,
64
- )
65
- }
66
- CsFold :: Combine ( span, expr1, expr2) => {
67
- cx. expr_binary ( span, BinOpKind :: And , expr1, expr2)
68
- }
69
- CsFold :: Fieldless => cx. expr_bool ( span, base) ,
70
- } ,
71
- ) ;
72
- BlockOrExpr :: new_expr ( expr)
73
- }
74
-
75
20
let structural_trait_def = TraitDef {
76
21
span,
77
22
path : path_std ! ( marker:: StructuralPartialEq ) ,
@@ -97,7 +42,9 @@ pub(crate) fn expand_deriving_partial_eq(
97
42
ret_ty: Path ( path_local!( bool ) ) ,
98
43
attributes: thin_vec![ cx. attr_word( sym:: inline, span) ] ,
99
44
fieldless_variants_strategy: FieldlessVariantsStrategy :: Unify ,
100
- combine_substructure: combine_substructure( Box :: new( |a, b, c| cs_eq( a, b, c) ) ) ,
45
+ combine_substructure: combine_substructure( Box :: new( |a, b, c| {
46
+ BlockOrExpr :: new_expr( get_substructure_equality_expr( a, b, c) )
47
+ } ) ) ,
101
48
} ] ;
102
49
103
50
let trait_def = TraitDef {
@@ -113,3 +60,138 @@ pub(crate) fn expand_deriving_partial_eq(
113
60
} ;
114
61
trait_def. expand ( cx, mitem, item, push)
115
62
}
63
+
64
+ /// Generates the equality expression for a struct or enum variant when deriving `PartialEq`.
65
+ ///
66
+ /// This function constructs an expression that compares all fields of a struct or enum variant for equality.
67
+ /// It groups scalar and compound types separately, combining their comparisons efficiently:
68
+ /// - If there are no fields, returns `true` (fieldless types are always equal to themselves).
69
+ /// - Scalar fields are compared first for efficiency, then compound fields.
70
+ /// - If only one group is non-empty, returns its comparison directly.
71
+ /// - Otherwise, returns a conjunction (logical AND) of both groups' comparisons.
72
+ ///
73
+ /// For enums with discriminants, compares the discriminant first, then the rest of the fields.
74
+ ///
75
+ /// # Panics
76
+ ///
77
+ /// If called on static or all-fieldless enums/structs, which should not occur during derive expansion.
78
+ fn get_substructure_equality_expr (
79
+ cx : & ExtCtxt < ' _ > ,
80
+ span : Span ,
81
+ substructure : & Substructure < ' _ > ,
82
+ ) -> P < Expr > {
83
+ /// Combines the accumulated comparison expression with the next field's comparison using logical AND.
84
+ ///
85
+ /// If this is the first field, initializes the accumulator. Otherwise, chains with logical AND.
86
+ fn combine ( cx : & ExtCtxt < ' _ > , span : Span , acc : & mut Option < P < Expr > > , elem : P < Expr > ) {
87
+ let Some ( lhs) = acc. take ( ) else {
88
+ * acc = Some ( elem) ;
89
+ return ;
90
+ } ;
91
+ * acc = Some ( cx. expr_binary ( span, BinOpKind :: And , lhs, elem) ) ;
92
+ }
93
+
94
+ use SubstructureFields :: * ;
95
+ match substructure. fields {
96
+ EnumMatching ( .., fields) | Struct ( .., fields) => {
97
+ if fields. is_empty ( ) {
98
+ // Fieldless structs or enum variants are always equal to themselves.
99
+ return cx. expr_bool ( span, true ) ;
100
+ }
101
+
102
+ let mut scalar_ty_cmp = None ;
103
+ let mut compound_ty_cmp = None ;
104
+ // Compare scalar and compound types separately for efficiency.
105
+ for field in fields {
106
+ let is_scalar = field. is_scalar ;
107
+ let field_span = field. span ;
108
+ let rhs = get_field_equality_expr ( cx, field) ;
109
+
110
+ if is_scalar {
111
+ // Combine scalar field comparisons first (cheaper to evaluate).
112
+ combine ( cx, field_span, & mut scalar_ty_cmp, rhs) ;
113
+ continue ;
114
+ }
115
+ // Combine compound (non-scalar) field comparisons.
116
+ combine ( cx, field_span, & mut compound_ty_cmp, rhs) ;
117
+ }
118
+
119
+ // If only one group (scalar or compound) has fields, return its comparison directly.
120
+ if scalar_ty_cmp. is_some ( ) ^ compound_ty_cmp. is_some ( ) {
121
+ return scalar_ty_cmp. or ( compound_ty_cmp) . unwrap ( ) ;
122
+ }
123
+
124
+ // If both groups are non-empty, require all fields to be equal.
125
+ // Scalar fields are compared first for performance.
126
+ return cx. expr_binary (
127
+ span,
128
+ BinOpKind :: And ,
129
+ scalar_ty_cmp. unwrap ( ) ,
130
+ compound_ty_cmp. unwrap ( ) ,
131
+ ) ;
132
+ }
133
+ EnumDiscr ( disc, match_expr) => {
134
+ let lhs = get_field_equality_expr ( cx, disc) ;
135
+ let Some ( match_expr) = match_expr else {
136
+ return lhs;
137
+ } ;
138
+ // Compare the discriminant first (cheaper), then the rest of the fields.
139
+ return cx. expr_binary ( disc. span , BinOpKind :: And , lhs, match_expr. clone ( ) ) ;
140
+ }
141
+ StaticEnum ( ..) => cx. dcx ( ) . span_bug (
142
+ span,
143
+ "unexpected static enum encountered during `derive(PartialEq)` expansion" ,
144
+ ) ,
145
+ StaticStruct ( ..) => cx. dcx ( ) . span_bug (
146
+ span,
147
+ "unexpected static struct encountered during `derive(PartialEq)` expansion" ,
148
+ ) ,
149
+ AllFieldlessEnum ( ..) => cx. dcx ( ) . span_bug (
150
+ span,
151
+ "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion" ,
152
+ ) ,
153
+ }
154
+ }
155
+
156
+ /// Generates an equality comparison expression for a single struct or enum field.
157
+ ///
158
+ /// This function produces an AST expression that compares the `self` and `other` values for a field using `==`.
159
+ /// It removes any leading references from both sides for readability.
160
+ /// If the field is a block expression, it is wrapped in parentheses to ensure valid syntax.
161
+ ///
162
+ /// # Panics
163
+ ///
164
+ /// Panics if there are not exactly two arguments to compare (should be `self` and `other`).
165
+ fn get_field_equality_expr ( cx : & ExtCtxt < ' _ > , field : & FieldInfo ) -> P < Expr > {
166
+ let [ rhs] = & field. other_selflike_exprs [ ..] else {
167
+ cx. dcx ( ) . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
168
+ } ;
169
+
170
+ cx. expr_binary (
171
+ field. span ,
172
+ BinOpKind :: Eq ,
173
+ wrap_block_expr ( cx, peel_refs ( & field. self_expr ) ) ,
174
+ wrap_block_expr ( cx, peel_refs ( rhs) ) ,
175
+ )
176
+ }
177
+
178
+ /// Removes all leading immutable references from an expression.
179
+ ///
180
+ /// This is used to strip away any number of leading `&` from an expression (e.g., `&&&T` becomes `T`).
181
+ /// Only removes immutable references; mutable references are preserved.
182
+ fn peel_refs ( mut expr : & P < Expr > ) -> P < Expr > {
183
+ while let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) = & expr. kind {
184
+ expr = & inner;
185
+ }
186
+ expr. clone ( )
187
+ }
188
+
189
+ /// Wraps a block expression in parentheses to ensure valid AST in macro expansion output.
190
+ ///
191
+ /// If the given expression is a block, it is wrapped in parentheses; otherwise, it is returned unchanged.
192
+ fn wrap_block_expr ( cx : & ExtCtxt < ' _ > , expr : P < Expr > ) -> P < Expr > {
193
+ if matches ! ( & expr. kind, ExprKind :: Block ( ..) ) {
194
+ return cx. expr_paren ( expr. span , expr) ;
195
+ }
196
+ expr
197
+ }
0 commit comments