From 1aa3bab98686041b3c81552bcfcd046118d24b56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Fri, 31 Jan 2025 07:10:35 +0100 Subject: [PATCH 01/12] initial group-finding pattern handler --- compiler/rustc_middle/src/arena.rs | 1 + .../src/builder/matches/match_pair.rs | 175 ++++++++++++++++-- .../src/builder/matches/util.rs | 44 +++++ 3 files changed, 204 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 750531b638e4d..e52d6fc60eba6 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -114,6 +114,7 @@ macro_rules! arena_types { [decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph, [] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls, [] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>, + [] thir_pats: rustc_middle::thir::Pat<'tcx>, ]); ) } diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 9d59ffc88ba23..e67371280c6f4 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -1,6 +1,9 @@ use rustc_middle::mir::*; use rustc_middle::thir::{self, *}; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; +use std::ops; +use either::Either; +use crate::builder::matches::util::Range; use crate::builder::Builder; use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder}; @@ -33,6 +36,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// Used internally by [`MatchPairTree::for_pattern`]. fn prefix_slice_suffix<'pat>( &mut self, + top_pattern: &'pat Pat<'tcx>, match_pairs: &mut Vec>, place: &PlaceBuilder<'tcx>, prefix: &'pat [Box>], @@ -54,11 +58,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ((prefix.len() + suffix.len()).try_into().unwrap(), false) }; - match_pairs.extend(prefix.iter().enumerate().map(|(idx, subpattern)| { - let elem = - ProjectionElem::ConstantIndex { offset: idx as u64, min_length, from_end: false }; - MatchPairTree::for_pattern(place.clone_project(elem), subpattern, self) - })); + if !prefix.is_empty() { + let bounds = Range::from_start(0..prefix.len() as u64); + let subpattern = bounds.apply(prefix); + for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern) { + match_pairs.push(pair); + } + } if let Some(subslice_pat) = opt_slice { let suffix_len = suffix.len() as u64; @@ -70,16 +76,153 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { match_pairs.push(MatchPairTree::for_pattern(subslice, subslice_pat, self)); } - match_pairs.extend(suffix.iter().rev().enumerate().map(|(idx, subpattern)| { - let end_offset = (idx + 1) as u64; - let elem = ProjectionElem::ConstantIndex { - offset: if exact_size { min_length - end_offset } else { end_offset }, - min_length, - from_end: !exact_size, + if !suffix.is_empty() { + let bounds = Range::from_end(0..suffix.len() as u64); + let subpattern = bounds.apply(suffix); + for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern) { + match_pairs.push(pair); + } + } + } + + fn build_slice_branch<'pat, 'b>( + &'b mut self, + bounds: Range, + place: &'b PlaceBuilder<'tcx>, + top_pattern: &'pat Pat<'tcx>, + pattern: &'pat [Box>], + ) -> impl Iterator> + use<'a, 'tcx, 'pat, 'b> { + let entries = self.find_const_groups(pattern); + + entries.into_iter().map(move |entry| { + let mut build_single = |idx| { + let subpattern = &pattern[idx as usize]; + let place = place.clone_project(ProjectionElem::ConstantIndex { + offset: bounds.shift_idx(idx), + min_length: pattern.len() as u64, + from_end: bounds.from_end, + }); + + MatchPairTree::for_pattern(place, subpattern, self) }; - let place = place.clone_project(elem); - MatchPairTree::for_pattern(place, subpattern, self) - })); + + match entry { + Either::Right(range) if range.end - range.start > 1 => { + assert!( + (range.start..range.end) + .all(|idx| self.is_constant_pattern(&pattern[idx as usize])) + ); + + let subpattern = &pattern[range.start as usize..range.end as usize]; + let elem_ty = subpattern[0].ty; + + let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern); + self.valtree_to_match_pair( + top_pattern, + valtree, + place.clone(), + elem_ty, + bounds.shift_range(range), + ) + } + Either::Right(range) => { + let tree = build_single(range.start); + assert!(self.is_constant_pattern(&pattern[range.start as usize])); + tree + } + Either::Left(idx) => build_single(idx), + } + }) + } + + fn find_const_groups(&self, pattern: &[Box>]) -> Vec>> { + let mut entries = Vec::new(); + let mut current_seq_start = None; + + for (idx, pat) in pattern.iter().enumerate() { + if self.is_constant_pattern(pat) { + if current_seq_start.is_none() { + current_seq_start = Some(idx as u64); + } else { + continue; + } + } else { + if let Some(start) = current_seq_start { + entries.push(Either::Right(start..idx as u64)); + current_seq_start = None; + } + entries.push(Either::Left(idx as u64)); + } + } + + if let Some(start) = current_seq_start { + entries.push(Either::Right(start..pattern.len() as u64)); + } + + entries + } + + fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool { + if let PatKind::Constant { value } = pat.kind + && let Const::Ty(_, const_) = value + && let ty::ConstKind::Value(cv) = const_.kind() + && let ty::ValTree::Leaf(_) = cv.valtree + { + true + } else { + false + } + } + + fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> { + if let PatKind::Constant { value } = pat.kind + && let Const::Ty(_, const_) = value + && let ty::ConstKind::Value(cv) = const_.kind() + && matches!(cv.valtree, ty::ValTree::Leaf(_)) + { + cv.valtree + } else { + unreachable!() + } + } + + fn simplify_const_pattern_slice_into_valtree( + &self, + subslice: &[Box>], + ) -> ty::ValTree<'tcx> { + let leaves = subslice.iter().map(|p| self.extract_leaf(p)); + let interned = self.tcx.arena.alloc_from_iter(leaves); + ty::ValTree::Branch(interned) + } + + fn valtree_to_match_pair<'pat>( + &mut self, + source_pattern: &'pat Pat<'tcx>, + valtree: ty::ValTree<'tcx>, + place: PlaceBuilder<'tcx>, + elem_ty: Ty<'tcx>, + range: Range, + ) -> MatchPairTree<'pat, 'tcx> { + let tcx = self.tcx; + let const_ty = + Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, Ty::new_array(tcx, elem_ty, range.len())); + + let pat_ty = if do_slice { Ty::new_slice(tcx, elem_ty) } else { source_pattern.ty }; + let ty_const = ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree })); + let value = Const::Ty(const_ty, ty_const); + let test_case = TestCase::Constant { value }; + let pattern = tcx.arena.alloc(Pat { + ty: pat_ty, + span: source_pattern.span, + kind: PatKind::Constant { value }, + }); + + MatchPairTree { + place: Some(place.to_place(self)), + test_case, + subpairs: Vec::new(), + pattern, + } } } @@ -192,11 +335,11 @@ impl<'pat, 'tcx> MatchPairTree<'pat, 'tcx> { } PatKind::Array { ref prefix, ref slice, ref suffix } => { - cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix); + cx.prefix_slice_suffix(pattern, &mut subpairs, &place_builder, prefix, slice, suffix); default_irrefutable() } PatKind::Slice { ref prefix, ref slice, ref suffix } => { - cx.prefix_slice_suffix(&mut subpairs, &place_builder, prefix, slice, suffix); + cx.prefix_slice_suffix(pattern, &mut subpairs, &place_builder, prefix, slice, suffix); if prefix.is_empty() && slice.is_some() && suffix.is_empty() { default_irrefutable() diff --git a/compiler/rustc_mir_build/src/builder/matches/util.rs b/compiler/rustc_mir_build/src/builder/matches/util.rs index 1bd399e511b39..2e56a7b89152e 100644 --- a/compiler/rustc_mir_build/src/builder/matches/util.rs +++ b/compiler/rustc_mir_build/src/builder/matches/util.rs @@ -1,6 +1,7 @@ use rustc_data_structures::fx::FxIndexMap; use rustc_middle::mir::*; use rustc_middle::ty::Ty; +use std::ops; use rustc_span::Span; use tracing::debug; @@ -229,3 +230,46 @@ pub(crate) fn ref_pat_borrow_kind(ref_mutability: Mutability) -> BorrowKind { Mutability::Not => BorrowKind::Shared, } } + +#[derive(Copy, Clone, PartialEq, Debug)] +pub(super) struct Range { + pub(crate) start: u64, + pub(crate) end: u64, + pub(crate) from_end: bool, +} + +impl Range { + pub(crate) fn from_start(range: ops::Range) -> Self { + Range { start: range.start, end: range.end, from_end: false } + } + + pub(crate) fn from_end(range: ops::Range) -> Self { + Range { start: range.end, end: range.start, from_end: true } + } + + pub(crate) fn len(self) -> u64 { + if !self.from_end { self.end - self.start } else { self.start - self.end } + } + + pub(crate) fn apply(self, slice: &[T]) -> &[T] { + if !self.from_end { + &slice[self.start as usize..self.end as usize] + } else { + &slice[..self.start as usize - self.end as usize] + } + } + + pub(crate) fn shift_idx(self, idx: u64) -> u64 { + if !self.from_end { self.start + idx } else { self.start - idx } + } + + pub(crate) fn shift_range(self, range_within: ops::Range) -> Self { + if !self.from_end { + Self::from_start(self.start + range_within.start..self.start + range_within.end) + } else { + let range_within_start = range_within.end; + let range_within_end = range_within.start; + Self::from_end(self.start - range_within_start..self.start - range_within_end) + } + } +} From 913eaaa4986f750c23de72dac4985a78eb4cb809 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 15:07:45 +0100 Subject: [PATCH 02/12] split out block for handling fused groups --- .../src/builder/matches/match_pair.rs | 141 +++++++++++++++--- .../src/builder/matches/util.rs | 3 +- 2 files changed, 120 insertions(+), 24 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index e67371280c6f4..c4547d4d55cc2 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -1,12 +1,13 @@ +use std::ops; + +use either::Either; use rustc_middle::mir::*; use rustc_middle::thir::{self, *}; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; -use std::ops; -use either::Either; -use crate::builder::matches::util::Range; use crate::builder::Builder; use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder}; +use crate::builder::matches::util::Range; use crate::builder::matches::{FlatPat, MatchPairTree, TestCase}; impl<'a, 'tcx> Builder<'a, 'tcx> { @@ -61,7 +62,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if !prefix.is_empty() { let bounds = Range::from_start(0..prefix.len() as u64); let subpattern = bounds.apply(prefix); - for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern) { + for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) + { match_pairs.push(pair); } } @@ -79,7 +81,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if !suffix.is_empty() { let bounds = Range::from_end(0..suffix.len() as u64); let subpattern = bounds.apply(suffix); - for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern) { + for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) + { match_pairs.push(pair); } } @@ -91,6 +94,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { place: &'b PlaceBuilder<'tcx>, top_pattern: &'pat Pat<'tcx>, pattern: &'pat [Box>], + min_length: u64, ) -> impl Iterator> + use<'a, 'tcx, 'pat, 'b> { let entries = self.find_const_groups(pattern); @@ -123,6 +127,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { place.clone(), elem_ty, bounds.shift_range(range), + min_length, ) } Either::Right(range) => { @@ -202,26 +207,101 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { place: PlaceBuilder<'tcx>, elem_ty: Ty<'tcx>, range: Range, + min_length: u64, ) -> MatchPairTree<'pat, 'tcx> { let tcx = self.tcx; - let const_ty = - Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, Ty::new_array(tcx, elem_ty, range.len())); - - let pat_ty = if do_slice { Ty::new_slice(tcx, elem_ty) } else { source_pattern.ty }; - let ty_const = ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree })); - let value = Const::Ty(const_ty, ty_const); - let test_case = TestCase::Constant { value }; - let pattern = tcx.arena.alloc(Pat { - ty: pat_ty, - span: source_pattern.span, - kind: PatKind::Constant { value }, - }); + let leaves = match valtree { + ty::ValTree::Leaf(_) => unreachable!(), + ty::ValTree::Branch(leaves) => leaves, + }; + + if range.from_end { + todo!(); + } + + assert!(range.len() == leaves.len() as u64); + let mut subpairs = Vec::new(); + + let mut were_merged = 0; + if elem_ty == tcx.types.u8 { + let groups = (0..usize::MAX).take_while(|i| i * 2 + 1 <= leaves.len()); + + let leaf_bits = |leaf: ty::ValTree<'tcx>| { + if let ty::ValTree::Leaf(scalar) = leaf { scalar.to_u8() } else { todo!() } + }; + + for g_idx in groups { + were_merged += 2; + + let lo = leaf_bits(leaves[g_idx]); + let hi = leaf_bits(leaves[g_idx + 1]); + let data = u16::from_le_bytes([lo, hi]); + let valtree = ty::ValTree::Leaf(ty::ScalarInt::from(data)); + + let elem_ty = tcx.types.u16; + let ty_const = + ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree })); + let value = Const::Ty(elem_ty, ty_const); + let test_case = TestCase::Constant { value }; + + let pattern = tcx.arena.alloc(Pat { + ty: elem_ty, + span: source_pattern.span, + kind: PatKind::Constant { value }, + }); + + let place = place + .clone_project(ProjectionElem::ConstantIndex { + offset: range.start + g_idx as u64, + min_length, + from_end: range.from_end, + }) + .to_place(self); + + subpairs.push(MatchPairTree { + place: Some(place), + test_case, + subpairs: Vec::new(), + pattern, + }); + } + } + + for (idx, leaf) in leaves.iter().enumerate().skip(were_merged) { + let ty_const = ty::Const::new( + tcx, + ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree: *leaf }), + ); + let value = Const::Ty(elem_ty, ty_const); + let test_case = TestCase::Constant { value }; + + let pattern = tcx.arena.alloc(Pat { + ty: elem_ty, + span: source_pattern.span, + kind: PatKind::Constant { value }, + }); + + let place = place + .clone_project(ProjectionElem::ConstantIndex { + offset: range.start + idx as u64, + min_length, + from_end: range.from_end, + }) + .to_place(self); + + subpairs.push(MatchPairTree { + place: Some(place), + test_case, + subpairs: Vec::new(), + pattern, + }); + } MatchPairTree { - place: Some(place.to_place(self)), - test_case, - subpairs: Vec::new(), - pattern, + place: None, + test_case: TestCase::Irrefutable { binding: None, ascription: None }, + subpairs, + pattern: source_pattern, } } } @@ -335,15 +415,30 @@ impl<'pat, 'tcx> MatchPairTree<'pat, 'tcx> { } PatKind::Array { ref prefix, ref slice, ref suffix } => { - cx.prefix_slice_suffix(pattern, &mut subpairs, &place_builder, prefix, slice, suffix); + cx.prefix_slice_suffix( + pattern, + &mut subpairs, + &place_builder, + prefix, + slice, + suffix, + ); default_irrefutable() } PatKind::Slice { ref prefix, ref slice, ref suffix } => { - cx.prefix_slice_suffix(pattern, &mut subpairs, &place_builder, prefix, slice, suffix); + cx.prefix_slice_suffix( + pattern, + &mut subpairs, + &place_builder, + prefix, + slice, + suffix, + ); if prefix.is_empty() && slice.is_some() && suffix.is_empty() { default_irrefutable() } else { + // TODO: do we always need this? TestCase::Slice { len: prefix.len() + suffix.len(), variable_length: slice.is_some(), diff --git a/compiler/rustc_mir_build/src/builder/matches/util.rs b/compiler/rustc_mir_build/src/builder/matches/util.rs index 2e56a7b89152e..8ba3c46682183 100644 --- a/compiler/rustc_mir_build/src/builder/matches/util.rs +++ b/compiler/rustc_mir_build/src/builder/matches/util.rs @@ -1,7 +1,8 @@ +use std::ops; + use rustc_data_structures::fx::FxIndexMap; use rustc_middle::mir::*; use rustc_middle::ty::Ty; -use std::ops; use rustc_span::Span; use tracing::debug; From 36c3fd5abe5bfbde59394aaaa5243af6d1541cfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 17:56:57 +0100 Subject: [PATCH 03/12] constant merge testing --- .../src/builder/expr/as_place.rs | 4 ++ .../src/builder/matches/match_pair.rs | 4 +- .../src/builder/matches/mod.rs | 3 +- .../src/builder/matches/test.rs | 39 ++++++++++++++++--- src/llvm-project | 2 +- 5 files changed, 43 insertions(+), 9 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/expr/as_place.rs b/compiler/rustc_mir_build/src/builder/expr/as_place.rs index 0086775e9f46d..f12b8a8cd15e4 100644 --- a/compiler/rustc_mir_build/src/builder/expr/as_place.rs +++ b/compiler/rustc_mir_build/src/builder/expr/as_place.rs @@ -301,6 +301,10 @@ impl<'tcx> PlaceBuilder<'tcx> { &self.projection } + pub(crate) fn projection_mut(&mut self) -> &mut [PlaceElem<'tcx>] { + &mut self.projection + } + pub(crate) fn field(self, f: FieldIdx, ty: Ty<'tcx>) -> Self { self.project(PlaceElem::Field(f, ty)) } diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index c4547d4d55cc2..1095cb6e18c65 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -224,7 +224,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let mut were_merged = 0; if elem_ty == tcx.types.u8 { - let groups = (0..usize::MAX).take_while(|i| i * 2 + 1 <= leaves.len()); + let groups = (0..usize::MAX).take_while(|i| i * 2 + 1 < leaves.len()); let leaf_bits = |leaf: ty::ValTree<'tcx>| { if let ty::ValTree::Leaf(scalar) = leaf { scalar.to_u8() } else { todo!() } @@ -242,7 +242,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let ty_const = ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree })); let value = Const::Ty(elem_ty, ty_const); - let test_case = TestCase::Constant { value }; + let test_case = TestCase::FusedConstant { _value: value }; let pattern = tcx.arena.alloc(Pat { ty: elem_ty, diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index b21ec8f3083b3..75dca042ea2be 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -1238,6 +1238,7 @@ enum TestCase<'pat, 'tcx> { Irrefutable { binding: Option>, ascription: Option> }, Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx }, Constant { value: mir::Const<'tcx> }, + FusedConstant { _value: mir::Const<'tcx> }, Range(&'pat PatRange<'tcx>), Slice { len: usize, variable_length: bool }, Deref { temp: Place<'tcx>, mutability: Mutability }, @@ -1304,7 +1305,7 @@ enum TestKind<'tcx> { /// /// The test's target values are not stored here; instead they are extracted /// from the [`TestCase`]s of the candidates participating in the test. - SwitchInt, + SwitchInt { fused: bool }, /// Test whether a `bool` is `true` or `false`. If, diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index afe6b4475be3c..f0747910510e9 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -19,7 +19,7 @@ use rustc_span::source_map::Spanned; use rustc_span::{DUMMY_SP, Span, Symbol, sym}; use tracing::{debug, instrument}; -use crate::builder::Builder; +use crate::builder::{Builder, PlaceBuilder}; use crate::builder::matches::{Candidate, MatchPairTree, Test, TestBranch, TestCase, TestKind}; impl<'a, 'tcx> Builder<'a, 'tcx> { @@ -34,9 +34,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestCase::Variant { adt_def, variant_index: _ } => TestKind::Switch { adt_def }, TestCase::Constant { .. } if match_pair.pattern.ty.is_bool() => TestKind::If, - TestCase::Constant { .. } if is_switch_ty(match_pair.pattern.ty) => TestKind::SwitchInt, + TestCase::Constant { .. } if is_switch_ty(match_pair.pattern.ty) => TestKind::SwitchInt { fused: false }, TestCase::Constant { value } => TestKind::Eq { value, ty: match_pair.pattern.ty }, + TestCase::FusedConstant { .. } => TestKind::SwitchInt { fused: true }, + TestCase::Range(range) => { assert_eq!(range.ty, match_pair.pattern.ty); TestKind::Range(Box::new(range.clone())) @@ -113,7 +115,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ); } - TestKind::SwitchInt => { + TestKind::SwitchInt { fused } => { // The switch may be inexhaustive so we have a catch-all block let otherwise_block = target_block(TestBranch::Failure); let switch_targets = SwitchTargets::new( @@ -126,6 +128,33 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }), otherwise_block, ); + + let mut place = place; + + if fused { + let tcx = self.tcx; + let source_info = self.source_info(match_start_span); + + let mut builder = PlaceBuilder::from(place); + match builder.projection_mut() { + [.., ProjectionElem::ConstantIndex { offset, ..}] => { + *offset += 1; + } + _ => todo!(), + } + let place_2 = builder.to_place(&self); + let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, 8)); + let fused_temp = self.temp(tcx.types.u16, DUMMY_SP); + let fused_temp2 = self.temp(tcx.types.u16, DUMMY_SP); + let fused_temp3 = self.temp(tcx.types.u16, DUMMY_SP); + let fused_final = self.temp(tcx.types.u16, DUMMY_SP); + self.cfg.push_assign(block, source_info, fused_temp, Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place_2), tcx.types.u16)); + self.cfg.push_assign(block, source_info, fused_temp2, Rvalue::BinaryOp(BinOp::Shr, Box::new((Operand::Copy(fused_temp), shift)))); + self.cfg.push_assign(block, source_info, fused_temp3, Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), tcx.types.u16)); + self.cfg.push_assign(block, source_info, fused_final, Rvalue::BinaryOp(BinOp::BitOr, Box::new((Operand::Copy(fused_temp2), Operand::Copy(fused_temp3))))); + place = Place::from(fused_final); + } + let terminator = TerminatorKind::SwitchInt { discr: Operand::Copy(place), targets: switch_targets, @@ -557,7 +586,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // // FIXME(#29623) we could use PatKind::Range to rule // things out here, in some cases. - (TestKind::SwitchInt, &TestCase::Constant { value }) + (TestKind::SwitchInt { .. }, &TestCase::Constant { value } | &TestCase::FusedConstant { _value: value }) if is_switch_ty(match_pair.pattern.ty) => { // An important invariant of candidate sorting is that a candidate @@ -591,7 +620,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { Some(TestBranch::Constant(value, bits)) } } - (TestKind::SwitchInt, TestCase::Range(range)) => { + (TestKind::SwitchInt { fused: _fused}, TestCase::Range(range)) => { // When performing a `SwitchInt` test, a range pattern can be // sorted into the failure arm if it doesn't contain _any_ of // the values being tested. (This restricts what values can be diff --git a/src/llvm-project b/src/llvm-project index 7e8c93c87c611..f3b462d0c6ca1 160000 --- a/src/llvm-project +++ b/src/llvm-project @@ -1 +1 @@ -Subproject commit 7e8c93c87c611f21d9bd95100563392f4c18bfe7 +Subproject commit f3b462d0c6ca1e96853d16530b54f45760b3eb04 From edacc62bf465e0d9c05aaae693b4a1e2327494e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 18:20:31 +0100 Subject: [PATCH 04/12] fuse groups of size 2,3,4. matching impl 1/2 --- .../src/builder/matches/match_pair.rs | 49 ++++++++++++----- .../src/builder/matches/mod.rs | 4 +- .../src/builder/matches/test.rs | 52 ++++++++++++++----- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 1095cb6e18c65..5cea2403e42ea 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -224,35 +224,40 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let mut were_merged = 0; if elem_ty == tcx.types.u8 { - let groups = (0..usize::MAX).take_while(|i| i * 2 + 1 < leaves.len()); - let leaf_bits = |leaf: ty::ValTree<'tcx>| { if let ty::ValTree::Leaf(scalar) = leaf { scalar.to_u8() } else { todo!() } }; - for g_idx in groups { - were_merged += 2; + let mut fuse_group = |first_idx, len| { + were_merged += len; - let lo = leaf_bits(leaves[g_idx]); - let hi = leaf_bits(leaves[g_idx + 1]); - let data = u16::from_le_bytes([lo, hi]); - let valtree = ty::ValTree::Leaf(ty::ScalarInt::from(data)); + let data = leaves[first_idx..first_idx + len] + .iter() + .copied() + .map(leaf_bits) + .fold(0u32, |acc, x| (acc << 8) | u32::from(x)); + + let fused_ty = match len { + 2 => tcx.types.u16, + 3 | 4 => tcx.types.u32, + _ => unreachable!(), + }; - let elem_ty = tcx.types.u16; + let valtree = ty::ValTree::Leaf(ty::ScalarInt::from(data)); let ty_const = - ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: elem_ty, valtree })); - let value = Const::Ty(elem_ty, ty_const); - let test_case = TestCase::FusedConstant { _value: value }; + ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: fused_ty, valtree })); + let value = Const::Ty(fused_ty, ty_const); + let test_case = TestCase::FusedConstant { value, fused: len as u64 }; let pattern = tcx.arena.alloc(Pat { - ty: elem_ty, + ty: fused_ty, span: source_pattern.span, kind: PatKind::Constant { value }, }); let place = place .clone_project(ProjectionElem::ConstantIndex { - offset: range.start + g_idx as u64, + offset: range.start + first_idx as u64, min_length, from_end: range.from_end, }) @@ -264,6 +269,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { subpairs: Vec::new(), pattern, }); + }; + + let x4 = (0..usize::MAX).take_while(|i| i * 4 + 1 < leaves.len()); + let x3 = (0..usize::MAX).take_while(|i| i * 3 + 1 < leaves.len()); + let x2 = (0..usize::MAX).take_while(|i| i * 2 + 1 < leaves.len()); + + for i in x4 { + fuse_group(i * 4, 4); + } + + for i in x3 { + fuse_group(i * 3, 3); + } + + for i in x2 { + fuse_group(i * 2, 2); } } diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index 75dca042ea2be..1da96c7e0dacf 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -1238,7 +1238,7 @@ enum TestCase<'pat, 'tcx> { Irrefutable { binding: Option>, ascription: Option> }, Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx }, Constant { value: mir::Const<'tcx> }, - FusedConstant { _value: mir::Const<'tcx> }, + FusedConstant { value: mir::Const<'tcx>, fused: u64 }, Range(&'pat PatRange<'tcx>), Slice { len: usize, variable_length: bool }, Deref { temp: Place<'tcx>, mutability: Mutability }, @@ -1305,7 +1305,7 @@ enum TestKind<'tcx> { /// /// The test's target values are not stored here; instead they are extracted /// from the [`TestCase`]s of the candidates participating in the test. - SwitchInt { fused: bool }, + SwitchInt { fused: u64 }, /// Test whether a `bool` is `true` or `false`. If, diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index f0747910510e9..68b27b5e67e00 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -19,8 +19,8 @@ use rustc_span::source_map::Spanned; use rustc_span::{DUMMY_SP, Span, Symbol, sym}; use tracing::{debug, instrument}; -use crate::builder::{Builder, PlaceBuilder}; use crate::builder::matches::{Candidate, MatchPairTree, Test, TestBranch, TestCase, TestKind}; +use crate::builder::{Builder, PlaceBuilder}; impl<'a, 'tcx> Builder<'a, 'tcx> { /// Identifies what test is needed to decide if `match_pair` is applicable. @@ -34,10 +34,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestCase::Variant { adt_def, variant_index: _ } => TestKind::Switch { adt_def }, TestCase::Constant { .. } if match_pair.pattern.ty.is_bool() => TestKind::If, - TestCase::Constant { .. } if is_switch_ty(match_pair.pattern.ty) => TestKind::SwitchInt { fused: false }, + TestCase::Constant { .. } if is_switch_ty(match_pair.pattern.ty) => { + TestKind::SwitchInt { fused: 1 } + } TestCase::Constant { value } => TestKind::Eq { value, ty: match_pair.pattern.ty }, - TestCase::FusedConstant { .. } => TestKind::SwitchInt { fused: true }, + TestCase::FusedConstant { fused, .. } => TestKind::SwitchInt { fused }, TestCase::Range(range) => { assert_eq!(range.ty, match_pair.pattern.ty); @@ -131,13 +133,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let mut place = place; - if fused { + if fused == 2 { let tcx = self.tcx; let source_info = self.source_info(match_start_span); let mut builder = PlaceBuilder::from(place); match builder.projection_mut() { - [.., ProjectionElem::ConstantIndex { offset, ..}] => { + [.., ProjectionElem::ConstantIndex { offset, .. }] => { *offset += 1; } _ => todo!(), @@ -148,10 +150,33 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let fused_temp2 = self.temp(tcx.types.u16, DUMMY_SP); let fused_temp3 = self.temp(tcx.types.u16, DUMMY_SP); let fused_final = self.temp(tcx.types.u16, DUMMY_SP); - self.cfg.push_assign(block, source_info, fused_temp, Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place_2), tcx.types.u16)); - self.cfg.push_assign(block, source_info, fused_temp2, Rvalue::BinaryOp(BinOp::Shr, Box::new((Operand::Copy(fused_temp), shift)))); - self.cfg.push_assign(block, source_info, fused_temp3, Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), tcx.types.u16)); - self.cfg.push_assign(block, source_info, fused_final, Rvalue::BinaryOp(BinOp::BitOr, Box::new((Operand::Copy(fused_temp2), Operand::Copy(fused_temp3))))); + self.cfg.push_assign( + block, + source_info, + fused_temp, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place_2), tcx.types.u16), + ); + self.cfg.push_assign( + block, + source_info, + fused_temp2, + Rvalue::BinaryOp(BinOp::Shr, Box::new((Operand::Copy(fused_temp), shift))), + ); + self.cfg.push_assign( + block, + source_info, + fused_temp3, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), tcx.types.u16), + ); + self.cfg.push_assign( + block, + source_info, + fused_final, + Rvalue::BinaryOp( + BinOp::BitOr, + Box::new((Operand::Copy(fused_temp2), Operand::Copy(fused_temp3))), + ), + ); place = Place::from(fused_final); } @@ -586,9 +611,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // // FIXME(#29623) we could use PatKind::Range to rule // things out here, in some cases. - (TestKind::SwitchInt { .. }, &TestCase::Constant { value } | &TestCase::FusedConstant { _value: value }) - if is_switch_ty(match_pair.pattern.ty) => - { + ( + TestKind::SwitchInt { .. }, + &TestCase::Constant { value } | &TestCase::FusedConstant { value, .. }, + ) if is_switch_ty(match_pair.pattern.ty) => { // An important invariant of candidate sorting is that a candidate // must not match in multiple branches. For `SwitchInt` tests, adding // a new value might invalidate that property for range patterns that @@ -620,7 +646,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { Some(TestBranch::Constant(value, bits)) } } - (TestKind::SwitchInt { fused: _fused}, TestCase::Range(range)) => { + (TestKind::SwitchInt { fused: _fused }, TestCase::Range(range)) => { // When performing a `SwitchInt` test, a range pattern can be // sorted into the failure arm if it doesn't contain _any_ of // the values being tested. (This restricts what values can be From 71fac2f33dc75178565f3d15999793e20cdaadf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 18:43:32 +0100 Subject: [PATCH 05/12] fuse groups of size 2,3,4. codegen impl 2/2 --- .../src/builder/matches/test.rs | 100 ++++++++++-------- 1 file changed, 56 insertions(+), 44 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 68b27b5e67e00..caf17e52ec3ee 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -131,57 +131,69 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { otherwise_block, ); - let mut place = place; - - if fused == 2 { + let discr = if fused > 1 { + assert!(fused <= 4 && place_ty.ty == self.tcx.types.u8); let tcx = self.tcx; let source_info = self.source_info(match_start_span); - let mut builder = PlaceBuilder::from(place); - match builder.projection_mut() { - [.., ProjectionElem::ConstantIndex { offset, .. }] => { - *offset += 1; + let fused_ty = match fused { + 2 => tcx.types.u16, + 3 | 4 => tcx.types.u32, + _ => unreachable!(), + }; + + let builder = PlaceBuilder::from(place); + + let place_for = move |b: &mut Self, idx| { + let mut builder = builder.clone(); + match builder.projection_mut() { + [.., ProjectionElem::ConstantIndex { offset, .. }] => { + *offset += idx; + } + _ => todo!(), } - _ => todo!(), + builder.to_place(b) + }; + + let mut acc = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, 0)); + + for i in 0..fused { + let new_acc = self.temp(fused_ty, DUMMY_SP); + let temp1 = self.temp(fused_ty, DUMMY_SP); + let temp2 = self.temp(fused_ty, DUMMY_SP); + let place = place_for(self, i); + let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, i * 8)); + + self.cfg.push_assign( + block, + source_info, + temp1, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), + ); + self.cfg.push_assign( + block, + source_info, + temp2, + Rvalue::BinaryOp(BinOp::Shr, Box::new((Operand::Copy(temp1), shift))), + ); + self.cfg.push_assign( + block, + source_info, + new_acc, + Rvalue::BinaryOp( + BinOp::BitOr, + Box::new((acc, Operand::Copy(temp2))), + ), + ); + + acc = Operand::Copy(new_acc); } - let place_2 = builder.to_place(&self); - let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, 8)); - let fused_temp = self.temp(tcx.types.u16, DUMMY_SP); - let fused_temp2 = self.temp(tcx.types.u16, DUMMY_SP); - let fused_temp3 = self.temp(tcx.types.u16, DUMMY_SP); - let fused_final = self.temp(tcx.types.u16, DUMMY_SP); - self.cfg.push_assign( - block, - source_info, - fused_temp, - Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place_2), tcx.types.u16), - ); - self.cfg.push_assign( - block, - source_info, - fused_temp2, - Rvalue::BinaryOp(BinOp::Shr, Box::new((Operand::Copy(fused_temp), shift))), - ); - self.cfg.push_assign( - block, - source_info, - fused_temp3, - Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), tcx.types.u16), - ); - self.cfg.push_assign( - block, - source_info, - fused_final, - Rvalue::BinaryOp( - BinOp::BitOr, - Box::new((Operand::Copy(fused_temp2), Operand::Copy(fused_temp3))), - ), - ); - place = Place::from(fused_final); - } + + acc + } else { Operand::Copy(place)}; let terminator = TerminatorKind::SwitchInt { - discr: Operand::Copy(place), + discr, targets: switch_targets, }; self.cfg.terminate(block, self.source_info(match_start_span), terminator); From f5c5b1bdf2bd620e13636b9277799ff1b071d935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 21:12:47 +0100 Subject: [PATCH 06/12] 32 bit fusion POC --- Cargo.toml | 5 +++ .../src/builder/matches/match_pair.rs | 32 +++++++++++-------- .../src/builder/matches/test.rs | 26 +++++++-------- 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b773030b4cab4..a89b10d322ac9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,3 +93,8 @@ codegen-units = 1 # FIXME: LTO cannot be enabled for binaries in a workspace # # lto = true + +[profile.release.package.rustc_mir_build] +opt-level = 0 +[profile.release.package.rustc_driver] +opt-level = 0 diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 5cea2403e42ea..22e4ba85e97ea 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -233,6 +233,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let data = leaves[first_idx..first_idx + len] .iter() + .rev() .copied() .map(leaf_bits) .fold(0u32, |acc, x| (acc << 8) | u32::from(x)); @@ -243,7 +244,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { _ => unreachable!(), }; - let valtree = ty::ValTree::Leaf(ty::ScalarInt::from(data)); + let scalar = match len { + 2 => ty::ScalarInt::from(data as u16), + 3 | 4 => ty::ScalarInt::from(data), + _ => unreachable!(), + }; + + let valtree = ty::ValTree::Leaf(scalar); let ty_const = ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: fused_ty, valtree })); let value = Const::Ty(fused_ty, ty_const); @@ -271,20 +278,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }); }; - let x4 = (0..usize::MAX).take_while(|i| i * 4 + 1 < leaves.len()); - let x3 = (0..usize::MAX).take_while(|i| i * 3 + 1 < leaves.len()); - let x2 = (0..usize::MAX).take_while(|i| i * 2 + 1 < leaves.len()); - - for i in x4 { - fuse_group(i * 4, 4); - } - - for i in x3 { - fuse_group(i * 3, 3); - } + let indices = |group_size, skip| { + (skip..usize::MAX) + .take_while(move |i| i * group_size + (group_size - 1) < leaves.len()) + }; - for i in x2 { - fuse_group(i * 2, 2); + let mut skip = 0; + for i in (2..=4).rev() { + for idx in indices(i, skip) { + fuse_group(idx * i, i); + skip += i; + } } } diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index caf17e52ec3ee..f9a2b1e703c6a 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -9,6 +9,7 @@ use std::cmp::Ordering; use rustc_data_structures::fx::FxIndexMap; use rustc_hir::{LangItem, RangeEnd}; +use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::*; use rustc_middle::ty::adjustment::PointerCoercion; use rustc_middle::ty::util::IntTypeExt; @@ -135,7 +136,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { assert!(fused <= 4 && place_ty.ty == self.tcx.types.u8); let tcx = self.tcx; let source_info = self.source_info(match_start_span); - let fused_ty = match fused { 2 => tcx.types.u16, 3 | 4 => tcx.types.u32, @@ -143,7 +143,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }; let builder = PlaceBuilder::from(place); - let place_for = move |b: &mut Self, idx| { let mut builder = builder.clone(); match builder.projection_mut() { @@ -155,7 +154,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { builder.to_place(b) }; - let mut acc = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, 0)); + let mut acc = Operand::const_from_scalar( + tcx, + fused_ty, + Scalar::from_uint(0u32, fused_ty.primitive_size(tcx)), + DUMMY_SP, + ); for i in 0..fused { let new_acc = self.temp(fused_ty, DUMMY_SP); @@ -174,28 +178,24 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { block, source_info, temp2, - Rvalue::BinaryOp(BinOp::Shr, Box::new((Operand::Copy(temp1), shift))), + Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp1), shift))), ); self.cfg.push_assign( block, source_info, new_acc, - Rvalue::BinaryOp( - BinOp::BitOr, - Box::new((acc, Operand::Copy(temp2))), - ), + Rvalue::BinaryOp(BinOp::BitOr, Box::new((acc, Operand::Copy(temp2)))), ); acc = Operand::Copy(new_acc); } acc - } else { Operand::Copy(place)}; - - let terminator = TerminatorKind::SwitchInt { - discr, - targets: switch_targets, + } else { + Operand::Copy(place) }; + + let terminator = TerminatorKind::SwitchInt { discr, targets: switch_targets }; self.cfg.terminate(block, self.source_info(match_start_span), terminator); } From 9975e86a83cd2113738fcb89e84d9fdf59bf51d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 22:10:52 +0100 Subject: [PATCH 07/12] eliminate initial zero assignment in fused const switches --- .../src/builder/matches/test.rs | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index f9a2b1e703c6a..d26ec9c1b93d5 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -154,48 +154,52 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { builder.to_place(b) }; - let mut acc = Operand::const_from_scalar( - tcx, - fused_ty, - Scalar::from_uint(0u32, fused_ty.primitive_size(tcx)), - DUMMY_SP, - ); + let temp = self.temp(fused_ty, DUMMY_SP); + let acc = self.temp(fused_ty, DUMMY_SP); + let zero = || { + Operand::const_from_scalar( + tcx, + fused_ty, + Scalar::from_uint(0u32, fused_ty.primitive_size(tcx)), + DUMMY_SP, + ) + }; for i in 0..fused { - let new_acc = self.temp(fused_ty, DUMMY_SP); - let temp1 = self.temp(fused_ty, DUMMY_SP); - let temp2 = self.temp(fused_ty, DUMMY_SP); let place = place_for(self, i); let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, i * 8)); + let or_lhs = if i == 0 { zero() } else { Operand::Copy(acc) }; self.cfg.push_assign( block, source_info, - temp1, + temp, Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), ); self.cfg.push_assign( block, source_info, - temp2, - Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp1), shift))), + temp, + Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp), shift))), ); + self.cfg.push_assign( block, source_info, - new_acc, - Rvalue::BinaryOp(BinOp::BitOr, Box::new((acc, Operand::Copy(temp2)))), + acc, + Rvalue::BinaryOp(BinOp::BitOr, Box::new((or_lhs, Operand::Copy(temp)))), ); - - acc = Operand::Copy(new_acc); } acc } else { - Operand::Copy(place) + place }; - let terminator = TerminatorKind::SwitchInt { discr, targets: switch_targets }; + let terminator = TerminatorKind::SwitchInt { + discr: Operand::Copy(discr), + targets: switch_targets, + }; self.cfg.terminate(block, self.source_info(match_start_span), terminator); } From bd3a1f4f801363cd99e1270048a750ee2eb52c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 22:49:32 +0100 Subject: [PATCH 08/12] simplify fused switch gen code, making it easier to read --- .../src/builder/matches/test.rs | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index d26ec9c1b93d5..4134c1cc4a41b 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -9,7 +9,6 @@ use std::cmp::Ordering; use rustc_data_structures::fx::FxIndexMap; use rustc_hir::{LangItem, RangeEnd}; -use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::*; use rustc_middle::ty::adjustment::PointerCoercion; use rustc_middle::ty::util::IntTypeExt; @@ -156,19 +155,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let temp = self.temp(fused_ty, DUMMY_SP); let acc = self.temp(fused_ty, DUMMY_SP); - let zero = || { - Operand::const_from_scalar( - tcx, - fused_ty, - Scalar::from_uint(0u32, fused_ty.primitive_size(tcx)), - DUMMY_SP, - ) - }; - for i in 0..fused { + self.cfg.push_assign( + block, + source_info, + acc, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), + ); + + for i in 1..fused { let place = place_for(self, i); let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, i * 8)); - let or_lhs = if i == 0 { zero() } else { Operand::Copy(acc) }; self.cfg.push_assign( block, @@ -182,12 +179,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { temp, Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp), shift))), ); - self.cfg.push_assign( block, source_info, acc, - Rvalue::BinaryOp(BinOp::BitOr, Box::new((or_lhs, Operand::Copy(temp)))), + Rvalue::BinaryOp( + BinOp::BitOr, + Box::new((Operand::Copy(acc), Operand::Copy(temp))), + ), ); } From b974dec3ae28d0ef0c12d4651e1f9a40afd3de71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sat, 1 Feb 2025 23:43:44 +0100 Subject: [PATCH 09/12] fix from-back indexing for constant patterns in slices --- .../rustc_mir_build/src/builder/matches/match_pair.rs | 7 ++----- compiler/rustc_mir_build/src/builder/matches/test.rs | 8 ++++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 22e4ba85e97ea..4f5a187f777c9 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -215,10 +215,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ty::ValTree::Branch(leaves) => leaves, }; - if range.from_end { - todo!(); - } - assert!(range.len() == leaves.len() as u64); let mut subpairs = Vec::new(); @@ -253,6 +249,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let valtree = ty::ValTree::Leaf(scalar); let ty_const = ty::Const::new(tcx, ty::ConstKind::Value(ty::Value { ty: fused_ty, valtree })); + let value = Const::Ty(fused_ty, ty_const); let test_case = TestCase::FusedConstant { value, fused: len as u64 }; @@ -264,7 +261,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let place = place .clone_project(ProjectionElem::ConstantIndex { - offset: range.start + first_idx as u64, + offset: range.shift_idx(first_idx as u64), min_length, from_end: range.from_end, }) diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 4134c1cc4a41b..4406c627f5d7e 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -145,8 +145,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let place_for = move |b: &mut Self, idx| { let mut builder = builder.clone(); match builder.projection_mut() { - [.., ProjectionElem::ConstantIndex { offset, .. }] => { - *offset += idx; + [.., ProjectionElem::ConstantIndex { offset, ref from_end, .. }] => { + if !from_end { + *offset += idx; + } else { + *offset -= idx; + } } _ => todo!(), } From c4344318178be559105423e6fc0a7ef837d97235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sun, 2 Feb 2025 03:08:44 +0100 Subject: [PATCH 10/12] code cleanup --- .../src/builder/matches/match_pair.rs | 51 +++--- .../src/builder/matches/test.rs | 160 ++++++++++-------- .../src/builder/matches/util.rs | 18 +- 3 files changed, 127 insertions(+), 102 deletions(-) diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 4f5a187f777c9..5e72338afcb24 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -1,6 +1,7 @@ use std::ops; use either::Either; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::thir::{self, *}; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; @@ -62,10 +63,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if !prefix.is_empty() { let bounds = Range::from_start(0..prefix.len() as u64); let subpattern = bounds.apply(prefix); - for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) - { - match_pairs.push(pair); - } + self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) + .for_each(|pair| match_pairs.push(pair)); } if let Some(subslice_pat) = opt_slice { @@ -81,13 +80,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if !suffix.is_empty() { let bounds = Range::from_end(0..suffix.len() as u64); let subpattern = bounds.apply(suffix); - for pair in self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) - { - match_pairs.push(pair); - } + self.build_slice_branch(bounds, place, top_pattern, subpattern, min_length) + .for_each(|pair| match_pairs.push(pair)); } } + // Traverses either side of a slice pattern (prefix/suffix) and yields an iterator of `MatchPairTree`s + // to cover all it's constant and non-constant subpatterns. fn build_slice_branch<'pat, 'b>( &'b mut self, bounds: Range, @@ -99,6 +98,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let entries = self.find_const_groups(pattern); entries.into_iter().map(move |entry| { + // Common case handler for both non-constant and constant subpatterns not in a range. let mut build_single = |idx| { let subpattern = &pattern[idx as usize]; let place = place.clone_project(ProjectionElem::ConstantIndex { @@ -112,14 +112,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { match entry { Either::Right(range) if range.end - range.start > 1 => { - assert!( - (range.start..range.end) - .all(|idx| self.is_constant_pattern(&pattern[idx as usize])) - ); - + // Figure out which subslice of our already sliced pattern we're looking at. let subpattern = &pattern[range.start as usize..range.end as usize]; let elem_ty = subpattern[0].ty; + // Right, we 've found a group of constant patterns worth grouping for later. + // We'll collect all the leaves we can find and create a single `ValTree` out of them. let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern); self.valtree_to_match_pair( top_pattern, @@ -130,16 +128,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { min_length, ) } - Either::Right(range) => { - let tree = build_single(range.start); - assert!(self.is_constant_pattern(&pattern[range.start as usize])); - tree - } + Either::Right(range) => build_single(range.start), Either::Left(idx) => build_single(idx), } }) } + // Given a partial view of the elements in a slice pattern, returns a list + // with left denoting non-constant element indices and right denoting ranges of constant elements. fn find_const_groups(&self, pattern: &[Box>]) -> Vec>> { let mut entries = Vec::new(); let mut current_seq_start = None; @@ -167,6 +163,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { entries } + // Checks if a pattern is constant and represented by a single scalar leaf. fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool { if let PatKind::Constant { value } = pat.kind && let Const::Ty(_, const_) = value @@ -179,6 +176,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } } + // Extract the `ValTree` from a constant pattern. + // You must ensure that the pattern is a constant pattern before calling this function or it will panic. fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> { if let PatKind::Constant { value } = pat.kind && let Const::Ty(_, const_) = value @@ -187,10 +186,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { { cv.valtree } else { - unreachable!() + bug!("expected constant pattern, got {:?}", pat) } } + // Simplifies a slice of constant patterns into a single flattened `ValTree`. fn simplify_const_pattern_slice_into_valtree( &self, subslice: &[Box>], @@ -200,6 +200,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ty::ValTree::Branch(interned) } + // Given a `ValTree` representing a slice of constant patterns, returns a `MatchPairTree` + // representing the slice pattern, providing as much info about subsequences in the slice as possible + // to later lowering stages. fn valtree_to_match_pair<'pat>( &mut self, source_pattern: &'pat Pat<'tcx>, @@ -211,17 +214,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ) -> MatchPairTree<'pat, 'tcx> { let tcx = self.tcx; let leaves = match valtree { - ty::ValTree::Leaf(_) => unreachable!(), + ty::ValTree::Leaf(_) => bug!("expected branch, got leaf"), ty::ValTree::Branch(leaves) => leaves, }; assert!(range.len() == leaves.len() as u64); let mut subpairs = Vec::new(); - let mut were_merged = 0; + if elem_ty == tcx.types.u8 { - let leaf_bits = |leaf: ty::ValTree<'tcx>| { - if let ty::ValTree::Leaf(scalar) = leaf { scalar.to_u8() } else { todo!() } + let leaf_bits = |leaf: ty::ValTree<'tcx>| match leaf { + ty::ValTree::Leaf(scalar) => scalar.to_u8(), + _ => bug!("found unflatted valtree"), }; let mut fuse_group = |first_idx, len| { @@ -460,7 +464,6 @@ impl<'pat, 'tcx> MatchPairTree<'pat, 'tcx> { if prefix.is_empty() && slice.is_some() && suffix.is_empty() { default_irrefutable() } else { - // TODO: do we always need this? TestCase::Slice { len: prefix.len() + suffix.len(), variable_length: slice.is_some(), diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 4406c627f5d7e..83544fa77c288 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -131,78 +131,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { otherwise_block, ); - let discr = if fused > 1 { - assert!(fused <= 4 && place_ty.ty == self.tcx.types.u8); - let tcx = self.tcx; - let source_info = self.source_info(match_start_span); - let fused_ty = match fused { - 2 => tcx.types.u16, - 3 | 4 => tcx.types.u32, - _ => unreachable!(), - }; - - let builder = PlaceBuilder::from(place); - let place_for = move |b: &mut Self, idx| { - let mut builder = builder.clone(); - match builder.projection_mut() { - [.., ProjectionElem::ConstantIndex { offset, ref from_end, .. }] => { - if !from_end { - *offset += idx; - } else { - *offset -= idx; - } - } - _ => todo!(), - } - builder.to_place(b) - }; - - let temp = self.temp(fused_ty, DUMMY_SP); - let acc = self.temp(fused_ty, DUMMY_SP); - - self.cfg.push_assign( - block, - source_info, - acc, - Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), - ); - - for i in 1..fused { - let place = place_for(self, i); - let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, i * 8)); - - self.cfg.push_assign( - block, - source_info, - temp, - Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), - ); - self.cfg.push_assign( - block, - source_info, - temp, - Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp), shift))), - ); - self.cfg.push_assign( - block, - source_info, - acc, - Rvalue::BinaryOp( - BinOp::BitOr, - Box::new((Operand::Copy(acc), Operand::Copy(temp))), - ), - ); + let discr = match fused { + 0 => span_bug!(test.span, "there must be at least one constant"), + 1 => Operand::Copy(place), + 2.. => { + self.fuse_switch_discriminant(block, place, place_ty.ty, fused, test.span) } - - acc - } else { - place }; - let terminator = TerminatorKind::SwitchInt { - discr: Operand::Copy(discr), - targets: switch_targets, - }; + let terminator = TerminatorKind::SwitchInt { discr, targets: switch_targets }; self.cfg.terminate(block, self.source_info(match_start_span), terminator); } @@ -410,6 +347,91 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { }); } + /// "Fuse" multiple small integer constants in a sequence into a single integer, possibly + /// removing unecessary branches from the lowered match tree. + fn fuse_switch_discriminant( + &mut self, + block: BasicBlock, + place: Place<'tcx>, + elem_ty: Ty<'tcx>, + count: u64, + test_span: Span, + ) -> Operand<'tcx> { + let tcx = self.tcx; + let source_info = self.source_info(test_span); + match (count, elem_ty) { + (2..=4, ty) if ty == tcx.types.u8 || ty == tcx.types.i8 => (), + (2..=2, ty) if ty == tcx.types.u16 || ty == tcx.types.i16 => (), + (fused, ty) => span_bug!( + test_span, + "unsupported constant fusion combination of count {} and type {}", + ty, + fused + ), + }; + + let fused_ty = match count * elem_ty.primitive_size(tcx).bits() { + 8..=16 => tcx.types.u16, + ..=32 => tcx.types.u32, + _ => unreachable!(), + }; + + let builder = PlaceBuilder::from(place); + let place_for = move |b: &mut Self, idx| { + let mut builder = builder.clone(); + match builder.projection_mut() { + [.., ProjectionElem::ConstantIndex { offset, ref from_end, .. }] => { + if !from_end { + *offset += idx; + } else { + *offset -= idx; + } + } + _ => span_bug!(test_span, "found unexpected projections"), + } + builder.to_place(b) + }; + + let temp = self.temp(fused_ty, DUMMY_SP); + let acc = self.temp(fused_ty, DUMMY_SP); + + // Since we can freely cast up integers + the required shift is zero on the first + // iteration, we skip both the shift and OR operations the first time. + self.cfg.push_assign( + block, + source_info, + acc, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), + ); + + // Handle all but the first iterations, iteratively building up the fused integer. + for i in 1..count { + let place = place_for(self, i); + let shift = self.literal_operand(DUMMY_SP, Const::from_usize(tcx, i * 8)); + + self.cfg.push_assign( + block, + source_info, + temp, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(place), fused_ty), + ); + self.cfg.push_assign( + block, + source_info, + temp, + Rvalue::BinaryOp(BinOp::Shl, Box::new((Operand::Copy(temp), shift))), + ); + self.cfg.push_assign( + block, + source_info, + acc, + Rvalue::BinaryOp(BinOp::BitOr, Box::new((Operand::Copy(acc), Operand::Copy(temp)))), + ); + } + + Operand::Copy(acc) + } + /// Compare using the provided built-in comparison operator fn compare( &mut self, diff --git a/compiler/rustc_mir_build/src/builder/matches/util.rs b/compiler/rustc_mir_build/src/builder/matches/util.rs index 8ba3c46682183..6be0ea5aa11dd 100644 --- a/compiler/rustc_mir_build/src/builder/matches/util.rs +++ b/compiler/rustc_mir_build/src/builder/matches/util.rs @@ -234,25 +234,25 @@ pub(crate) fn ref_pat_borrow_kind(ref_mutability: Mutability) -> BorrowKind { #[derive(Copy, Clone, PartialEq, Debug)] pub(super) struct Range { - pub(crate) start: u64, - pub(crate) end: u64, - pub(crate) from_end: bool, + pub(super) start: u64, + pub(super) end: u64, + pub(super) from_end: bool, } impl Range { - pub(crate) fn from_start(range: ops::Range) -> Self { + pub(super) fn from_start(range: ops::Range) -> Self { Range { start: range.start, end: range.end, from_end: false } } - pub(crate) fn from_end(range: ops::Range) -> Self { + pub(super) fn from_end(range: ops::Range) -> Self { Range { start: range.end, end: range.start, from_end: true } } - pub(crate) fn len(self) -> u64 { + pub(super) fn len(self) -> u64 { if !self.from_end { self.end - self.start } else { self.start - self.end } } - pub(crate) fn apply(self, slice: &[T]) -> &[T] { + pub(super) fn apply(self, slice: &[T]) -> &[T] { if !self.from_end { &slice[self.start as usize..self.end as usize] } else { @@ -260,11 +260,11 @@ impl Range { } } - pub(crate) fn shift_idx(self, idx: u64) -> u64 { + pub(super) fn shift_idx(self, idx: u64) -> u64 { if !self.from_end { self.start + idx } else { self.start - idx } } - pub(crate) fn shift_range(self, range_within: ops::Range) -> Self { + pub(super) fn shift_range(self, range_within: ops::Range) -> Self { if !self.from_end { Self::from_start(self.start + range_within.start..self.start + range_within.end) } else { From c29054e7ad773d0c4d8433ac126fd6ec3024bde5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sun, 2 Feb 2025 03:46:58 +0100 Subject: [PATCH 11/12] restore llvm --- src/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llvm-project b/src/llvm-project index f3b462d0c6ca1..7e8c93c87c611 160000 --- a/src/llvm-project +++ b/src/llvm-project @@ -1 +1 @@ -Subproject commit f3b462d0c6ca1e96853d16530b54f45760b3eb04 +Subproject commit 7e8c93c87c611f21d9bd95100563392f4c18bfe7 From ec873589f7e70a8462ba9a23dfd41276a581cf3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Wejdenst=C3=A5l?= Date: Sun, 2 Feb 2025 03:50:28 +0100 Subject: [PATCH 12/12] revert Cargo.toml --- Cargo.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a89b10d322ac9..b773030b4cab4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,8 +93,3 @@ codegen-units = 1 # FIXME: LTO cannot be enabled for binaries in a workspace # # lto = true - -[profile.release.package.rustc_mir_build] -opt-level = 0 -[profile.release.package.rustc_driver] -opt-level = 0