Skip to content

Commit 96db5e9

Browse files
committed
Add comments
Still need to make it so that it maps discriminants to variant indexes. Maybe instead I can map the variant indexes to discriminants?
1 parent 18144b6 commit 96db5e9

File tree

1 file changed

+49
-33
lines changed

1 file changed

+49
-33
lines changed

compiler/rustc_mir/src/transform/large_enums.rs

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_data_structures::stable_map::FxHashMap;
44
use rustc_middle::mir::*;
55
use rustc_middle::ty::{self, Const, List, Ty, TyCtxt};
66
use rustc_span::def_id::DefId;
7-
use rustc_target::abi::{Size, Variants};
7+
use rustc_target::abi::{Size, TagEncoding, Variants};
88

99
/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
1010
/// enough discrepanc between them
@@ -31,25 +31,33 @@ impl<const D: u64> EnumSizeOpt<D> {
3131
match variants {
3232
Variants::Single { .. } => None,
3333
Variants::Multiple { variants, .. } if variants.len() <= 1 => None,
34+
Variants::Multiple { tag_encoding, .. }
35+
if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
36+
{
37+
None
38+
}
3439
Variants::Multiple { variants, .. } => {
3540
let min = variants.iter().map(|v| v.size).min().unwrap();
3641
let max = variants.iter().map(|v| v.size).max().unwrap();
3742
if max.bytes() - min.bytes() < D {
3843
return None;
3944
}
40-
Some((
41-
layout.size,
42-
variants.len() as u64,
43-
variants.iter().map(|v| v.size).collect(),
44-
))
45+
let mut discr_sizes = vec![Size::ZERO; adt_def.discriminants(tcx).count()];
46+
for (var_idx, layout) in variants.iter_enumerated() {
47+
let disc_idx =
48+
adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
49+
assert_eq!(discr_sizes[disc_idx], Size::ZERO);
50+
discr_sizes[disc_idx] = layout.size;
51+
}
52+
Some((layout.size, variants.len() as u64, discr_sizes))
4553
}
4654
}
4755
}
4856
_ => None,
4957
}
5058
}
5159
fn optim(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
52-
let mut match_cache = FxHashMap::default();
60+
let mut alloc_cache = FxHashMap::default();
5361
let body_did = body.source.def_id();
5462
let mut patch = MirPatch::new(body);
5563
let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut();
@@ -61,39 +69,45 @@ impl<const D: u64> EnumSizeOpt<D> {
6169
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
6270
)) => {
6371
let ty = lhs.ty(local_decls, tcx).ty;
72+
let source_info = st.source_info;
73+
let span = source_info.span;
74+
6475
let (total_size, num_variants, sizes) =
65-
if let Some((ts, nv, s)) = match_cache.get(ty) {
66-
(*ts, *nv, s)
67-
} else if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) {
68-
// FIXME(jknodt) use entry API.
69-
match_cache.insert(ty, (ts, nv, s));
70-
let (ts, nv, s) = match_cache.get(ty).unwrap();
71-
(*ts, *nv, s)
76+
if let Some((ts, nv, s)) = Self::candidate(tcx, ty, body_did) {
77+
(ts, nv, s)
7278
} else {
7379
return None;
7480
};
7581

76-
let source_info = st.source_info;
77-
let span = source_info.span;
82+
let alloc = if let Some(alloc) = alloc_cache.get(ty) {
83+
alloc
84+
} else {
85+
let mut data =
86+
vec![0; std::mem::size_of::<usize>() * num_variants as usize];
87+
data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) });
88+
let alloc = interpret::Allocation::from_bytes(
89+
data,
90+
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
91+
Mutability::Not,
92+
);
93+
let alloc = tcx.intern_const_alloc(alloc);
94+
alloc_cache.insert(ty, alloc);
95+
// FIXME(jknodt) use entry API
96+
alloc_cache.get(ty).unwrap()
97+
};
7898

7999
let tmp_ty = tcx.mk_ty(ty::Array(
80100
tcx.types.usize,
81101
Const::from_usize(tcx, num_variants),
82102
));
83103

84-
let new_local = patch.new_temp(tmp_ty, span);
85-
let store_live =
86-
Statement { source_info, kind: StatementKind::StorageLive(new_local) };
87-
88-
let place = Place { local: new_local, projection: List::empty() };
89-
let mut data =
90-
vec![0; std::mem::size_of::<usize>() * num_variants as usize];
91-
data.copy_from_slice(unsafe { std::mem::transmute(&sizes[..]) });
92-
let alloc = interpret::Allocation::from_bytes(
93-
data,
94-
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
95-
);
96-
let alloc = tcx.intern_const_alloc(alloc);
104+
let size_array_local = patch.new_temp(tmp_ty, span);
105+
let store_live = Statement {
106+
source_info,
107+
kind: StatementKind::StorageLive(size_array_local),
108+
};
109+
110+
let place = Place { local: size_array_local, projection: List::empty() };
97111
let constant_vals = Constant {
98112
span,
99113
user_ty: None,
@@ -134,9 +148,9 @@ impl<const D: u64> EnumSizeOpt<D> {
134148
kind: StatementKind::Assign(box (
135149
size_place,
136150
Rvalue::Use(Operand::Copy(Place {
137-
local: discr_place.local,
151+
local: size_array_local,
138152
projection: tcx
139-
.intern_place_elems(&[PlaceElem::Index(size_place.local)]),
153+
.intern_place_elems(&[PlaceElem::Index(discr_place.local)]),
140154
})),
141155
)),
142156
};
@@ -187,8 +201,10 @@ impl<const D: u64> EnumSizeOpt<D> {
187201
}),
188202
};
189203

190-
let store_dead =
191-
Statement { source_info, kind: StatementKind::StorageDead(new_local) };
204+
let store_dead = Statement {
205+
source_info,
206+
kind: StatementKind::StorageDead(size_array_local),
207+
};
192208
let iter = std::array::IntoIter::new([
193209
store_live,
194210
const_assign,

0 commit comments

Comments
 (0)