Skip to content

Commit d69340d

Browse files
authored
Rollup merge of rust-lang#140697 - Sa4dUs:split-autodiff, r=ZuseZ4
Split `autodiff` into `autodiff_forward` and `autodiff_reverse` This PR splits `#[autodiff]` macro so `#[autodiff(df, Reverse, args)]` would become `#[autodiff_reverse(df, args)]` and `#[autodiff(df, Forward, args)]` would become `#[autodiff_forwad(df, args)]`.
2 parents 66772ca + c6c2fde commit d69340d

32 files changed

+234
-217
lines changed

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ builtin_macros_assert_requires_expression = macro requires an expression as an a
7171
7272
builtin_macros_autodiff = autodiff must be applied to function
7373
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
74-
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
7574
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
7675
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7776
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -86,27 +86,23 @@ mod llvm_enzyme {
8686
ecx: &mut ExtCtxt<'_>,
8787
meta_item: &ThinVec<MetaItemInner>,
8888
has_ret: bool,
89+
mode: DiffMode,
8990
) -> AutoDiffAttrs {
9091
let dcx = ecx.sess.dcx();
91-
let mode = name(&meta_item[1]);
92-
let Ok(mode) = DiffMode::from_str(&mode) else {
93-
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
94-
return AutoDiffAttrs::error();
95-
};
9692

9793
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
9894
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99-
let mut first_activity = 2;
95+
let mut first_activity = 1;
10096

101-
let width = if let [_, _, x, ..] = &meta_item[..]
97+
let width = if let [_, x, ..] = &meta_item[..]
10298
&& let Some(x) = width(x)
10399
{
104-
first_activity = 3;
100+
first_activity = 2;
105101
match x.try_into() {
106102
Ok(x) => x,
107103
Err(_) => {
108104
dcx.emit_err(errors::AutoDiffInvalidWidth {
109-
span: meta_item[2].span(),
105+
span: meta_item[1].span(),
110106
width: x,
111107
});
112108
return AutoDiffAttrs::error();
@@ -165,6 +161,24 @@ mod llvm_enzyme {
165161
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
166162
}
167163

164+
pub(crate) fn expand_forward(
165+
ecx: &mut ExtCtxt<'_>,
166+
expand_span: Span,
167+
meta_item: &ast::MetaItem,
168+
item: Annotatable,
169+
) -> Vec<Annotatable> {
170+
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171+
}
172+
173+
pub(crate) fn expand_reverse(
174+
ecx: &mut ExtCtxt<'_>,
175+
expand_span: Span,
176+
meta_item: &ast::MetaItem,
177+
item: Annotatable,
178+
) -> Vec<Annotatable> {
179+
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180+
}
181+
168182
/// We expand the autodiff macro to generate a new placeholder function which passes
169183
/// type-checking and can be called by users. The function body of the placeholder function will
170184
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +212,12 @@ mod llvm_enzyme {
198212
/// ```
199213
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200214
/// in CI.
201-
pub(crate) fn expand(
215+
pub(crate) fn expand_with_mode(
202216
ecx: &mut ExtCtxt<'_>,
203217
expand_span: Span,
204218
meta_item: &ast::MetaItem,
205219
mut item: Annotatable,
220+
mode: DiffMode,
206221
) -> Vec<Annotatable> {
207222
if cfg!(not(llvm_enzyme)) {
208223
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
@@ -245,29 +260,41 @@ mod llvm_enzyme {
245260
// create TokenStream from vec elemtents:
246261
// meta_item doesn't have a .tokens field
247262
let mut ts: Vec<TokenTree> = vec![];
248-
if meta_item_vec.len() < 2 {
249-
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
250-
// input and output args.
263+
if meta_item_vec.len() < 1 {
264+
// At the bare minimum, we need a fnc name.
251265
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
252266
return vec![item];
253267
}
254268

255-
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
269+
let mode_symbol = match mode {
270+
DiffMode::Forward => sym::Forward,
271+
DiffMode::Reverse => sym::Reverse,
272+
_ => unreachable!("Unsupported mode: {:?}", mode),
273+
};
274+
275+
// Insert mode token
276+
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
277+
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
278+
ts.insert(
279+
1,
280+
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
281+
);
256282

257283
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
258284
// If it is not given, we default to 1 (scalar mode).
259285
let start_position;
260286
let kind: LitKind = LitKind::Integer;
261287
let symbol;
262-
if meta_item_vec.len() >= 3
263-
&& let Some(width) = width(&meta_item_vec[2])
288+
if meta_item_vec.len() >= 2
289+
&& let Some(width) = width(&meta_item_vec[1])
264290
{
265-
start_position = 3;
291+
start_position = 2;
266292
symbol = Symbol::intern(&width.to_string());
267293
} else {
268-
start_position = 2;
294+
start_position = 1;
269295
symbol = sym::integer(1);
270296
}
297+
271298
let l: Lit = Lit { kind, symbol, suffix: None };
272299
let t = Token::new(TokenKind::Literal(l), Span::default());
273300
let comma = Token::new(TokenKind::Comma, Span::default());
@@ -289,7 +316,7 @@ mod llvm_enzyme {
289316
ts.pop();
290317
let ts: TokenStream = TokenStream::from_iter(ts);
291318

292-
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
319+
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
293320
if !x.is_active() {
294321
// We encountered an error, so we return the original item.
295322
// This allows us to potentially parse other attributes.
@@ -1017,4 +1044,4 @@ mod llvm_enzyme {
10171044
}
10181045
}
10191046

1020-
pub(crate) use llvm_enzyme::expand;
1047+
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,6 @@ mod autodiff {
187187
pub(crate) act: String,
188188
}
189189

190-
#[derive(Diagnostic)]
191-
#[diag(builtin_macros_autodiff_mode)]
192-
pub(crate) struct AutoDiffInvalidMode {
193-
#[primary_span]
194-
pub(crate) span: Span,
195-
pub(crate) mode: String,
196-
}
197-
198190
#[derive(Diagnostic)]
199191
#[diag(builtin_macros_autodiff_width)]
200192
pub(crate) struct AutoDiffInvalidWidth {

compiler/rustc_builtin_macros/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
#![allow(internal_features)]
66
#![allow(rustc::diagnostic_outside_of_impl)]
77
#![allow(rustc::untranslatable_diagnostic)]
8+
#![cfg_attr(not(bootstrap), feature(autodiff))]
89
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
910
#![doc(rust_logo)]
1011
#![feature(assert_matches)]
11-
#![feature(autodiff)]
1212
#![feature(box_patterns)]
1313
#![feature(decl_macro)]
1414
#![feature(if_let_guard)]
@@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
112112

113113
register_attr! {
114114
alloc_error_handler: alloc_error_handler::expand,
115-
autodiff: autodiff::expand,
115+
autodiff_forward: autodiff::expand_forward,
116+
autodiff_reverse: autodiff::expand_reverse,
116117
bench: test::expand_bench,
117118
cfg_accessible: cfg_accessible::Expander,
118119
cfg_eval: cfg_eval::expand,

compiler/rustc_passes/src/check_attr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
255255
self.check_generic_attr(hir_id, attr, target, Target::Fn);
256256
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
257257
}
258-
[sym::autodiff, ..] => {
258+
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
259259
self.check_autodiff(hir_id, attr, span, target)
260260
}
261261
[sym::coroutine, ..] => {

compiler/rustc_span/src/symbol.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ symbols! {
244244
FnMut,
245245
FnOnce,
246246
Formatter,
247+
Forward,
247248
From,
248249
FromIterator,
249250
FromResidual,
@@ -339,6 +340,7 @@ symbols! {
339340
Result,
340341
ResumeTy,
341342
Return,
343+
Reverse,
342344
Right,
343345
Rust,
344346
RustaceansAreAwesome,
@@ -521,7 +523,8 @@ symbols! {
521523
audit_that,
522524
augmented_assignments,
523525
auto_traits,
524-
autodiff,
526+
autodiff_forward,
527+
autodiff_reverse,
525528
automatically_derived,
526529
avx,
527530
avx10_target_feature,

library/core/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,11 @@ pub mod assert_matches {
225225

226226
// We don't export this through #[macro_export] for now, to avoid breakage.
227227
#[unstable(feature = "autodiff", issue = "124509")]
228+
#[cfg(not(bootstrap))]
228229
/// Unstable module containing the unstable `autodiff` macro.
229230
pub mod autodiff {
230231
#[unstable(feature = "autodiff", issue = "124509")]
231-
pub use crate::macros::builtin::autodiff;
232+
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
232233
}
233234

234235
#[unstable(feature = "contracts", issue = "128044")]

library/core/src/macros/mod.rs

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,20 +1519,41 @@ pub(crate) mod builtin {
15191519
($file:expr $(,)?) => {{ /* compiler built-in */ }};
15201520
}
15211521

1522-
/// Automatic Differentiation macro which allows generating a new function to compute
1523-
/// the derivative of a given function. It may only be applied to a function.
1524-
/// The expected usage syntax is
1525-
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
1526-
/// where:
1527-
/// NAME is a string that represents a valid function name.
1528-
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
1529-
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
1530-
/// OUTPUT_ACTIVITY must not be set if we implicitly return nothing (or explicitly return
1531-
/// `-> ()`). Otherwise it must be set to one of the allowed activities.
1522+
/// This macro uses forward-mode automatic differentiation to generate a new function.
1523+
/// It may only be applied to a function. The new function will compute the derivative
1524+
/// of the function to which the macro was applied.
1525+
///
1526+
/// The expected usage syntax is:
1527+
/// `#[autodiff_forward(NAME, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
1528+
///
1529+
/// - `NAME`: A string that represents a valid function name.
1530+
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
1531+
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
1532+
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
1533+
#[unstable(feature = "autodiff", issue = "124509")]
1534+
#[allow_internal_unstable(rustc_attrs)]
1535+
#[rustc_builtin_macro]
1536+
#[cfg(not(bootstrap))]
1537+
pub macro autodiff_forward($item:item) {
1538+
/* compiler built-in */
1539+
}
1540+
1541+
/// This macro uses reverse-mode automatic differentiation to generate a new function.
1542+
/// It may only be applied to a function. The new function will compute the derivative
1543+
/// of the function to which the macro was applied.
1544+
///
1545+
/// The expected usage syntax is:
1546+
/// `#[autodiff_reverse(NAME, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
1547+
///
1548+
/// - `NAME`: A string that represents a valid function name.
1549+
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
1550+
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
1551+
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
15321552
#[unstable(feature = "autodiff", issue = "124509")]
15331553
#[allow_internal_unstable(rustc_attrs)]
15341554
#[rustc_builtin_macro]
1535-
pub macro autodiff($item:item) {
1555+
#[cfg(not(bootstrap))]
1556+
pub macro autodiff_reverse($item:item) {
15361557
/* compiler built-in */
15371558
}
15381559

library/std/src/lib.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,12 @@
276276
// tidy-alphabetical-start
277277

278278
// stabilization was reverted after it hit beta
279+
#![cfg_attr(not(bootstrap), feature(autodiff))]
279280
#![feature(alloc_error_handler)]
280281
#![feature(allocator_internals)]
281282
#![feature(allow_internal_unsafe)]
282283
#![feature(allow_internal_unstable)]
283284
#![feature(asm_experimental_arch)]
284-
#![feature(autodiff)]
285285
#![feature(cfg_sanitizer_cfi)]
286286
#![feature(cfg_target_thread_local)]
287287
#![feature(cfi_encoding)]
@@ -636,12 +636,15 @@ pub mod simd {
636636
#[doc(inline)]
637637
pub use crate::std_float::StdFloat;
638638
}
639+
639640
#[unstable(feature = "autodiff", issue = "124509")]
641+
#[cfg(not(bootstrap))]
640642
/// This module provides support for automatic differentiation.
641643
pub mod autodiff {
642644
/// This macro handles automatic differentiation.
643-
pub use core::autodiff::autodiff;
645+
pub use core::autodiff::{autodiff_forward, autodiff_reverse};
644646
}
647+
645648
#[stable(feature = "futures_api", since = "1.36.0")]
646649
pub mod task {
647650
//! Types and Traits for working with asynchronous tasks.

tests/codegen/autodiff/batched.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
#![feature(autodiff)]
1313

14-
use std::autodiff::autodiff;
14+
use std::autodiff::autodiff_forward;
1515

16-
#[autodiff(d_square3, Forward, Dual, DualOnly)]
17-
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
18-
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
16+
#[autodiff_forward(d_square3, Dual, DualOnly)]
17+
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
18+
#[autodiff_forward(d_square1, 4, Dual, Dual)]
1919
#[no_mangle]
2020
fn square(x: &f32) -> f32 {
2121
x * x

tests/codegen/autodiff/generic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
//@ needs-enzyme
44
#![feature(autodiff)]
55

6-
use std::autodiff::autodiff;
6+
use std::autodiff::autodiff_reverse;
77

8-
#[autodiff(d_square, Reverse, Duplicated, Active)]
8+
#[autodiff_reverse(d_square, Duplicated, Active)]
99
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
1010
*x * *x
1111
}

tests/codegen/autodiff/identical_fnc.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
1212
#![feature(autodiff)]
1313

14-
use std::autodiff::autodiff;
14+
use std::autodiff::autodiff_reverse;
1515

16-
#[autodiff(d_square, Reverse, Duplicated, Active)]
16+
#[autodiff_reverse(d_square, Duplicated, Active)]
1717
fn square(x: &f64) -> f64 {
1818
x * x
1919
}
2020

21-
#[autodiff(d_square2, Reverse, Duplicated, Active)]
21+
#[autodiff_reverse(d_square2, Duplicated, Active)]
2222
fn square2(x: &f64) -> f64 {
2323
x * x
2424
}

tests/codegen/autodiff/inline.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
#![feature(autodiff)]
66

7-
use std::autodiff::autodiff;
7+
use std::autodiff::autodiff_reverse;
88

9-
#[autodiff(d_square, Reverse, Duplicated, Active)]
9+
#[autodiff_reverse(d_square, Duplicated, Active)]
1010
fn square(x: &f64) -> f64 {
1111
x * x
1212
}

tests/codegen/autodiff/scalar.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
//@ needs-enzyme
44
#![feature(autodiff)]
55

6-
use std::autodiff::autodiff;
6+
use std::autodiff::autodiff_reverse;
77

8-
#[autodiff(d_square, Reverse, Duplicated, Active)]
8+
#[autodiff_reverse(d_square, Duplicated, Active)]
99
#[no_mangle]
1010
fn square(x: &f64) -> f64 {
1111
x * x

tests/codegen/autodiff/sret.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
#![feature(autodiff)]
1111

12-
use std::autodiff::autodiff;
12+
use std::autodiff::autodiff_reverse;
1313

1414
#[no_mangle]
15-
#[autodiff(df, Reverse, Active, Active, Active)]
15+
#[autodiff_reverse(df, Active, Active, Active)]
1616
fn primal(x: f32, y: f32) -> f64 {
1717
(x * x * y) as f64
1818
}

0 commit comments

Comments
 (0)