Skip to content

Commit b87e031

Browse files
committed
add autodiff batching fix
1 parent 2898b90 commit b87e031

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ fn match_args_from_caller_to_enzyme<'ll>(
125125
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
126126
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
127127

128-
for _ in 0..width {
129-
let next_outer_arg2 = outer_args[outer_pos + 2];
128+
for i in 0..(width as usize) {
129+
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
130130
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131131
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
132-
let next_outer_arg3 = outer_args[outer_pos + 3];
132+
let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
133133
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
134134
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
135135
args.push(next_outer_arg2);

0 commit comments

Comments
 (0)