Skip to content

Commit 3b60dee

Browse files
committed
Add bypass for bf16 and bf16xN
1 parent b5cb020 commit 3b60dee

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,17 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
368368
}
369369

370370
match self.type_kind(llvm_ty) {
371+
TypeKind::BFloat => rust_ty == self.type_i16(),
372+
TypeKind::Vector => {
373+
let llvm_element_count = self.vector_length(llvm_ty) as u64;
374+
let llvm_element_ty = self.element_type(llvm_ty);
375+
376+
if llvm_element_ty == self.type_bf16() {
377+
rust_ty == self.type_vector(self.type_i16(), llvm_element_count)
378+
} else {
379+
false
380+
}
381+
}
371382
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
372383
let rust_element_tys = self.struct_element_types(rust_ty);
373384
let llvm_element_tys = self.struct_element_types(llvm_ty);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,7 +1698,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16981698
}
16991699
ret
17001700
}
1701-
_ => unreachable!(),
1701+
_ => self.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
17021702
}
17031703
}
17041704

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,9 @@ unsafe extern "C" {
10501050
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
10511051
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
10521052

1053+
// Operations on non-IEEE real types
1054+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
1055+
10531056
// Operations on function types
10541057
pub(crate) fn LLVMFunctionType<'a>(
10551058
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
168168
)
169169
}
170170
}
171+
172+
pub(crate) fn type_bf16(&self) -> &'ll Type {
173+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
174+
}
171175
}
172176

173177
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -241,7 +245,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
241245

242246
fn float_width(&self, ty: &'ll Type) -> usize {
243247
match self.type_kind(ty) {
244-
TypeKind::Half => 16,
248+
TypeKind::Half | TypeKind::BFloat => 16,
245249
TypeKind::Float => 32,
246250
TypeKind::Double => 64,
247251
TypeKind::X86_FP80 => 80,

0 commit comments

Comments
 (0)