Skip to content

Commit 95d66be

Browse files
committed
Add bypass for i1xN
1 parent 3b60dee commit 95d66be

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
375375

376376
if llvm_element_ty == self.type_bf16() {
377377
rust_ty == self.type_vector(self.type_i16(), llvm_element_count)
378+
} else if llvm_element_ty == self.type_i1() {
379+
let int_width = cmp::max(llvm_element_count.next_power_of_two(), 8);
380+
rust_ty == self.type_ix(int_width)
378381
} else {
379382
false
380383
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
3-
use std::{iter, ptr};
3+
use std::{cmp, iter, ptr};
44

55
pub(crate) mod autodiff;
66

@@ -1670,6 +1670,46 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16701670
self.call(self.type_func(&[src_ty], dest_ty), None, None, f, &[val], None, None)
16711671
}
16721672

1673+
fn trunc_int_to_i1_vector(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
1674+
let vector_length = self.vector_length(dest_ty) as u64;
1675+
let int_width = cmp::max(vector_length.next_power_of_two(), 8);
1676+
1677+
let bitcasted = self.bitcast(val, self.type_vector(self.type_i1(), int_width));
1678+
if vector_length == int_width {
1679+
bitcasted
1680+
} else {
1681+
let shuffle_mask =
1682+
(0..vector_length).map(|i| self.const_i32(i as i32)).collect::<Vec<_>>();
1683+
self.shuffle_vector(bitcasted, bitcasted, self.const_vector(&shuffle_mask))
1684+
}
1685+
}
1686+
1687+
fn zext_i1_vector_to_int(
1688+
&mut self,
1689+
mut val: &'ll Value,
1690+
src_ty: &'ll Type,
1691+
dest_ty: &'ll Type,
1692+
) -> &'ll Value {
1693+
let vector_length = self.vector_length(src_ty) as u64;
1694+
let int_width = cmp::max(vector_length.next_power_of_two(), 8);
1695+
1696+
if vector_length != int_width {
1697+
let shuffle_indices = match vector_length {
1698+
0 => unreachable!("zero length vectors are not allowed"),
1699+
1 => vec![0, 1, 1, 1, 1, 1, 1, 1],
1700+
2 => vec![0, 1, 2, 3, 2, 3, 2, 3],
1701+
3 => vec![0, 1, 2, 3, 4, 5, 3, 4],
1702+
4.. => (0..int_width as i32).collect(),
1703+
};
1704+
let shuffle_mask =
1705+
shuffle_indices.into_iter().map(|i| self.const_i32(i)).collect::<Vec<_>>();
1706+
val =
1707+
self.shuffle_vector(val, self.const_null(src_ty), self.const_vector(&shuffle_mask));
1708+
}
1709+
1710+
self.bitcast(val, dest_ty)
1711+
}
1712+
16731713
fn autocast(
16741714
&mut self,
16751715
llfn: &'ll Value,
@@ -1685,6 +1725,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16851725
}
16861726

16871727
match self.type_kind(llvm_ty) {
1728+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
1729+
if is_argument {
1730+
self.trunc_int_to_i1_vector(val, dest_ty)
1731+
} else {
1732+
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
1733+
}
1734+
}
16881735
TypeKind::Struct => {
16891736
let mut ret = self.const_poison(dest_ty);
16901737
for (idx, (src_element_ty, dest_element_ty)) in

0 commit comments

Comments
 (0)