Skip to content

Change codegen of LLVM intrinsics to be name-based, and add llvm linkage support for x86amx, bf16(xN) and i1xN #140763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
fn checked_binop(
&mut self,
oop: OverflowOp,
typ: Ty<'_>,
typ: Ty<'tcx>,
lhs: Self::Value,
rhs: Self::Value,
) -> (Self::Value, Self::Value) {
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_codegen_gcc/src/type_of.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Write;

use gccjit::{Struct, Type};
use gccjit::{RValue, Struct, Type};
use rustc_abi as abi;
use rustc_abi::Primitive::*;
use rustc_abi::{
Expand Down Expand Up @@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
unimplemented!();
}

fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
fn fn_decl_backend_type(
&self,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_fn_ptr: RValue<'gcc>,
) -> Type<'gcc> {
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)
Expand Down
233 changes: 211 additions & 22 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::borrow::Borrow;
use std::cmp;
use std::{cmp, iter};

use libc::c_uint;
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
Expand All @@ -19,7 +20,7 @@ use smallvec::SmallVec;

use crate::attributes::{self, llfn_attrs_from_instance};
use crate::builder::Builder;
use crate::context::CodegenCx;
use crate::context::{CodegenCx, GenericCx, SCx};
use crate::llvm::{self, Attribute, AttributePlace};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
Expand Down Expand Up @@ -307,8 +308,39 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}
}

pub(crate) enum FunctionSignature<'ll> {
/// The signature is obtained directly from LLVM, and **may not match the Rust signature**
Intrinsic(llvm::Intrinsic, &'ll Type),
/// The name starts with `llvm.`, but can't obtain the intrinsic ID. May be invalid or upgradable
MaybeInvalidIntrinsic(&'ll Type),
/// Just the Rust signature
Rust(&'ll Type),
}

impl<'ll> FunctionSignature<'ll> {
pub(crate) fn fn_ty(&self) -> &'ll Type {
match self {
FunctionSignature::Intrinsic(_, fn_ty)
| FunctionSignature::MaybeInvalidIntrinsic(fn_ty)
| FunctionSignature::Rust(fn_ty) => fn_ty,
}
}
}

pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_argument_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type>;
/// When `do_verify` is set, this function performs checks for the signature of LLVM intrinsics
/// and emits a fatal error if it doesn't match. These checks are important,but somewhat expensive
/// So they are only used at function definitions, not at callsites
fn llvm_type(
&self,
cx: &CodegenCx<'ll, 'tcx>,
name: &[u8],
do_verify: bool,
) -> FunctionSignature<'ll>;
/// **If this function is an LLVM intrinsic** checks if the LLVM signature provided matches with this
fn verify_intrinsic_signature(&self, cx: &CodegenCx<'ll, 'tcx>, llvm_ty: &'ll Type) -> bool;
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;

Expand All @@ -321,30 +353,107 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
);

/// Apply attributes to a function call.
fn apply_attrs_callsite(&self, bx: &mut Builder<'_, 'll, 'tcx>, callsite: &'ll Value);
fn apply_attrs_callsite(
&self,
bx: &mut Builder<'_, 'll, 'tcx>,
callsite: &'ll Value,
llfn: &'ll Value,
);
}

impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
pub(crate) fn set_intrinsic_attributes(
&self,
intrinsic: llvm::Intrinsic,
llfn_or_callsite: &'ll Value,
) {
unsafe {
llvm::LLVMRustSetIntrinsicAttributes(self.llcx(), llfn_or_callsite, intrinsic.id());
}
}

pub(crate) fn equate_ty(&self, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
if rust_ty == llvm_ty {
return true;
}

match self.type_kind(llvm_ty) {
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
let element_count = self.vector_length(rust_ty);
let element_ty = self.element_type(rust_ty);

let element_size_bits = match self.type_kind(element_ty) {
TypeKind::Half => 16,
TypeKind::Float => 32,
TypeKind::Double => 64,
TypeKind::FP128 => 128,
TypeKind::Integer => self.int_width(element_ty),
TypeKind::Pointer => self.int_width(self.isize_ty()),
_ => bug!(
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
),
};
let vector_size_bits = element_size_bits * element_count as u64;

vector_size_bits == 8192
}
TypeKind::BFloat => rust_ty == self.type_i16(),
TypeKind::Vector => {
let llvm_element_count = self.vector_length(llvm_ty) as u64;
let llvm_element_ty = self.element_type(llvm_ty);

if llvm_element_ty == self.type_bf16() {
rust_ty == self.type_vector(self.type_i16(), llvm_element_count)
} else if llvm_element_ty == self.type_i1() {
let int_width = cmp::max(llvm_element_count.next_power_of_two(), 8);
rust_ty == self.type_ix(int_width)
} else {
false
}
}
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
let rust_element_tys = self.struct_element_types(rust_ty);
let llvm_element_tys = self.struct_element_types(llvm_ty);

if rust_element_tys.len() != llvm_element_tys.len() {
return false;
}

iter::zip(rust_element_tys, llvm_element_tys).all(
|(rust_element_ty, llvm_element_ty)| {
self.equate_ty(rust_element_ty, llvm_element_ty)
},
)
}
_ => false,
}
}
}

impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => cx.type_void(),
}
}

fn llvm_argument_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type> {
let indirect_return = matches!(self.ret.mode, PassMode::Indirect { .. });

// Ignore "extra" args from the call site for C variadic functions.
// Only the "fixed" args are part of the LLVM function signature.
let args =
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };

// This capacity calculation is approximate.
let mut llargument_tys = Vec::with_capacity(
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
);
let mut llargument_tys =
Vec::with_capacity(args.len() + if indirect_return { 1 } else { 0 });

let llreturn_ty = match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => {
llargument_tys.push(cx.type_ptr());
cx.type_void()
}
};
if indirect_return {
llargument_tys.push(cx.type_ptr());
}

for arg in args {
// Note that the exact number of arguments pushed here is carefully synchronized with
Expand Down Expand Up @@ -391,10 +500,73 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
llargument_tys.push(llarg_ty);
}

if self.c_variadic {
cx.type_variadic_func(&llargument_tys, llreturn_ty)
llargument_tys
}

fn verify_intrinsic_signature(&self, cx: &CodegenCx<'ll, 'tcx>, llvm_fn_ty: &'ll Type) -> bool {
let rust_return_ty = self.llvm_return_type(cx);
let rust_argument_tys = self.llvm_argument_types(cx);

let llvm_return_ty = cx.get_return_type(llvm_fn_ty);
let llvm_argument_tys = cx.func_params_types(llvm_fn_ty);
let llvm_is_variadic = cx.func_is_variadic(llvm_fn_ty);

if self.c_variadic != llvm_is_variadic || rust_argument_tys.len() != llvm_argument_tys.len()
{
return false;
}

iter::once((rust_return_ty, llvm_return_ty))
.chain(iter::zip(rust_argument_tys, llvm_argument_tys))
.all(|(rust_ty, llvm_ty)| cx.equate_ty(rust_ty, llvm_ty))
}

fn llvm_type(
&self,
cx: &CodegenCx<'ll, 'tcx>,
name: &[u8],
do_verify: bool,
) -> FunctionSignature<'ll> {
let return_ty = self.llvm_return_type(cx);
let argument_tys = self.llvm_argument_types(cx);

let mut maybe_invalid = false;

if name.starts_with(b"llvm.") {
if let Some(intrinsic) = llvm::Intrinsic::lookup(name) {
if !intrinsic.is_overloaded() {
// FIXME: also do this for overloaded intrinsics
let llvm_fn_ty = cx.intrinsic_type(intrinsic, &[]);
if do_verify {
if !self.verify_intrinsic_signature(cx, llvm_fn_ty) {
cx.tcx.dcx().fatal(format!(
"Intrinsic signature mismatch for `{}`: expected signature `{llvm_fn_ty:?}`",
str::from_utf8(name).unwrap()
));
}
}
return FunctionSignature::Intrinsic(intrinsic, llvm_fn_ty);
}
} else {
// it's one of 2 cases,
// - either the base name is invalid
// - it has been superceded by something else, so the intrinsic was removed entirely
// to check for upgrades, we need the `llfn`, so we defer it for now

maybe_invalid = true;
}
}

let fn_ty = if self.c_variadic {
cx.type_variadic_func(&argument_tys, return_ty)
} else {
cx.type_func(&argument_tys, return_ty)
};

if maybe_invalid {
FunctionSignature::MaybeInvalidIntrinsic(fn_ty)
} else {
cx.type_func(&llargument_tys, llreturn_ty)
FunctionSignature::Rust(fn_ty)
}
}

Expand Down Expand Up @@ -531,7 +703,24 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
}
}

fn apply_attrs_callsite(&self, bx: &mut Builder<'_, 'll, 'tcx>, callsite: &'ll Value) {
fn apply_attrs_callsite(
&self,
bx: &mut Builder<'_, 'll, 'tcx>,
callsite: &'ll Value,
llfn: &'ll Value,
) {
// if we are using the LLVM signature, use the LLVM attributes otherwise it might be problematic
let name = llvm::get_value_name(llfn);
if name.starts_with(b"llvm.")
&& let Some(intrinsic) = llvm::Intrinsic::lookup(name)
{
// FIXME: also do this for overloaded intrinsics
if !intrinsic.is_overloaded() {
bx.set_intrinsic_attributes(intrinsic, callsite);
return;
}
}

let mut func_attrs = SmallVec::<[_; 2]>::new();
if self.ret.layout.is_uninhabited() {
func_attrs.push(llvm::AttributeKind::NoReturn.create_attr(bx.cx.llcx));
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_llvm/src/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn create_wrapper_function(
let ty = cx.type_func(args, output.unwrap_or_else(|| cx.type_void()));
let llfn = declare_simple_fn(
&cx,
from_name,
from_name.as_bytes(),
llvm::CallConv::CCallConv,
llvm::UnnamedAddr::Global,
llvm::Visibility::from_generic(tcx.sess.default_visibility()),
Expand All @@ -130,7 +130,7 @@ fn create_wrapper_function(

let callee = declare_simple_fn(
&cx,
to_name,
to_name.as_bytes(),
llvm::CallConv::CCallConv,
llvm::UnnamedAddr::Global,
llvm::Visibility::Hidden,
Expand All @@ -150,7 +150,7 @@ fn create_wrapper_function(
.enumerate()
.map(|(i, _)| llvm::get_param(llfn, i as c_uint))
.collect::<Vec<_>>();
let ret = bx.call(ty, callee, &args, None);
let ret = bx.simple_call(ty, callee, &args, None);
llvm::LLVMSetTailCall(ret, True);
if output.is_some() {
bx.ret(ret);
Expand Down
Loading
Loading