diff --git a/compiler/rustc_abi/src/layout.rs b/compiler/rustc_abi/src/layout.rs index 6e1299944a09d..8b94e88a31a7a 100644 --- a/compiler/rustc_abi/src/layout.rs +++ b/compiler/rustc_abi/src/layout.rs @@ -196,7 +196,7 @@ impl LayoutCalculator { pub fn layout_of_struct_or_enum< 'a, FieldIdx: Idx, - VariantIdx: Idx, + VariantIdx: Idx + PartialOrd, F: Deref> + fmt::Debug + Copy, >( &self, @@ -468,7 +468,7 @@ impl LayoutCalculator { fn layout_of_enum< 'a, FieldIdx: Idx, - VariantIdx: Idx, + VariantIdx: Idx + PartialOrd, F: Deref> + fmt::Debug + Copy, >( &self, @@ -528,8 +528,16 @@ impl LayoutCalculator { let niche_variants = all_indices.clone().find(|v| needs_disc(*v)).unwrap() ..=all_indices.rev().find(|v| needs_disc(*v)).unwrap(); - let count = - (niche_variants.end().index() as u128 - niche_variants.start().index() as u128) + 1; + let count = { + let niche_variants_len = (niche_variants.end().index() as u128 + - niche_variants.start().index() as u128) + + 1; + if niche_variants.contains(&largest_variant_index) { + niche_variants_len - 1 + } else { + niche_variants_len + } + }; // Use the largest niche in the largest variant. let niche = variant_layouts[largest_variant_index].largest_niche?; diff --git a/compiler/rustc_abi/src/lib.rs b/compiler/rustc_abi/src/lib.rs index 84d756b6d517c..246936c4b7252 100644 --- a/compiler/rustc_abi/src/lib.rs +++ b/compiler/rustc_abi/src/lib.rs @@ -1498,15 +1498,59 @@ pub enum TagEncoding { Direct, /// Niche (values invalid for a type) encoding the discriminant: - /// Discriminant and variant index coincide. + /// Discriminant and variant index doesn't always coincide. + /// /// The variant `untagged_variant` contains a niche at an arbitrary /// offset (field `tag_field` of the enum), which for a variant with - /// discriminant `d` is set to - /// `(d - niche_variants.start).wrapping_add(niche_start)`. + /// discriminant `d` is set to `d.wrapping_add(niche_start)`. /// - /// For example, `Option<(usize, &T)>` is represented such that - /// `None` has a null pointer for the second tuple field, and - /// `Some` is the identity function (with a non-null reference). + /// As for how to compute the discriminant, we have an optimization here that we allocate discriminant + /// value starting from the variant after the `untagged_variant` when the `untagged_variant` is + /// contained in `niche_variants`' range. Thus the `untagged_variant` won't be allocated with a + /// unneeded discriminant. Motivation for this is issue #117238. + /// For example, + /// ``` + /// enum SomeEnum { + /// A, // 1 + /// B, // 2 + /// C(bool), // untagged_variant, no discriminant + /// D, // has a discriminant of 0 + /// } + /// ``` + /// The algorithm is as follows: + /// ```rust,ignore (pseudo-code) + /// // We ignore leading and trailing variants that don't need discriminants. + /// adjusted_len = niche_variants.end - niche_variants.start + 1 + /// adjusted_index = variant_index - niche_variants.start + /// d = if niche_variants.contains(untagged_variant) { + /// adjusted_untagged_index = untagged_variant - niche_variants.start + /// (adjusted_index + adjusted_len - adjusted_untagged_index) % adjusted_len - 1 + /// } else { + /// adjusted_index + /// } + /// tag_value = d.wrapping_add(niche_start) + /// ``` + /// To load variant index from tag value: + /// ```rust,ignore (pseudo-code) + /// adjusted_len = niche_variants.end - niche_variants.start + 1 + /// d = tag_value.wrapping_sub(niche_start) + /// variant_index = if niche_variants.contains(untagged_variant) { + /// if d < adjusted_len - 1 { + /// adjusted_untagged_index = untagged_variant - niche_variants.start + /// (d + 1 + adjusted_untagged_index) % adjusted_len + niche_variants.start + /// } else { + /// // When the discriminant is larger than the number of variants having + /// // discriminant, we know it represents the untagged_variant. + /// untagged_variant + /// } + /// } else { + /// if d < adjusted_len { + /// d + niche_variants.start + /// } else { + /// untagged_variant + /// } + /// } + /// ``` Niche { untagged_variant: VariantIdx, niche_variants: RangeInclusive, diff --git a/compiler/rustc_codegen_cranelift/src/discriminant.rs b/compiler/rustc_codegen_cranelift/src/discriminant.rs index d462dcd63a925..74dd94f14bac5 100644 --- a/compiler/rustc_codegen_cranelift/src/discriminant.rs +++ b/compiler/rustc_codegen_cranelift/src/discriminant.rs @@ -52,10 +52,20 @@ pub(crate) fn codegen_set_discriminant<'tcx>( variants: _, } => { if variant_index != untagged_variant { + let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1; + let adj_idx = variant_index.index() - niche_variants.start().index(); + let niche = place.place_field(fx, FieldIdx::new(tag_field)); let niche_type = fx.clif_type(niche.layout().ty).unwrap(); - let niche_value = variant_index.as_u32() - niche_variants.start().as_u32(); - let niche_value = (niche_value as u128).wrapping_add(niche_start); + + let discr = if niche_variants.contains(&untagged_variant) { + let adj_untagged_idx = + untagged_variant.index() - niche_variants.start().index(); + (adj_idx + discr_len - adj_untagged_idx) % discr_len - 1 + } else { + adj_idx + }; + let niche_value = (discr as u128).wrapping_add(niche_start); let niche_value = match niche_type { types::I128 => { let lsb = fx.bcx.ins().iconst(types::I64, niche_value as u64 as i64); @@ -131,72 +141,91 @@ pub(crate) fn codegen_get_discriminant<'tcx>( dest.write_cvalue(fx, res); } TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => { - let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32(); - - // We have a subrange `niche_start..=niche_end` inside `range`. - // If the value of the tag is inside this subrange, it's a - // "niche value", an increment of the discriminant. Otherwise it - // indicates the untagged variant. - // A general algorithm to extract the discriminant from the tag - // is: - // relative_tag = tag - niche_start - // is_niche = relative_tag <= (ule) relative_max - // discr = if is_niche { - // cast(relative_tag) + niche_variants.start() - // } else { - // untagged_variant - // } - // However, we will likely be able to emit simpler code. - - let (is_niche, tagged_discr, delta) = if relative_max == 0 { - // Best case scenario: only one tagged variant. This will - // likely become just a comparison and a jump. - // The algorithm is: - // is_niche = tag == niche_start - // discr = if is_niche { - // niche_start - // } else { - // untagged_variant - // } + // See the algorithm explanation in the definition of `TagEncoding::Niche`. + let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1; + + let niche_start_value = match fx.bcx.func.dfg.value_type(tag) { + types::I128 => { + let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64); + let msb = fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64); + fx.bcx.ins().iconcat(lsb, msb) + } + ty => fx.bcx.ins().iconst(ty, niche_start as i64), + }; + + let (is_niche, tagged_discr) = if discr_len == 1 { + // Special case where we only have a single tagged variant. + // The untagged variant can't be contained in niche_variant's range in this case. + // Thus the discriminant of the only tagged variant is 0 and its variant index + // is the start of niche_variants. let is_niche = codegen_icmp_imm(fx, IntCC::Equal, tag, niche_start as i128); let tagged_discr = fx.bcx.ins().iconst(cast_to, niche_variants.start().as_u32() as i64); - (is_niche, tagged_discr, 0) + (is_niche, tagged_discr) } else { - // The special cases don't apply, so we'll have to go with - // the general algorithm. - let niche_start = match fx.bcx.func.dfg.value_type(tag) { - types::I128 => { - let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64); - let msb = - fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64); - fx.bcx.ins().iconcat(lsb, msb) - } - ty => fx.bcx.ins().iconst(ty, niche_start as i64), - }; - let relative_discr = fx.bcx.ins().isub(tag, niche_start); - let cast_tag = clif_intcast(fx, relative_discr, cast_to, false); - let is_niche = crate::common::codegen_icmp_imm( - fx, - IntCC::UnsignedLessThanOrEqual, - relative_discr, - i128::from(relative_max), - ); - (is_niche, cast_tag, niche_variants.start().as_u32() as u128) - }; + // General case. + let discr = fx.bcx.ins().isub(tag, niche_start_value); + let tagged_discr = clif_intcast(fx, discr, cast_to, false); + if niche_variants.contains(&untagged_variant) { + let is_niche = crate::common::codegen_icmp_imm( + fx, + IntCC::UnsignedLessThan, + discr, + (discr_len - 1) as i128, + ); + let adj_untagged_idx = + untagged_variant.index() - niche_variants.start().index(); + let untagged_delta = 1 + adj_untagged_idx; + let untagged_delta = match cast_to { + types::I128 => { + let lsb = fx.bcx.ins().iconst(types::I64, untagged_delta as i64); + let msb = fx.bcx.ins().iconst(types::I64, 0); + fx.bcx.ins().iconcat(lsb, msb) + } + ty => fx.bcx.ins().iconst(ty, untagged_delta as i64), + }; + let tagged_discr = fx.bcx.ins().iadd(tagged_discr, untagged_delta); - let tagged_discr = if delta == 0 { - tagged_discr - } else { - let delta = match cast_to { - types::I128 => { - let lsb = fx.bcx.ins().iconst(types::I64, delta as u64 as i64); - let msb = fx.bcx.ins().iconst(types::I64, (delta >> 64) as u64 as i64); - fx.bcx.ins().iconcat(lsb, msb) - } - ty => fx.bcx.ins().iconst(ty, delta as i64), - }; - fx.bcx.ins().iadd(tagged_discr, delta) + let discr_len = match cast_to { + types::I128 => { + let lsb = fx.bcx.ins().iconst(types::I64, discr_len as i64); + let msb = fx.bcx.ins().iconst(types::I64, 0); + fx.bcx.ins().iconcat(lsb, msb) + } + ty => fx.bcx.ins().iconst(ty, discr_len as i64), + }; + let tagged_discr = fx.bcx.ins().urem(tagged_discr, discr_len); + + let niche_variants_start = niche_variants.start().index(); + let niche_variants_start = match cast_to { + types::I128 => { + let lsb = fx.bcx.ins().iconst(types::I64, niche_variants_start as i64); + let msb = fx.bcx.ins().iconst(types::I64, 0); + fx.bcx.ins().iconcat(lsb, msb) + } + ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64), + }; + let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start); + (is_niche, tagged_discr) + } else { + let is_niche = crate::common::codegen_icmp_imm( + fx, + IntCC::UnsignedLessThan, + discr, + (discr_len - 1) as i128, + ); + let niche_variants_start = niche_variants.start().index(); + let niche_variants_start = match cast_to { + types::I128 => { + let lsb = fx.bcx.ins().iconst(types::I64, niche_variants_start as i64); + let msb = fx.bcx.ins().iconst(types::I64, 0); + fx.bcx.ins().iconcat(lsb, msb) + } + ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64), + }; + let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start); + (is_niche, tagged_discr) + } }; let untagged_variant = if cast_to == types::I128 { diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs index fe1634146ff83..2b0f57692444e 100644 --- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs +++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs @@ -391,9 +391,21 @@ fn compute_discriminant_value<'ll, 'tcx>( DiscrResult::Range(min, max) } else { - let value = (variant_index.as_u32() as u128) - .wrapping_sub(niche_variants.start().as_u32() as u128) - .wrapping_add(niche_start); + let discr_len = niche_variants.end().as_u32() as u128 + - niche_variants.start().as_u32() as u128 + + 1; + // FIXME: Why do we even return discriminant for absent variants? + let adj_idx = (variant_index.as_u32() as u128) + .wrapping_sub(niche_variants.start().as_u32() as u128); + + let discr = if niche_variants.contains(&untagged_variant) { + let adj_untagged_idx = + (untagged_variant.as_u32() - niche_variants.start().as_u32()) as u128; + (adj_idx + discr_len - adj_untagged_idx) % discr_len - 1 + } else { + adj_idx + }; + let value = discr.wrapping_add(niche_start); let value = tag.size(cx).truncate(value); DiscrResult::Value(value) } diff --git a/compiler/rustc_codegen_ssa/src/mir/place.rs b/compiler/rustc_codegen_ssa/src/mir/place.rs index a7d5541481a6c..4959f3a90e47f 100644 --- a/compiler/rustc_codegen_ssa/src/mir/place.rs +++ b/compiler/rustc_codegen_ssa/src/mir/place.rs @@ -287,54 +287,53 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { _ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)), }; - let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32(); - - // We have a subrange `niche_start..=niche_end` inside `range`. - // If the value of the tag is inside this subrange, it's a - // "niche value", an increment of the discriminant. Otherwise it - // indicates the untagged variant. - // A general algorithm to extract the discriminant from the tag - // is: - // relative_tag = tag - niche_start - // is_niche = relative_tag <= (ule) relative_max - // discr = if is_niche { - // cast(relative_tag) + niche_variants.start() - // } else { - // untagged_variant - // } - // However, we will likely be able to emit simpler code. - let (is_niche, tagged_discr, delta) = if relative_max == 0 { - // Best case scenario: only one tagged variant. This will - // likely become just a comparison and a jump. - // The algorithm is: - // is_niche = tag == niche_start - // discr = if is_niche { - // niche_start - // } else { - // untagged_variant - // } - let niche_start = bx.cx().const_uint_big(tag_llty, niche_start); + // See the algorithm explanation in the definition of `TagEncoding::Niche`. + let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1; + let niche_start = bx.cx().const_uint_big(tag_llty, niche_start); + let (is_niche, tagged_discr) = if discr_len == 1 { + // Special case where we only have a single tagged variant. + // The untagged variant can't be contained in niche_variant's range in this case. + // Thus the discriminant of the only tagged variant is 0 and its variant index + // is the start of niche_variants. let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start); let tagged_discr = bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64); - (is_niche, tagged_discr, 0) + (is_niche, tagged_discr) } else { - // The special cases don't apply, so we'll have to go with - // the general algorithm. - let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start)); - let cast_tag = bx.intcast(relative_discr, cast_to, false); - let is_niche = bx.icmp( - IntPredicate::IntULE, - relative_discr, - bx.cx().const_uint(tag_llty, relative_max as u64), - ); - (is_niche, cast_tag, niche_variants.start().as_u32() as u128) - }; - - let tagged_discr = if delta == 0 { - tagged_discr - } else { - bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta)) + // General case. + let discr = bx.sub(tag, niche_start); + let tagged_discr = bx.intcast(discr, cast_to, false); + if niche_variants.contains(&untagged_variant) { + let is_niche = bx.icmp( + IntPredicate::IntULT, + discr, + bx.cx().const_uint(tag_llty, (discr_len - 1) as u64), + ); + let adj_untagged_idx = + untagged_variant.index() - niche_variants.start().index(); + let tagged_discr = bx.add( + tagged_discr, + bx.cx().const_uint_big(cast_to, (1 + adj_untagged_idx) as u128), + ); + let tagged_discr = bx + .urem(tagged_discr, bx.cx().const_uint_big(cast_to, discr_len as u128)); + let tagged_discr = bx.add( + tagged_discr, + bx.cx().const_uint_big(cast_to, niche_variants.start().index() as u128), + ); + (is_niche, tagged_discr) + } else { + let is_niche = bx.icmp( + IntPredicate::IntULT, + discr, + bx.cx().const_uint(tag_llty, discr_len as u64), + ); + let tagged_discr = bx.add( + tagged_discr, + bx.cx().const_uint_big(cast_to, niche_variants.start().index() as u128), + ); + (is_niche, tagged_discr) + } }; let discr = bx.select( @@ -384,10 +383,20 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { .. } => { if variant_index != untagged_variant { + let discr_len = + niche_variants.end().index() - niche_variants.start().index() + 1; + let adj_idx = variant_index.index() - niche_variants.start().index(); + let niche = self.project_field(bx, tag_field); let niche_llty = bx.cx().immediate_backend_type(niche.layout); - let niche_value = variant_index.as_u32() - niche_variants.start().as_u32(); - let niche_value = (niche_value as u128).wrapping_add(niche_start); + let discr = if niche_variants.contains(&untagged_variant) { + let adj_untagged_idx = + untagged_variant.index() - niche_variants.start().index(); + (adj_idx + discr_len - adj_untagged_idx) % discr_len - 1 + } else { + adj_idx + }; + let niche_value = (discr as u128).wrapping_add(niche_start); // FIXME(eddyb): check the actual primitive type here. let niche_llval = if niche_value == 0 { // HACK(eddyb): using `c_null` as it works on all types. diff --git a/compiler/rustc_const_eval/src/interpret/discriminant.rs b/compiler/rustc_const_eval/src/interpret/discriminant.rs index 81e0b1e12caf4..2cff779243db8 100644 --- a/compiler/rustc_const_eval/src/interpret/discriminant.rs +++ b/compiler/rustc_const_eval/src/interpret/discriminant.rs @@ -166,9 +166,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { untagged_variant } Ok(tag_bits) => { + // See the algorithm explanation in the definition of `TagEncoding::Niche`. + let discr_len = (variants_end - variants_start) + .checked_add(1) + .expect("the number of niche variants fits into u32"); + let tag_bits = tag_bits.to_bits(tag_layout.size); // We need to use machine arithmetic to get the relative variant idx: - // variant_index_relative = tag_val - niche_start_val let tag_val = ImmTy::from_uint(tag_bits, tag_layout); let niche_start_val = ImmTy::from_uint(niche_start, tag_layout); let variant_index_relative_val = @@ -176,21 +180,45 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let variant_index_relative = variant_index_relative_val.to_scalar().to_bits(tag_val.layout.size)?; // Check if this is in the range that indicates an actual discriminant. - if variant_index_relative <= u128::from(variants_end - variants_start) { - let variant_index_relative = u32::try_from(variant_index_relative) - .expect("we checked that this fits into a u32"); - // Then computing the absolute variant idx should not overflow any more. - let variant_index = VariantIdx::from_u32( - variants_start - .checked_add(variant_index_relative) - .expect("overflow computing absolute variant idx"), - ); - let variants = - ty.ty_adt_def().expect("tagged layout for non adt").variants(); - assert!(variant_index < variants.next_index()); - variant_index + if niche_variants.contains(&untagged_variant) { + if variant_index_relative < u128::from(discr_len - 1) { + let adj_untagged_idx = untagged_variant.as_u32() - variants_start; + let variant_index_relative = u32::try_from(variant_index_relative) + .expect("we checked that this fits into a u32"); + let variant_index_to_modulo = variant_index_relative + .checked_add(1) + .expect("overflow computing absolute variant idx") + .checked_add(adj_untagged_idx) + .expect("overflow computing absolute variant idx"); + let variant_index = VariantIdx::from_u32( + variants_start + .checked_add(variant_index_to_modulo % discr_len) + .expect("overflow computing absolute variant idx"), + ); + let variants = + ty.ty_adt_def().expect("tagged layout for non adt").variants(); + assert!(variant_index < variants.next_index()); + variant_index + } else { + untagged_variant + } } else { - untagged_variant + if variant_index_relative < u128::from(discr_len) { + let variant_index_relative = u32::try_from(variant_index_relative) + .expect("we checked that this fits into a u32"); + // Then computing the absolute variant idx should not overflow any more. + let variant_index = VariantIdx::from_u32( + variants_start + .checked_add(variant_index_relative) + .expect("overflow computing absolute variant idx"), + ); + let variants = + ty.ty_adt_def().expect("tagged layout for non adt").variants(); + assert!(variant_index < variants.next_index()); + variant_index + } else { + untagged_variant + } } } }; @@ -286,11 +314,24 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { .. } => { assert!(variant_index != untagged_variant); + let discr_len = (niche_variants.end().as_u32() - niche_variants.start().as_u32()) + .checked_add(1) + .expect("the number of niche variants fits into u32"); let variants_start = niche_variants.start().as_u32(); - let variant_index_relative = variant_index + let adj_idx = variant_index .as_u32() .checked_sub(variants_start) .expect("overflow computing relative variant idx"); + + let variant_index_relative = if niche_variants.contains(&untagged_variant) { + let adj_untagged_idx = untagged_variant.as_u32() - variants_start; + let adj_idx_to_modulo = adj_idx + .checked_add(discr_len - adj_untagged_idx) + .expect("overflow computing relative variant idx"); + adj_idx_to_modulo % discr_len - 1 + } else { + adj_idx + }; // We need to use machine arithmetic when taking into account `niche_start`: // tag_val = variant_index_relative + niche_start_val let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs b/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs index 4cdc0db46a15f..9fe80f186eeca 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs @@ -34,7 +34,7 @@ pub use self::{ mod adt; mod target; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd)] pub struct RustcEnumVariantIdx(pub usize); impl rustc_index::Idx for RustcEnumVariantIdx { diff --git a/tests/codegen/enum/enum-match.rs b/tests/codegen/enum/enum-match.rs index a24b98050d232..a73f146c73f0f 100644 --- a/tests/codegen/enum/enum-match.rs +++ b/tests/codegen/enum/enum-match.rs @@ -25,31 +25,6 @@ pub fn match0(e: Enum0) -> u8 { } } -// Case 1: Niche values are on a boundary for `range`. -pub enum Enum1 { - A(bool), - B, - C, -} - -// CHECK: define noundef{{( range\(i8 [0-9]+, [0-9]+\))?}} i8 @match1{{.*}} -// CHECK-NEXT: start: -// CHECK-NEXT: %1 = add{{( nsw)?}} i8 %0, -2 -// CHECK-NEXT: %2 = zext i8 %1 to i64 -// CHECK-NEXT: %3 = icmp ult i8 %1, 2 -// CHECK-NEXT: %4 = add nuw nsw i64 %2, 1 -// CHECK-NEXT: %_2 = select i1 %3, i64 %4, i64 0 -#[no_mangle] -pub fn match1(e: Enum1) -> u8 { - use Enum1::*; - match e { - A(b) => b as u8, - B => 13, - C => 100, - } -} - -// Case 2: Special cases don't apply. #[rustfmt::skip] pub enum X { _2=2, _3, _4, _5, _6, _7, _8, _9, _10, _11, @@ -84,6 +59,37 @@ pub enum X { _246, _247, _248, _249, _250, _251, _252, _253, } +// Case 1: Special case doesn't apply. And the untagged variant is contained in niche_variants. +pub enum Enum1 { + A, + B, + C(X), + D, + E, +} + +// CHECK: define noundef{{( range\(i8 [0-9]+, [0-9]+\))?}} i8 @match1{{.*}} +// CHECK-NEXT: start: +// CHECK-NEXT: %1 = add i8 %0, 2 +// CHECK-NEXT: %2 = zext i8 %1 to i64 +// CHECK-NEXT: %3 = icmp ult i8 %1, 4 +// CHECK-NEXT: %4 = add nuw nsw i64 %2, 3 +// CHECK-NEXT: %5 = urem i64 %4, 5 +// CHECK-NEXT: %_2 = select i1 %3, i64 %5, i64 2 +// CHECK-NEXT: switch i64 %_2, label {{.*}} [ +#[no_mangle] +pub fn match1(e: Enum1) -> u8 { + use Enum1::*; + match e { + A => 0, + B => 1, + C(c) => c as u8, + D => 254, + E => 255, + } +} + +// Case 2: Special cases don't apply. And the untagged variant is not contained in niche_variants. pub enum Enum2 { A(X), B, diff --git a/tests/ui/enum-discriminant/get_discr.rs b/tests/ui/enum-discriminant/get_discr.rs index d7d11274de40f..ba9b6d3866e5a 100644 --- a/tests/ui/enum-discriminant/get_discr.rs +++ b/tests/ui/enum-discriminant/get_discr.rs @@ -1,4 +1,5 @@ //@ run-pass +use std::mem; // Now that there are several variations on the code generated in // `codegen_get_discr`, let's make sure the various cases yield the correct @@ -6,6 +7,8 @@ // To get the discriminant of an E value, there are no shortcuts - we must // do the full algorithm. +// `X1` is u8 with two niche values. +#[derive(Clone, Copy)] #[repr(u8)] pub enum X1 { _1 = 1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, @@ -31,6 +34,7 @@ pub enum X2 { _1 = -1, _2 = 0, _3 = 1, } +#[derive(Clone, Copy)] #[repr(i8)] pub enum X3 { _1 = -128, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, @@ -86,6 +90,116 @@ pub fn match_e(e: E) -> u8 { } } +#[derive(Clone, Copy)] +pub enum Void {} +// Special case that there's only one tagged variant. +#[derive(Clone, Copy)] +pub enum E1 { + A(X), // the untagged variant. + B, +} + +pub const fn match_e1(e: E1) -> u8 { + use E1::*; + match e { + A(_) => 0, + B => 1, + } +} + +#[derive(Clone, Copy)] +pub enum E1WithAbsent { + V1(Void), + A(X), + V2(Void), + B, + V3(Void), +} + +pub const fn match_e1_with_absent(e: E1WithAbsent) -> u8 { + use E1WithAbsent::*; + match e { + A(_) => 0, + B => 1, + _ => unreachable!(), + } +} + +// General case. And the untagged variant is contained in the niche_variants. +#[derive(Clone, Copy)] +pub enum E2 { + A, + B(X), + C, +} + +pub const fn match_e2(e: E2) -> u8 { + use E2::*; + match e { + A => 0, + B(_) => 1, + C => 2, + } +} + +#[derive(Clone, Copy)] +pub enum E2WithAbsent { + V1(Void), + V2(Void), + A, + B(X), + C, + V3(Void), +} + +pub const fn match_e2_with_absent(e: E2WithAbsent) -> u8 { + use E2WithAbsent::*; + match e { + A => 0, + B(_) => 1, + C => 2, + _ => unreachable!(), + } +} + +// General case. And the untagged variant is not contained in the niche_variants. +#[derive(Clone, Copy)] +pub enum E3 { + A, + B, + C(X) +} + +pub const fn match_e3(e: E3) -> u8 { + use E3::*; + match e { + A => 0, + B => 1, + C(_) => 2, + } +} + +#[derive(Clone, Copy)] +pub enum E3WithAbsent { + V1(Void), + V2(Void), + A, + B, + C(X), + V3(Void), +} + +pub const fn match_e3_with_absent(e: E3WithAbsent) -> u8 { + use E3WithAbsent::*; + match e { + A => 0, + B => 1, + C(_) => 2, + _ => unreachable!(), + } +} + + fn main() { assert_eq!(match_e(E::A(X1::_1)), 0); assert_eq!(match_e(E::A(X1::_2)), 0); @@ -111,4 +225,159 @@ fn main() { assert_eq!(match_e(E::A(true)), 0); assert_eq!(match_e(E::::B), 1); assert_eq!(match_e(E::::C), 2); + + // Check `u8` primitive type as discriminant. + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e1(E1::A(X1::_1)), 0); + assert_eq!(match_e1(E1::A(X1::_2)), 0); + assert_eq!(match_e1(E1::A(X1::_254)), 0); + assert_eq!(match_e1(E1::::B), 1); + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e1_with_absent(E1WithAbsent::A(X1::_1)), 0); + assert_eq!(match_e1_with_absent(E1WithAbsent::A(X1::_2)), 0); + assert_eq!(match_e1_with_absent(E1WithAbsent::A(X1::_254)), 0); + assert_eq!(match_e1_with_absent(E1WithAbsent::::B), 1); + + // Check `i8` primitive type as discriminant. + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e1(E1::A(X3::_1)), 0); + assert_eq!(match_e1(E1::A(X3::_2)), 0); + assert_eq!(match_e1(E1::A(X3::_254)), 0); + assert_eq!(match_e1(E1::::B), 1); + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e1_with_absent(E1WithAbsent::A(X3::_1)), 0); + assert_eq!(match_e1_with_absent(E1WithAbsent::A(X3::_2)), 0); + assert_eq!(match_e1_with_absent(E1WithAbsent::A(X3::_254)), 0); + assert_eq!(match_e1_with_absent(E1WithAbsent::::B), 1); + + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e2(E2::::A), 0); + assert_eq!(match_e2(E2::B(X1::_1)), 1); + assert_eq!(match_e2(E2::B(X1::_2)), 1); + assert_eq!(match_e2(E2::B(X1::_254)), 1); + assert_eq!(match_e2(E2::::C), 2); + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::::A), 0); + assert_eq!(match_e2_with_absent(E2WithAbsent::B(X1::_1)), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::B(X1::_2)), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::B(X1::_254)), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::::C), 2); + + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e2(E2::::A), 0); + assert_eq!(match_e2(E2::B(X3::_1)), 1); + assert_eq!(match_e2(E2::B(X3::_2)), 1); + assert_eq!(match_e2(E2::B(X3::_254)), 1); + assert_eq!(match_e2(E2::::C), 2); + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::::A), 0); + assert_eq!(match_e2_with_absent(E2WithAbsent::B(X3::_1)), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::B(X3::_2)), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::B(X3::_254)), 1); + assert_eq!(match_e2_with_absent(E2WithAbsent::::C), 2); + + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e3(E3::::A), 0); + assert_eq!(match_e3(E3::::B), 1); + assert_eq!(match_e3(E3::C(X1::_1)), 2); + assert_eq!(match_e3(E3::C(X1::_2)), 2); + assert_eq!(match_e3(E3::C(X1::_254)), 2); + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e3_with_absent(E3WithAbsent::::A), 0); + assert_eq!(match_e3_with_absent(E3WithAbsent::::B), 1); + assert_eq!(match_e3_with_absent(E3WithAbsent::C(X1::_1)), 2); + assert_eq!(match_e3_with_absent(E3WithAbsent::C(X1::_2)), 2); + assert_eq!(match_e3_with_absent(E3WithAbsent::C(X1::_254)), 2); + + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e3(E3::::A), 0); + assert_eq!(match_e3(E3::::B), 1); + assert_eq!(match_e3(E3::C(X3::_1)), 2); + assert_eq!(match_e3(E3::C(X3::_2)), 2); + assert_eq!(match_e3(E3::C(X3::_254)), 2); + assert_eq!(mem::size_of::>(), 1); + assert_eq!(match_e3_with_absent(E3WithAbsent::::A), 0); + assert_eq!(match_e3_with_absent(E3WithAbsent::::B), 1); + assert_eq!(match_e3_with_absent(E3WithAbsent::C(X3::_1)), 2); + assert_eq!(match_e3_with_absent(E3WithAbsent::C(X3::_2)), 2); + assert_eq!(match_e3_with_absent(E3WithAbsent::C(X3::_254)), 2); + + // Check set_discr and get_discr work as intended in const eval too. + const _: () = { + // Check `u8` primitive type as discriminant. + assert!(mem::size_of::>() == 1); + assert!(match_e1(E1::A(X1::_1)) == 0); + assert!(match_e1(E1::A(X1::_2)) == 0); + assert!(match_e1(E1::A(X1::_254)) == 0); + assert!(match_e1(E1::::B) == 1); + assert!(mem::size_of::>() == 1); + assert!(match_e1_with_absent(E1WithAbsent::A(X1::_1)) == 0); + assert!(match_e1_with_absent(E1WithAbsent::A(X1::_2)) == 0); + assert!(match_e1_with_absent(E1WithAbsent::A(X1::_254)) == 0); + assert!(match_e1_with_absent(E1WithAbsent::::B) == 1); + + // Check `i8` primitive type as discriminant. + assert!(mem::size_of::>() == 1); + assert!(match_e1(E1::A(X3::_1)) == 0); + assert!(match_e1(E1::A(X3::_2)) == 0); + assert!(match_e1(E1::A(X3::_254)) == 0); + assert!(match_e1(E1::::B) == 1); + assert!(mem::size_of::>() == 1); + assert!(match_e1_with_absent(E1WithAbsent::A(X3::_1)) == 0); + assert!(match_e1_with_absent(E1WithAbsent::A(X3::_2)) == 0); + assert!(match_e1_with_absent(E1WithAbsent::A(X3::_254)) == 0); + assert!(match_e1_with_absent(E1WithAbsent::::B) == 1); + + assert!(mem::size_of::>() == 1); + assert!(match_e2(E2::::A) == 0); + assert!(match_e2(E2::B(X1::_1)) == 1); + assert!(match_e2(E2::B(X1::_2)) == 1); + assert!(match_e2(E2::B(X1::_254)) == 1); + assert!(match_e2(E2::::C) == 2); + assert!(mem::size_of::>() == 1); + assert!(match_e2_with_absent(E2WithAbsent::::A) == 0); + assert!(match_e2_with_absent(E2WithAbsent::B(X1::_1)) == 1); + assert!(match_e2_with_absent(E2WithAbsent::B(X1::_2)) == 1); + assert!(match_e2_with_absent(E2WithAbsent::B(X1::_254)) == 1); + assert!(match_e2_with_absent(E2WithAbsent::::C) == 2); + + assert!(mem::size_of::>() == 1); + assert!(match_e2(E2::::A) == 0); + assert!(match_e2(E2::B(X3::_1)) == 1); + assert!(match_e2(E2::B(X3::_2)) == 1); + assert!(match_e2(E2::B(X3::_254)) == 1); + assert!(match_e2(E2::::C) == 2); + assert!(mem::size_of::>() == 1); + assert!(match_e2_with_absent(E2WithAbsent::::A) == 0); + assert!(match_e2_with_absent(E2WithAbsent::B(X3::_1)) == 1); + assert!(match_e2_with_absent(E2WithAbsent::B(X3::_2)) == 1); + assert!(match_e2_with_absent(E2WithAbsent::B(X3::_254)) == 1); + assert!(match_e2_with_absent(E2WithAbsent::::C) == 2); + + assert!(mem::size_of::>() == 1); + assert!(match_e3(E3::::A) == 0); + assert!(match_e3(E3::::B) == 1); + assert!(match_e3(E3::C(X1::_1)) == 2); + assert!(match_e3(E3::C(X1::_2)) == 2); + assert!(match_e3(E3::C(X1::_254)) == 2); + assert!(mem::size_of::>() == 1); + assert!(match_e3_with_absent(E3WithAbsent::::A) == 0); + assert!(match_e3_with_absent(E3WithAbsent::::B) == 1); + assert!(match_e3_with_absent(E3WithAbsent::C(X1::_1)) == 2); + assert!(match_e3_with_absent(E3WithAbsent::C(X1::_2)) == 2); + assert!(match_e3_with_absent(E3WithAbsent::C(X1::_254)) == 2); + + assert!(mem::size_of::>() == 1); + assert!(match_e3(E3::::A) == 0); + assert!(match_e3(E3::::B) == 1); + assert!(match_e3(E3::C(X3::_1)) == 2); + assert!(match_e3(E3::C(X3::_2)) == 2); + assert!(match_e3(E3::C(X3::_254)) == 2); + assert!(mem::size_of::>() == 1); + assert!(match_e3_with_absent(E3WithAbsent::::A) == 0); + assert!(match_e3_with_absent(E3WithAbsent::::B) == 1); + assert!(match_e3_with_absent(E3WithAbsent::C(X3::_1)) == 2); + assert!(match_e3_with_absent(E3WithAbsent::C(X3::_2)) == 2); + assert!(match_e3_with_absent(E3WithAbsent::C(X3::_254)) == 2); + }; }