Skip to content

Commit 6e72625

Browse files
committed
Check equality of scalar types first.
1 parent ebe9b00 commit 6e72625

File tree

5 files changed

+136
-12
lines changed

5 files changed

+136
-12
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,6 +2195,34 @@ impl FloatTy {
21952195
}
21962196
}
21972197

2198+
impl<'a> TryFrom<&'a str> for FloatTy {
2199+
type Error = ();
2200+
2201+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2202+
Ok(match value {
2203+
"f16" => Self::F16,
2204+
"f32" => Self::F32,
2205+
"f64" => Self::F64,
2206+
"f128" => Self::F128,
2207+
_ => return Err(()),
2208+
})
2209+
}
2210+
}
2211+
2212+
impl TryFrom<Symbol> for FloatTy {
2213+
type Error = ();
2214+
2215+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2216+
Ok(match value {
2217+
sym::f16 => Self::F16,
2218+
sym::f32 => Self::F32,
2219+
sym::f64 => Self::F64,
2220+
sym::f128 => Self::F128,
2221+
_ => return Err(()),
2222+
})
2223+
}
2224+
}
2225+
21982226
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
21992227
#[derive(Encodable, Decodable, HashStable_Generic)]
22002228
pub enum IntTy {
@@ -2230,6 +2258,38 @@ impl IntTy {
22302258
}
22312259
}
22322260

2261+
impl<'a> TryFrom<&'a str> for IntTy {
2262+
type Error = ();
2263+
2264+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2265+
Ok(match value {
2266+
"isize" => Self::Isize,
2267+
"i8" => Self::I8,
2268+
"i16" => Self::I16,
2269+
"i32" => Self::I32,
2270+
"i64" => Self::I64,
2271+
"i128" => Self::I128,
2272+
_ => return Err(()),
2273+
})
2274+
}
2275+
}
2276+
2277+
impl TryFrom<Symbol> for IntTy {
2278+
type Error = ();
2279+
2280+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2281+
Ok(match value {
2282+
sym::isize => Self::Isize,
2283+
sym::i8 => Self::I8,
2284+
sym::i16 => Self::I16,
2285+
sym::i32 => Self::I32,
2286+
sym::i64 => Self::I64,
2287+
sym::i128 => Self::I128,
2288+
_ => return Err(()),
2289+
})
2290+
}
2291+
}
2292+
22332293
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Debug)]
22342294
#[derive(Encodable, Decodable, HashStable_Generic)]
22352295
pub enum UintTy {
@@ -2265,6 +2325,38 @@ impl UintTy {
22652325
}
22662326
}
22672327

2328+
impl<'a> TryFrom<&'a str> for UintTy {
2329+
type Error = ();
2330+
2331+
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
2332+
Ok(match value {
2333+
"usize" => Self::Usize,
2334+
"u8" => Self::U8,
2335+
"u16" => Self::U16,
2336+
"u32" => Self::U32,
2337+
"u64" => Self::U64,
2338+
"u128" => Self::U128,
2339+
_ => return Err(()),
2340+
})
2341+
}
2342+
}
2343+
2344+
impl TryFrom<Symbol> for UintTy {
2345+
type Error = ();
2346+
2347+
fn try_from(value: Symbol) -> Result<Self, Self::Error> {
2348+
Ok(match value {
2349+
sym::usize => Self::Usize,
2350+
sym::u8 => Self::U8,
2351+
sym::u16 => Self::U16,
2352+
sym::u32 => Self::U32,
2353+
sym::u64 => Self::U64,
2354+
sym::u128 => Self::U128,
2355+
_ => return Err(()),
2356+
})
2357+
}
2358+
}
2359+
22682360
/// A constraint on an associated item.
22692361
///
22702362
/// ### Examples
@@ -2452,6 +2544,21 @@ impl TyKind {
24522544
None
24532545
}
24542546
}
2547+
2548+
pub fn is_scalar(&self) -> bool {
2549+
let Some(ty_kind) = self.is_simple_path() else {
2550+
match self {
2551+
TyKind::Tup(tys) => return tys.is_empty(), // unit type
2552+
_ => return false,
2553+
}
2554+
};
2555+
2556+
ty_kind == sym::bool
2557+
|| ty_kind == sym::char
2558+
|| IntTy::try_from(ty_kind).is_ok()
2559+
|| UintTy::try_from(ty_kind).is_ok()
2560+
|| FloatTy::try_from(ty_kind).is_ok()
2561+
}
24552562
}
24562563

24572564
/// A pattern type pattern.

compiler/rustc_builtin_macros/src/deriving/cmp/ord.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ pub(crate) fn cs_cmp(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) ->
6666
let args = thin_vec![field.self_expr.clone(), other_expr.clone()];
6767
cx.expr_call_global(field.span, cmp_path.clone(), args)
6868
}
69-
CsFold::Combine(span, expr1, expr2) => {
69+
CsFold::Combine { field_info, field_expr: expr2, other_expr: expr1 } => {
70+
let span = field_info.span;
7071
let eq_arm = cx.arm(span, cx.pat_path(span, equal_path.clone()), expr1);
7172
let neq_arm =
7273
cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));

compiler/rustc_builtin_macros/src/deriving/cmp/partial_eq.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub(crate) fn expand_deriving_partial_eq(
1818
) {
1919
fn cs_eq(cx: &ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
2020
let base = true;
21+
2122
let expr = cs_fold(
2223
true, // use foldl
2324
cx,
@@ -63,8 +64,12 @@ pub(crate) fn expand_deriving_partial_eq(
6364
convert(other_expr),
6465
)
6566
}
66-
CsFold::Combine(span, expr1, expr2) => {
67-
cx.expr_binary(span, BinOpKind::And, expr1, expr2)
67+
CsFold::Combine { field_info, mut field_expr, mut other_expr } => {
68+
// Comparison of numeric primitives is cheap, therefore compare them first.
69+
if field_info.is_scalar {
70+
(field_expr, other_expr) = (other_expr, field_expr);
71+
}
72+
cx.expr_binary(field_info.span, BinOpKind::And, other_expr, field_expr)
6873
}
6974
CsFold::Fieldless => cx.expr_bool(span, base),
7075
},

compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ fn cs_partial_cmp(
100100
let args = thin_vec![field.self_expr.clone(), other_expr.clone()];
101101
cx.expr_call_global(field.span, partial_cmp_path.clone(), args)
102102
}
103-
CsFold::Combine(span, mut expr1, expr2) => {
103+
CsFold::Combine { field_info, field_expr: expr2, other_expr: mut expr1 } => {
104104
// When the item is an enum, this expands to
105105
// ```
106106
// match (expr2) {
@@ -130,6 +130,7 @@ fn cs_partial_cmp(
130130
// }
131131
// ```
132132
// Reference: https://github.com/rust-lang/rust/pull/103659#issuecomment-1328126354
133+
let span = field_info.span;
133134

134135
if !discr_then_data
135136
&& let ExprKind::Match(_, arms, _) = &mut expr1.kind

compiler/rustc_builtin_macros/src/deriving/generic/mod.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ pub(crate) struct FieldInfo {
284284
/// The expressions corresponding to references to this field in
285285
/// the other selflike arguments.
286286
pub other_selflike_exprs: Vec<P<Expr>>,
287+
pub is_scalar: bool,
287288
}
288289

289290
#[derive(Copy, Clone)]
@@ -1220,7 +1221,8 @@ impl<'a> MethodDef<'a> {
12201221

12211222
let self_expr = discr_exprs.remove(0);
12221223
let other_selflike_exprs = discr_exprs;
1223-
let discr_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
1224+
let discr_field =
1225+
FieldInfo { span, name: None, self_expr, other_selflike_exprs, is_scalar: true };
12241226

12251227
let discr_let_stmts: ThinVec<_> = iter::zip(&discr_idents, &selflike_args)
12261228
.map(|(&ident, selflike_arg)| {
@@ -1533,6 +1535,7 @@ impl<'a> TraitDef<'a> {
15331535
name: struct_field.ident,
15341536
self_expr,
15351537
other_selflike_exprs,
1538+
is_scalar: struct_field.ty.peel_refs().kind.is_scalar(),
15361539
}
15371540
})
15381541
.collect()
@@ -1607,7 +1610,11 @@ pub(crate) enum CsFold<'a> {
16071610

16081611
/// The combination of two field expressions. E.g. for `PartialEq::eq` this
16091612
/// is something like `<field1 equality> && <field2 equality>`.
1610-
Combine(Span, P<Expr>, P<Expr>),
1613+
Combine {
1614+
field_info: &'a FieldInfo,
1615+
field_expr: P<Expr>,
1616+
other_expr: P<Expr>,
1617+
},
16111618

16121619
// The fallback case for a struct or enum variant with no fields.
16131620
Fieldless,
@@ -1641,7 +1648,7 @@ where
16411648

16421649
let op = |old, field: &FieldInfo| {
16431650
let new = f(cx, CsFold::Single(field));
1644-
f(cx, CsFold::Combine(field.span, old, new))
1651+
f(cx, CsFold::Combine { field_info: field, field_expr: new, other_expr: old })
16451652
};
16461653

16471654
if use_foldl {
@@ -1653,11 +1660,14 @@ where
16531660
EnumDiscr(discr_field, match_expr) => {
16541661
let discr_check_expr = f(cx, CsFold::Single(discr_field));
16551662
if let Some(match_expr) = match_expr {
1656-
if use_foldl {
1657-
f(cx, CsFold::Combine(trait_span, discr_check_expr, match_expr.clone()))
1658-
} else {
1659-
f(cx, CsFold::Combine(trait_span, match_expr.clone(), discr_check_expr))
1660-
}
1663+
f(
1664+
cx,
1665+
CsFold::Combine {
1666+
field_info: discr_field,
1667+
field_expr: discr_check_expr,
1668+
other_expr: match_expr.clone(),
1669+
},
1670+
)
16611671
} else {
16621672
discr_check_expr
16631673
}

0 commit comments

Comments
 (0)