Skip to content

Commit bab14b9

Browse files
committed
re-add flags after refactoring
1 parent 35d882f commit bab14b9

File tree

4 files changed

+103
-18
lines changed

4 files changed

+103
-18
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,29 @@ fn thin_lto(
586586
}
587587
}
588588

589+
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
590+
591+
for &val in ad {
592+
match val {
593+
config::AutoDiff::PrintModBefore => { unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; },
594+
config::AutoDiff::PrintPerf => {llvm::set_print_perf(true);},
595+
config::AutoDiff::PrintAA => {llvm::set_print_activity(true);},
596+
config::AutoDiff::PrintTA => {llvm::set_print_type(true);},
597+
config::AutoDiff::Inline => {llvm::set_inline(true);},
598+
config::AutoDiff::LooseTypes => {llvm::set_loose_types(false);},
599+
config::AutoDiff::PrintSteps => {llvm::set_print(true);},
600+
// We handle this below
601+
config::AutoDiff::PrintModAfter => {}
602+
// This is required and already checked
603+
config::AutoDiff::Enable => {}
604+
}
605+
}
606+
// This helps with handling enums for now.
607+
llvm::set_strict_aliasing(false);
608+
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
609+
llvm::set_rust_rules(true);
610+
}
611+
589612
pub(crate) fn run_pass_manager(
590613
cgcx: &CodegenContext<LlvmCodegenBackend>,
591614
dcx: DiagCtxtHandle<'_>,
@@ -604,10 +627,6 @@ pub(crate) fn run_pass_manager(
604627
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
605628
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
606629

607-
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608-
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609-
debug!("running llvm pm opt pipeline");
610-
611630
// The PostAD behavior is the same that we would have if no autodiff was used.
612631
// It will run the default optimization pipeline. If AD is enabled we select
613632
// the DuringAD stage, which will disable vectorization and loop unrolling, and
@@ -618,25 +637,27 @@ pub(crate) fn run_pass_manager(
618637
let stage =
619638
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
620639

621-
// If Enzyme fails to differentiate the code, then this module will have everything that's
622-
// needed to reproduce the bug.
623-
if config.autodiff.contains(&config::AutoDiff::PrintModBefore) {
624-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
640+
if enable_ad {
641+
enable_autodiff_settings(&config.autodiff, module);
625642
}
643+
626644
unsafe {
627645
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
628646
}
647+
629648
if cfg!(llvm_enzyme) && enable_ad {
630649
let opt_stage = llvm::OptStage::FatLTO;
631650
let stage = write::AutodiffStage::PostAD;
632651
unsafe {
633652
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
634653
}
654+
655+
// This is the final IR, so people should be able to inspect the optimized autodiff output.
656+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
657+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
658+
}
635659
}
636-
// This is the final IR, so people should be able to inspect the optimized autodiff output.
637-
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
638-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
639-
}
660+
640661
debug!("lto done");
641662
Ok(())
642663
}

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(non_camel_case_types)]
22
#![expect(dead_code)]
33

4-
use libc::{c_char, c_uint};
4+
use libc::{c_char, c_uint, c_void};
55

66
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
77
use crate::llvm::Bool;
@@ -35,3 +35,71 @@ pub enum LLVMRustVerifierFailureAction {
3535
LLVMPrintMessageAction = 1,
3636
LLVMReturnStatusAction = 2,
3737
}
38+
39+
40+
#[cfg(not(llvm_enzyme))]
41+
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8) {
42+
unimplemented!()
43+
}
44+
#[cfg(not(llvm_enzyme))]
45+
pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64) {
46+
unimplemented!()
47+
}
48+
49+
#[cfg(llvm_enzyme)]
50+
extern "C" {
51+
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
52+
pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64);
53+
}
54+
55+
#[cfg(llvm_enzyme)]
56+
extern "C" {
57+
static mut EnzymePrintPerf: c_void;
58+
static mut EnzymePrintActivity: c_void;
59+
static mut EnzymePrintType: c_void;
60+
static mut EnzymePrint: c_void;
61+
static mut EnzymeStrictAliasing: c_void;
62+
static mut looseTypeAnalysis: c_void;
63+
static mut EnzymeInline: c_void;
64+
static mut RustTypeRules: c_void;
65+
}
66+
pub fn set_print_perf(print: bool) {
67+
unsafe {
68+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
69+
}
70+
}
71+
pub fn set_print_activity(print: bool) {
72+
unsafe {
73+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
74+
}
75+
}
76+
pub fn set_print_type(print: bool) {
77+
unsafe {
78+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
79+
}
80+
}
81+
pub fn set_print(print: bool) {
82+
unsafe {
83+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
84+
}
85+
}
86+
pub fn set_strict_aliasing(strict: bool) {
87+
unsafe {
88+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
89+
}
90+
}
91+
pub fn set_loose_types(loose: bool) {
92+
unsafe {
93+
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
94+
}
95+
}
96+
pub fn set_inline(val: bool) {
97+
unsafe {
98+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
99+
}
100+
}
101+
pub fn set_rust_rules(val: bool) {
102+
unsafe {
103+
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
104+
}
105+
}

compiler/rustc_session/src/config.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,6 @@ pub enum AutoDiff {
217217
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
218218
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
219219
LooseTypes,
220-
/// See Enzyme core documentation. FIXME(ZuseZ4): Clarify usages
221-
RuntimeActivity,
222220
/// Runs Enzyme's aggressive inlining
223221
Inline,
224222
}

compiler/rustc_session/src/options.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ mod desc {
707707
pub(crate) const parse_list: &str = "a space-separated list of strings";
708708
pub(crate) const parse_list_with_polarity: &str =
709709
"a comma-separated list of strings, with elements beginning with + or -";
710-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `RuntimeActivity`, `Inline`";
710+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`";
711711
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
712712
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
713713
pub(crate) const parse_number: &str = "a number";
@@ -1356,7 +1356,6 @@ pub mod parse {
13561356
"PrintModBefore" => AutoDiff::PrintModBefore,
13571357
"PrintModAfter" => AutoDiff::PrintModAfter,
13581358
"LooseTypes" => AutoDiff::LooseTypes,
1359-
"RuntimeActivity" => AutoDiff::RuntimeActivity,
13601359
"Inline" => AutoDiff::Inline,
13611360
_ => {
13621361
// FIXME(ZuseZ4): print an error saying which value is not recognized
@@ -2090,7 +2089,6 @@ options! {
20902089
`=PrintModBefore`
20912090
`=PrintModAfter`
20922091
`=LooseTypes`
2093-
'=RuntimeActivity`
20942092
`=Inline`
20952093
Multiple options can be combined with commas."),
20962094
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]

0 commit comments

Comments
 (0)