From d6467d34ae4057f493e2706a5625e0784f2a68bf Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 7 Apr 2025 07:07:16 -0400 Subject: [PATCH 1/2] handle sret for scalar autodiff --- .../rustc_ast/src/expand/autodiff_attrs.rs | 6 +++++ .../src/builder/autodiff.rs | 24 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index f01c781f46c65..13a7c5a180576 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -92,6 +92,12 @@ pub struct AutoDiffAttrs { pub input_activity: Vec, } +impl AutoDiffAttrs { + pub fn has_primal_ret(&self) -> bool { + matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual) + } +} + impl DiffMode { pub fn is_rev(&self) -> bool { matches!(self, DiffMode::Reverse) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 7d264ba4d00c8..5e7ef27143b14 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>( } if attrs.width == 1 { - todo!("Handle sret for scalar ad"); + // Enzyme returns a struct of style: + // `{ original_ret(if requested), float, float, ... }` + let mut struct_elements = vec![]; + if attrs.has_primal_ret() { + struct_elements.push(inner_ret_ty); + } + // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, + // and therefore part of the return struct. + let param_tys = cx.func_params_types(fn_ty); + for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { + if matches!(act, DiffActivity::Active) { + // Now find the float type at position i based on the fn_ty, + // to know what (f16/f32/f64/...) to add to the struct. + struct_elements.push(param_ty); + } + } + ret_ty = cx.type_struct(&struct_elements, false); } else { // First we check if we also have to deal with the primal return. match attrs.mode { @@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>( // now store the result of the enzyme call into the sret pointer. let sret_ptr = outer_args[0]; let call_ty = cx.val_ty(call); - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); + if attrs.width == 1 { + assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); + } else { + assert_eq!(cx.type_kind(call_ty), TypeKind::Array); + } llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); } builder.ret_void(); From ca5bea3ebbc4725c187abf4eac68f6c57fa938c1 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 7 Apr 2025 07:11:52 -0400 Subject: [PATCH 2/2] move old tests, add sret test --- .../{autodiffv.rs => autodiff/batched.rs} | 0 .../{autodiff.rs => autodiff/scalar.rs} | 0 tests/codegen/autodiff/sret.rs | 45 +++++++++++++++++++ 3 files changed, 45 insertions(+) rename tests/codegen/{autodiffv.rs => autodiff/batched.rs} (100%) rename tests/codegen/{autodiff.rs => autodiff/scalar.rs} (100%) create mode 100644 tests/codegen/autodiff/sret.rs diff --git a/tests/codegen/autodiffv.rs b/tests/codegen/autodiff/batched.rs similarity index 100% rename from tests/codegen/autodiffv.rs rename to tests/codegen/autodiff/batched.rs diff --git a/tests/codegen/autodiff.rs b/tests/codegen/autodiff/scalar.rs similarity index 100% rename from tests/codegen/autodiff.rs rename to tests/codegen/autodiff/scalar.rs diff --git a/tests/codegen/autodiff/sret.rs b/tests/codegen/autodiff/sret.rs new file mode 100644 index 0000000000000..5ead90041edc3 --- /dev/null +++ b/tests/codegen/autodiff/sret.rs @@ -0,0 +1,45 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is almost identical to the scalar.rs one, +// but we intentionally add a few more floats. +// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret. +// We therefore use this test to verify some of our sret handling. + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[no_mangle] +#[autodiff(df, Reverse, Active, Active, Active)] +fn primal(x: f32, y: f32) -> f64 { + (x * x * y) as f64 +} + +// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) +// CHECK-NEXT:start: +// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) +// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 +// CHECK-NEXT: store double %.elt, ptr %_0, align 8 +// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 +// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 +// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 +// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 +// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 +// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 +// CHECK-NEXT: ret void +// CHECK-NEXT:} + +fn main() { + let x = std::hint::black_box(3.0); + let y = std::hint::black_box(2.5); + let scalar = std::hint::black_box(1.0); + let (r1, r2, r3) = df(x, y, scalar); + // 3*3*1.5 = 22.5 + assert_eq!(r1, 22.5); + // 2*x*y = 2*3*2.5 = 15.0 + assert_eq!(r2, 15.0); + // x*x*1 = 3*3 = 9 + assert_eq!(r3, 9.0); +}