Skip to content

Commit 843083a

Browse files
committed
Separately check equality of the scalar types and compound types in the order of declaration.
1 parent ebe9b00 commit 843083a

File tree

5 files changed

+445
-58
lines changed

5 files changed

+445
-58
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,6 +2195,34 @@ impl FloatTy {
21952195
}
21962196
}
21972197

2198+
impl<'a> TryFrom<&'a str> for FloatTy {
2199+
type Error = ();
2200+
2201+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2202+
Ok(match value {
2203+
"f16" => Self::F16,
2204+
"f32" => Self::F32,
2205+
"f64" => Self::F64,
2206+
"f128" => Self::F128,
2207+
_ => return Err(()),
2208+
})
2209+
}
2210+
}
2211+
2212+
impl TryFrom<Symbol> for FloatTy {
2213+
type Error = ();
2214+
2215+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2216+
Ok(match value {
2217+
sym::f16 => Self::F16,
2218+
sym::f32 => Self::F32,
2219+
sym::f64 => Self::F64,
2220+
sym::f128 => Self::F128,
2221+
_ => return Err(()),
2222+
})
2223+
}
2224+
}
2225+
21982226
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
21992227
#[derive(Encodable, Decodable, HashStable_Generic)]
22002228
pub enum IntTy {
@@ -2230,6 +2258,38 @@ impl IntTy {
22302258
}
22312259
}
22322260

2261+
impl<'a> TryFrom<&'a str> for IntTy {
2262+
type Error = ();
2263+
2264+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2265+
Ok(match value {
2266+
"isize" => Self::Isize,
2267+
"i8" => Self::I8,
2268+
"i16" => Self::I16,
2269+
"i32" => Self::I32,
2270+
"i64" => Self::I64,
2271+
"i128" => Self::I128,
2272+
_ => return Err(()),
2273+
})
2274+
}
2275+
}
2276+
2277+
impl TryFrom<Symbol> for IntTy {
2278+
type Error = ();
2279+
2280+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2281+
Ok(match value {
2282+
sym::isize => Self::Isize,
2283+
sym::i8 => Self::I8,
2284+
sym::i16 => Self::I16,
2285+
sym::i32 => Self::I32,
2286+
sym::i64 => Self::I64,
2287+
sym::i128 => Self::I128,
2288+
_ => return Err(()),
2289+
})
2290+
}
2291+
}
2292+
22332293
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Debug)]
22342294
#[derive(Encodable, Decodable, HashStable_Generic)]
22352295
pub enum UintTy {
@@ -2265,6 +2325,38 @@ impl UintTy {
22652325
}
22662326
}
22672327

2328+
impl<'a> TryFrom<&'a str> for UintTy {
2329+
type Error = ();
2330+
2331+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2332+
Ok(match value {
2333+
"usize" => Self::Usize,
2334+
"u8" => Self::U8,
2335+
"u16" => Self::U16,
2336+
"u32" => Self::U32,
2337+
"u64" => Self::U64,
2338+
"u128" => Self::U128,
2339+
_ => return Err(()),
2340+
})
2341+
}
2342+
}
2343+
2344+
impl TryFrom<Symbol> for UintTy {
2345+
type Error = ();
2346+
2347+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2348+
Ok(match value {
2349+
sym::usize => Self::Usize,
2350+
sym::u8 => Self::U8,
2351+
sym::u16 => Self::U16,
2352+
sym::u32 => Self::U32,
2353+
sym::u64 => Self::U64,
2354+
sym::u128 => Self::U128,
2355+
_ => return Err(()),
2356+
})
2357+
}
2358+
}
2359+
22682360
/// A constraint on an associated item.
22692361
///
22702362
/// ### Examples
@@ -2452,6 +2544,21 @@ impl TyKind {
24522544
None
24532545
}
24542546
}
2547+
2548+
pub fn is_scalar(&self) -> bool {
2549+
let Some(ty_kind) = self.is_simple_path() else {
2550+
match self {
2551+
TyKind::Tup(tys) => return tys.is_empty(), // unit type
2552+
_ => return false,
2553+
}
2554+
};
2555+
2556+
ty_kind == sym::bool
2557+
|| ty_kind == sym::char
2558+
|| IntTy::try_from(ty_kind).is_ok()
2559+
|| UintTy::try_from(ty_kind).is_ok()
2560+
|| FloatTy::try_from(ty_kind).is_ok()
2561+
}
24552562
}
24562563

24572564
/// A pattern type pattern.

compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs

Lines changed: 139 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::deriving::generic::ty::*;
88
use crate::deriving::generic::*;
99
use crate::deriving::{path_local, path_std};
1010

11+
/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the target item.
1112
pub(crate) fn expand_deriving_partial_eq(
1213
cx: &ExtCtxt<'_>,
1314
span: Span,
@@ -16,62 +17,6 @@ pub(crate) fn expand_deriving_partial_eq(
1617
push: &mut dyn FnMut(Annotatable),
1718
is_const: bool,
1819
) {
19-
fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
20-
let base = true;
21-
let expr = cs_fold(
22-
true, // use foldl
23-
cx,
24-
span,
25-
substr,
26-
|cx, fold| match fold {
27-
CsFold::Single(field) => {
28-
let [other_expr] = &field.other_selflike_exprs[..] else {
29-
cx.dcx()
30-
.span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
31-
};
32-
33-
// We received arguments of type `&T`. Convert them to type `T` by stripping
34-
// any leading `&`. This isn't necessary for type checking, but
35-
// it results in better error messages if something goes wrong.
36-
//
37-
// Note: for arguments that look like `&{ x }`, which occur with packed
38-
// structs, this would cause expressions like `{ self.x } == { other.x }`,
39-
// which isn't valid Rust syntax. This wouldn't break compilation because these
40-
// AST nodes are constructed within the compiler. But it would mean that code
41-
// printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42-
// syntax, which would be suboptimal. So we wrap these in parens, giving
43-
// `({ self.x }) == ({ other.x })`, which is valid syntax.
44-
let convert = |expr: &P<Expr>| {
45-
if let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) =
46-
&expr.kind
47-
{
48-
if let ExprKind::Block(..) = &inner.kind {
49-
// `&{ x }` form: remove the `&`, add parens.
50-
cx.expr_paren(field.span, inner.clone())
51-
} else {
52-
// `&x` form: remove the `&`.
53-
inner.clone()
54-
}
55-
} else {
56-
expr.clone()
57-
}
58-
};
59-
cx.expr_binary(
60-
field.span,
61-
BinOpKind::Eq,
62-
convert(&field.self_expr),
63-
convert(other_expr),
64-
)
65-
}
66-
CsFold::Combine(span, expr1, expr2) => {
67-
cx.expr_binary(span, BinOpKind::And, expr1, expr2)
68-
}
69-
CsFold::Fieldless => cx.expr_bool(span, base),
70-
},
71-
);
72-
BlockOrExpr::new_expr(expr)
73-
}
74-
7520
let structural_trait_def = TraitDef {
7621
span,
7722
path: path_std!(marker::StructuralPartialEq),
@@ -97,7 +42,9 @@ pub(crate) fn expand_deriving_partial_eq(
9742
ret_ty: Path(path_local!(bool)),
9843
attributes: thin_vec![cx.attr_word(sym::inline, span)],
9944
fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
100-
combine_substructure: combine_substructure(Box::new(|a, b, c| cs_eq(a, b, c))),
45+
combine_substructure: combine_substructure(Box::new(|a, b, c| {
46+
BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
47+
})),
10148
}];
10249

10350
let trait_def = TraitDef {
@@ -113,3 +60,138 @@ pub(crate) fn expand_deriving_partial_eq(
11360
};
11461
trait_def.expand(cx, mitem, item, push)
11562
}
63+
64+
/// Generates the equality expression for a struct or enum variant when deriving `PartialEq`.
65+
///
66+
/// This function constructs an expression that compares all fields of a struct or enum variant for equality.
67+
/// It groups scalar and compound types separately, combining their comparisons efficiently:
68+
/// - If there are no fields, returns `true` (fieldless types are always equal to themselves).
69+
/// - Scalar fields are compared first for efficiency, then compound fields.
70+
/// - If only one group is non-empty, returns its comparison directly.
71+
/// - Otherwise, returns a conjunction (logical AND) of both groups' comparisons.
72+
///
73+
/// For enums with discriminants, compares the discriminant first, then the rest of the fields.
74+
///
75+
/// # Panics
76+
///
77+
/// If called on static or all-fieldless enums/structs, which should not occur during derive expansion.
78+
fn get_substructure_equality_expr(
79+
cx: &ExtCtxt<'_>,
80+
span: Span,
81+
substructure: &Substructure<'_>,
82+
) -> P<Expr> {
83+
/// Combines the accumulated comparison expression with the next field's comparison using logical AND.
84+
///
85+
/// If this is the first field, initializes the accumulator. Otherwise, chains with logical AND.
86+
fn combine(cx: &ExtCtxt<'_>, span: Span, acc: &mut Option<P<Expr>>, elem: P<Expr>) {
87+
let Some(lhs) = acc.take() else {
88+
*acc = Some(elem);
89+
return;
90+
};
91+
*acc = Some(cx.expr_binary(span, BinOpKind::And, lhs, elem));
92+
}
93+
94+
use SubstructureFields::*;
95+
match substructure.fields {
96+
EnumMatching(.., fields) | Struct(.., fields) => {
97+
if fields.is_empty() {
98+
// Fieldless structs or enum variants are always equal to themselves.
99+
return cx.expr_bool(span, true);
100+
}
101+
102+
let mut scalar_ty_cmp = None;
103+
let mut compound_ty_cmp = None;
104+
// Compare scalar and compound types separately for efficiency.
105+
for field in fields {
106+
let is_scalar = field.is_scalar;
107+
let field_span = field.span;
108+
let rhs = get_field_equality_expr(cx, field);
109+
110+
if is_scalar {
111+
// Combine scalar field comparisons first (cheaper to evaluate).
112+
combine(cx, field_span, &mut scalar_ty_cmp, rhs);
113+
continue;
114+
}
115+
// Combine compound (non-scalar) field comparisons.
116+
combine(cx, field_span, &mut compound_ty_cmp, rhs);
117+
}
118+
119+
// If only one group (scalar or compound) has fields, return its comparison directly.
120+
if scalar_ty_cmp.is_some() ^ compound_ty_cmp.is_some() {
121+
return scalar_ty_cmp.or(compound_ty_cmp).unwrap();
122+
}
123+
124+
// If both groups are non-empty, require all fields to be equal.
125+
// Scalar fields are compared first for performance.
126+
return cx.expr_binary(
127+
span,
128+
BinOpKind::And,
129+
scalar_ty_cmp.unwrap(),
130+
compound_ty_cmp.unwrap(),
131+
);
132+
}
133+
EnumDiscr(disc, match_expr) => {
134+
let lhs = get_field_equality_expr(cx, disc);
135+
let Some(match_expr) = match_expr else {
136+
return lhs;
137+
};
138+
// Compare the discriminant first (cheaper), then the rest of the fields.
139+
return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
140+
}
141+
StaticEnum(..) => cx.dcx().span_bug(
142+
span,
143+
"unexpected static enum encountered during `derive(PartialEq)` expansion",
144+
),
145+
StaticStruct(..) => cx.dcx().span_bug(
146+
span,
147+
"unexpected static struct encountered during `derive(PartialEq)` expansion",
148+
),
149+
AllFieldlessEnum(..) => cx.dcx().span_bug(
150+
span,
151+
"unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
152+
),
153+
}
154+
}
155+
156+
/// Generates an equality comparison expression for a single struct or enum field.
157+
///
158+
/// This function produces an AST expression that compares the `self` and `other` values for a field using `==`.
159+
/// It removes any leading references from both sides for readability.
160+
/// If the field is a block expression, it is wrapped in parentheses to ensure valid syntax.
161+
///
162+
/// # Panics
163+
///
164+
/// Panics if there are not exactly two arguments to compare (should be `self` and `other`).
165+
fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
166+
let [rhs] = &field.other_selflike_exprs[..] else {
167+
cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
168+
};
169+
170+
cx.expr_binary(
171+
field.span,
172+
BinOpKind::Eq,
173+
wrap_block_expr(cx, peel_refs(&field.self_expr)),
174+
wrap_block_expr(cx, peel_refs(rhs)),
175+
)
176+
}
177+
178+
/// Removes all leading immutable references from an expression.
179+
///
180+
/// This is used to strip away any number of leading `&` from an expression (e.g., `&&&T` becomes `T`).
181+
/// Only removes immutable references; mutable references are preserved.
182+
fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
183+
while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
184+
expr = &inner;
185+
}
186+
expr.clone()
187+
}
188+
189+
/// Wraps a block expression in parentheses to ensure valid AST in macro expansion output.
190+
///
191+
/// If the given expression is a block, it is wrapped in parentheses; otherwise, it is returned unchanged.
192+
fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
193+
if matches!(&expr.kind, ExprKind::Block(..)) {
194+
return cx.expr_paren(expr.span, expr);
195+
}
196+
expr
197+
}

compiler/rustc_builtin_macros/src/deriving/generic/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
284284
/// The expressions corresponding to references to this field in
285285
/// the other selflike arguments.
286286
pub other_selflike_exprs: Vec<P<Expr>>,
287+
pub is_scalar: bool,
287288
}
288289

289290
#[derive(Copy, Clone)]
@@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {
12201221

12211222
let self_expr = discr_exprs.remove(0);
12221223
let other_selflike_exprs = discr_exprs;
1223-
let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
1224+
let discr_field =
1225+
FieldInfo { span, name: None, self_expr, other_selflike_exprs, is_scalar: true };
12241226

12251227
let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
12261228
.map(|(&ident, selflike_arg)| {
@@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
15331535
name: struct_field.ident,
15341536
self_expr,
15351537
other_selflike_exprs,
1538+
is_scalar: struct_field.ty.peel_refs().kind.is_scalar(),
15361539
}
15371540
})
15381541
.collect()

0 commit comments

Comments
 (0)