Skip to content

Commit

Permalink
fix fwd-mode autodiff case
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Feb 5, 2025
1 parent 335151f commit 70b9ba3
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
let mut activity_pos = 0;
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
while activity_pos < inputs.len() {
let activity = inputs[activity_pos as usize];
let diff_activity = inputs[activity_pos as usize];
// Duplicated arguments received a shadow argument, into which enzyme will write the
// gradient.
let (activity, duplicated): (&Metadata, bool) = match activity {
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
DiffActivity::None => panic!("not a valid input activity"),
DiffActivity::Const => (enzyme_const, false),
DiffActivity::Active => (enzyme_out, false),
Expand Down Expand Up @@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
// A duplicated pointer will have the following two outer_fn arguments:
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
// (..., metadata! enzyme_dup, ptr, ptr, ...).
assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
assert!(
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
);
}
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
args.push(next_outer_arg);
outer_pos += 2;
activity_pos += 1;
Expand Down

0 comments on commit 70b9ba3

Please sign in to comment.