Skip to content

Pure batching without autodiff #141637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub enum DiffMode {
Forward,
/// The target function, to be created using reverse mode AD.
Reverse,
/// The target function, to be created using batching.
Batch,
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
Expand Down Expand Up @@ -69,6 +71,12 @@ pub enum DiffActivity {
/// length of a slice/vec. This is used for safety checks on slices.
/// The integer (if given) specifies the size of the slice element in bytes.
FakeActivitySize(Option<u32>),
/// Batching mode A
Vector,
/// Batching mode B, equivalent to *v modes above
Buffer,
/// "Batching" mode C, scalar. Not batched.
Scalar,
}

impl DiffActivity {
Expand Down Expand Up @@ -130,6 +138,7 @@ impl Display for DiffMode {
DiffMode::Source => write!(f, "Source"),
DiffMode::Forward => write!(f, "Forward"),
DiffMode::Reverse => write!(f, "Reverse"),
DiffMode::Batch => write!(f, "Batch"),
}
}
}
Expand All @@ -153,6 +162,14 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
|| activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
}
DiffMode::Batch => {
// Batching is a special case, since we don't compute derivatives wrt. the return value.
// We just compute derivatives wrt. the inputs, so we can ignore the return value.
activity == DiffActivity::Const
|| activity == DiffActivity::Vector
|| activity == DiffActivity::Buffer
|| activity == DiffActivity::Scalar
}
}
}

Expand Down Expand Up @@ -186,6 +203,11 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
DiffMode::Reverse => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
}
DiffMode::Batch => {
// Batching is a special case, since we don't compute derivatives wrt. the return value.
// We just compute derivatives wrt. the inputs, so we can ignore the return value.
matches!(activity, Const | Vector | Buffer)
}
};
}

Expand All @@ -203,6 +225,9 @@ impl Display for DiffActivity {
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
DiffActivity::Vector => write!(f, "Vector"),
DiffActivity::Buffer => write!(f, "Buffer"),
DiffActivity::Scalar => write!(f, "Scalar"),
}
}
}
Expand All @@ -216,6 +241,7 @@ impl FromStr for DiffMode {
"Source" => Ok(DiffMode::Source),
"Forward" => Ok(DiffMode::Forward),
"Reverse" => Ok(DiffMode::Reverse),
"Batch" => Ok(DiffMode::Batch),
_ => Err(()),
}
}
Expand All @@ -235,6 +261,9 @@ impl FromStr for DiffActivity {
"DualvOnly" => Ok(DiffActivity::DualvOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
"Scalar" => Ok(DiffActivity::Scalar),
"Vector" => Ok(DiffActivity::Vector),
"Buffer" => Ok(DiffActivity::Buffer),
_ => Err(()),
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ mod llvm_enzyme {
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
panic!("Should not happen");
}
DiffActivity::Vector | DiffActivity::Scalar | DiffActivity::Buffer => todo!()
}
if let PatKind::Ident(_, ident, _) = arg.pat.kind {
idents.push(ident.clone());
Expand Down
106 changes: 93 additions & 13 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
outer_pos = 1;
}

// Autodiff activities
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();

// Batching activities
let enzyme_scalar = cx.create_metadata("enzyme_scalar".to_string()).unwrap();
let enzyme_vector = cx.create_metadata("enzyme_vector".to_string()).unwrap();
let enzyme_buffer = cx.create_metadata("enzyme_buffer".to_string()).unwrap();

while activity_pos < inputs.len() {
let diff_activity = inputs[activity_pos as usize];
// Duplicated arguments received a shadow argument, into which enzyme will write the
Expand All @@ -99,14 +105,18 @@ fn match_args_from_caller_to_enzyme<'ll>(
DiffActivity::Duplicated => (enzyme_dup, true),
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
DiffActivity::Vector => (enzyme_vector, true),
DiffActivity::Buffer => (enzyme_buffer, false),
DiffActivity::Scalar => (enzyme_scalar, true),
};
let no_autodiff_only_batching = matches!(diff_activity, DiffActivity::Scalar | DiffActivity::Vector | DiffActivity::Buffer);
let outer_arg = outer_args[outer_pos];
args.push(cx.get_metadata_value(activity));
if matches!(diff_activity, DiffActivity::Dualv) {
let next_outer_arg = outer_args[outer_pos + 1];
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
DiffActivity::FakeActivitySize(Some(s)) => s.into(),
_ => bug!("incorrect Dualv handling recognized."),
_ => bug!("incorrect Dualv/Batching handling recognized."),
};
// stride: sizeof(T) * n_elems.
// n_elems is the next integer.
Expand All @@ -121,7 +131,53 @@ fn match_args_from_caller_to_enzyme<'ll>(
};
args.push(mul);
}
if matches!(diff_activity, DiffActivity::Buffer) {
// There are various cases.
// A) We look at a scalar float.
// B) We look at a Vector/Array of floats (byVal). Not sure if this is valid.
// C) We look at a ptr as part of a slice.
// D) We look at a ptr as part of a raw pointer or reference.

let mut elem_offset = cx.get_const_i64(width.into());
let outer_ty = cx.val_ty(outer_arg);
dbg!(&outer_ty);
let bit_width = if cx.is_float_type(outer_ty) {
cx.float_width(outer_ty)
} else if cx.is_vec_or_array_type(outer_ty) {
let elem_ty = cx.element_type(outer_ty);
assert!(cx.is_float_type(elem_ty));
let num_vec_elements = cx.vector_length(outer_ty);
assert!(num_vec_elements == width as usize);
dbg!(&num_vec_elements);
cx.float_width(elem_ty)
} else if cx.is_ptr_type(outer_ty) {
if is_slice(activity_pos, inputs) {
elem_offset = outer_args[outer_pos + 1];
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
DiffActivity::FakeActivitySize(Some(s)) => s.into(),
_ => bug!("incorrect Dualv/Buffer handling recognized."),
};
elem_bytes_size as usize * 8
} else {
// raw pointer or ref, hence `num_elem` = 1
unimplemented!()
}
} else {
bug!("expected float or vector type, found {:?}", outer_ty);
};
let elem_bytes_size = bit_width as u64 / 8;
let mul = unsafe {
llvm::LLVMBuildMul(
builder.llbuilder,
cx.get_const_i64(elem_bytes_size),
elem_offset,
UNNAMED,
)
};
args.push(mul);
}
args.push(outer_arg);
dbg!(&args);
if duplicated {
// We know that duplicated args by construction have a following argument,
// so this can not be out of bounds.
Expand All @@ -130,17 +186,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
// FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
// vectors behind references (&Vec<T>) are already supported. Users can not pass a
// Vec by value for reverse mode, so this would only help forward mode autodiff.
let slice = {
if activity_pos + 1 >= inputs.len() {
// If there is no arg following our ptr, it also can't be a slice,
// since that would lead to a ptr, int pair.
false
} else {
let next_activity = inputs[activity_pos + 1];
// We analyze the MIR types and add this dummy activity if we visit a slice.
matches!(next_activity, DiffActivity::FakeActivitySize(_))
}
};
let slice = is_slice(activity_pos, &inputs);
if slice {
// A duplicated slice will have the following two outer_fn arguments:
// (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
Expand Down Expand Up @@ -178,8 +224,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
outer_pos += 2;
activity_pos += 1;

dbg!(&width);
dbg!(&outer_pos);
dbg!(&activity_pos);
dbg!(&args);
let limit = if no_autodiff_only_batching {
// Usually we have one primal arg + width shadow args.
// Here we have `width` primal args, so one less than normal.
width as usize - 1
} else {
width as usize
};
// Now, if width > 1, we need to account for that
for _ in 1..width {
for _ in 1..limit {
let next_outer_arg = outer_args[outer_pos];
args.push(next_outer_arg);
outer_pos += 1;
Expand All @@ -192,6 +249,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
activity_pos += 1;
}
}
dbg!("ending");
}

fn is_slice(activity_pos: usize, inputs: &[DiffActivity]) -> bool {
if activity_pos + 1 >= inputs.len() {
// If there is no arg following our ptr, it also can't be a slice,
// since that would lead to a ptr, int pair.
false
} else {
let next_activity = inputs[activity_pos + 1];
// We analyze the MIR types and add this dummy activity if we visit a slice.
matches!(next_activity, DiffActivity::FakeActivitySize(_))
}
}

// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
Expand Down Expand Up @@ -269,6 +339,12 @@ fn compute_enzyme_fn_ty<'ll>(
DiffMode::Reverse => {
todo!("Handle sret for reverse mode");
}
DiffMode::Batch => {
let arr_ty = unsafe {
llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64)
};
ret_ty = arr_ty;
}
_ => {
bug!("unreachable");
}
Expand Down Expand Up @@ -299,6 +375,7 @@ fn generate_enzyme_call<'ll>(
let mut ad_name: String = match attrs.mode {
DiffMode::Forward => "__enzyme_fwddiff",
DiffMode::Reverse => "__enzyme_autodiff",
DiffMode::Batch => "__enzyme_batch",
_ => panic!("logic bug in autodiff, unrecognized mode"),
}
.to_string();
Expand Down Expand Up @@ -402,6 +479,7 @@ fn generate_enzyme_call<'ll>(

let call = builder.call(enzyme_ty, ad_fn, &args, None);

dbg!(&call);
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attached to it, but we just created this code oota. Given that the
// differentiated function already has partly confusing metadata, and given that this
Expand Down Expand Up @@ -448,6 +526,7 @@ fn generate_enzyme_call<'ll>(
} else {
builder.ret(call);
}
dbg!("Still alive");

// Let's crash in case that we messed something up above and generated invalid IR.
llvm::LLVMRustVerifyFunction(
Expand Down Expand Up @@ -507,6 +586,7 @@ pub(crate) fn differentiate<'ll>(

generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
}
dbg!("lowered all");

// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

Expand Down
17 changes: 17 additions & 0 deletions compiler/rustc_codegen_llvm/src/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
)
}
}

pub(crate) fn is_float_type(&self, ty: &'ll Type) -> bool {
matches!(
self.type_kind(ty),
TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::X86_FP80
| TypeKind::FP128 | TypeKind::PPC_FP128
)
}

pub(crate) fn is_vec_or_array_type(&self, ty: &'ll Type) -> bool {
matches!(self.type_kind(ty),
TypeKind::Array | TypeKind::Vector | TypeKind::ScalableVector)
}

pub(crate) fn is_ptr_type(&self, ty: &'ll Type) -> bool {
matches!(self.type_kind(ty), TypeKind::Pointer)
}
}

impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_ssa/src/codegen_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
let mode = match mode.as_str() {
"Forward" => DiffMode::Forward,
"Reverse" => DiffMode::Reverse,
"Batch" => DiffMode::Batch,
_ => {
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
}
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Buffer => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ symbols! {
BTreeEntry,
BTreeMap,
BTreeSet,
Batching,
BinaryHeap,
Borrow,
BorrowMut,
Expand Down
Loading
Loading