1
+ ; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2
+
3
+
4
+ ; #include <stdio.h>
5
+
6
+ ; double __enzyme_autodiff(void*, double);
7
+
8
+ ; __attribute__((noinline))
9
+ ; void square_(const double* src, double* dest) {
10
+ ; *dest = *src * *src;
11
+ ; }
12
+
13
+ ; void* augment_square_(const double* src, const double *d_src, double* dest, double* d_dest) {
14
+ ; *dest = *src * *src;
15
+ ; return NULL;
16
+ ; }
17
+
18
+ ; void gradient_square_(const double* src, double *d_src, const double* dest, const double* d_dest, void* tape) {
19
+ ; *d_src = *d_dest * *src * 2;
20
+ ; }
21
+
22
+ ; void* __enzyme_register_gradient_square[] = {
23
+ ; (void*)square_,
24
+ ; (void*)augment_square_,
25
+ ; (void*)gradient_square_,
26
+ ; };
27
+
28
+
29
+ ; double square(double x) {
30
+ ; double y;
31
+ ; square_(&x, &y);
32
+ ; return y;
33
+ ; }
34
+
35
+ ; double dsquare(double x) {
36
+ ; return __enzyme_autodiff((void*)square, x);
37
+ ; }
38
+
39
+
40
+ ; int main() {
41
+ ; double res = dsquare(3.0);
42
+ ; printf("res=%f\n", res);
43
+ ; }
44
+
45
+ @__enzyme_register_gradient_square = dso_local local_unnamed_addr global [3 x i8* ] [i8* bitcast (void (double *, double *)* @square_ to i8* ), i8* bitcast (i8* (double *, double *, double *, double *)* @augment_square_ to i8* ), i8* bitcast (void (double *, double *, double *, double *, i8* )* @gradient_square_ to i8* )], align 16
46
+ @.str = private unnamed_addr constant [8 x i8 ] c "res=%f\0A\00 " , align 1
47
+
48
+ define dso_local void @square_ (double * nocapture readonly %src , double * nocapture %dest ) #0 {
49
+ entry:
50
+ %0 = load double , double * %src , align 8
51
+ %mul = fmul double %0 , %0
52
+ store double %mul , double * %dest , align 8
53
+ ret void
54
+ }
55
+
56
+ define dso_local noalias i8* @augment_square_ (double * nocapture readonly %src , double * nocapture readnone %d_src , double * nocapture %dest , double * nocapture readnone %d_dest ) #1 {
57
+ entry:
58
+ %0 = load double , double * %src , align 8
59
+ %mul = fmul double %0 , %0
60
+ store double %mul , double * %dest , align 8
61
+ ret i8* null
62
+ }
63
+
64
+ define dso_local void @gradient_square_ (double * nocapture readonly %src , double * nocapture %d_src , double * nocapture readnone %dest , double * nocapture readonly %d_dest , i8* nocapture readnone %tape ) #1 {
65
+ entry:
66
+ %0 = load double , double * %d_dest , align 8
67
+ %1 = load double , double * %src , align 8
68
+ %mul = fmul double %0 , %1
69
+ %mul1 = fmul double %mul , 2 .000000e+00
70
+ store double %mul1 , double * %d_src , align 8
71
+ ret void
72
+ }
73
+
74
+ define dso_local double @square (double %x ) #2 {
75
+ entry:
76
+ %x.addr = alloca double , align 8
77
+ %y = alloca double , align 8
78
+ store double %x , double * %x.addr , align 8
79
+ %0 = bitcast double * %y to i8*
80
+ call void @llvm.lifetime.start.p0i8 (i64 8 , i8* nonnull %0 ) #6
81
+ call void @square_ (double * nonnull %x.addr , double * nonnull %y )
82
+ %1 = load double , double * %y , align 8
83
+ call void @llvm.lifetime.end.p0i8 (i64 8 , i8* nonnull %0 ) #6
84
+ ret double %1
85
+ }
86
+
87
+ declare void @llvm.lifetime.start.p0i8 (i64 , i8* nocapture ) #3
88
+
89
+ declare void @llvm.lifetime.end.p0i8 (i64 , i8* nocapture ) #3
90
+
91
+ define dso_local double @dsquare (double %x ) local_unnamed_addr #2 {
92
+ entry:
93
+ %call = tail call double @__enzyme_autodiff (i8* bitcast (double (double )* @square to i8* ), double %x ) #6
94
+ ret double %call
95
+ }
96
+
97
+ declare dso_local double @__enzyme_autodiff (i8* , double ) local_unnamed_addr #4
98
+
99
+ define dso_local i32 @main () local_unnamed_addr #2 {
100
+ entry:
101
+ %call.i = tail call double @__enzyme_autodiff (i8* bitcast (double (double )* @square to i8* ), double 3 .000000e+00 ) #6
102
+ %call1 = tail call i32 (i8* , ...) @printf (i8* getelementptr inbounds ([8 x i8 ], [8 x i8 ]* @.str , i64 0 , i64 0 ), double %call.i )
103
+ ret i32 0
104
+ }
105
+
106
+ declare dso_local i32 @printf (i8* nocapture readonly , ...) local_unnamed_addr #5
107
+
108
+ attributes #0 = { noinline norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math" ="false" "disable-tail-calls" ="false" "less-precise-fpmad" ="false" "min-legal-vector-width" ="0" "no-frame-pointer-elim" ="false" "no-infs-fp-math" ="false" "no-jump-tables" ="false" "no-nans-fp-math" ="false" "no-signed-zeros-fp-math" ="false" "no-trapping-math" ="false" "stack-protector-buffer-size" ="8" "target-cpu" ="x86-64" "target-features" ="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math" ="false" "use-soft-float" ="false" }
109
+ attributes #1 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math" ="false" "disable-tail-calls" ="false" "less-precise-fpmad" ="false" "min-legal-vector-width" ="0" "no-frame-pointer-elim" ="false" "no-infs-fp-math" ="false" "no-jump-tables" ="false" "no-nans-fp-math" ="false" "no-signed-zeros-fp-math" ="false" "no-trapping-math" ="false" "stack-protector-buffer-size" ="8" "target-cpu" ="x86-64" "target-features" ="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math" ="false" "use-soft-float" ="false" }
110
+ attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math" ="false" "disable-tail-calls" ="false" "less-precise-fpmad" ="false" "min-legal-vector-width" ="0" "no-frame-pointer-elim" ="false" "no-infs-fp-math" ="false" "no-jump-tables" ="false" "no-nans-fp-math" ="false" "no-signed-zeros-fp-math" ="false" "no-trapping-math" ="false" "stack-protector-buffer-size" ="8" "target-cpu" ="x86-64" "target-features" ="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math" ="false" "use-soft-float" ="false" }
111
+ attributes #3 = { argmemonly nounwind }
112
+ attributes #4 = { "correctly-rounded-divide-sqrt-fp-math" ="false" "disable-tail-calls" ="false" "less-precise-fpmad" ="false" "no-frame-pointer-elim" ="false" "no-infs-fp-math" ="false" "no-nans-fp-math" ="false" "no-signed-zeros-fp-math" ="false" "no-trapping-math" ="false" "stack-protector-buffer-size" ="8" "target-cpu" ="x86-64" "target-features" ="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math" ="false" "use-soft-float" ="false" }
113
+ attributes #5 = { nounwind "correctly-rounded-divide-sqrt-fp-math" ="false" "disable-tail-calls" ="false" "less-precise-fpmad" ="false" "no-frame-pointer-elim" ="false" "no-infs-fp-math" ="false" "no-nans-fp-math" ="false" "no-signed-zeros-fp-math" ="false" "no-trapping-math" ="false" "stack-protector-buffer-size" ="8" "target-cpu" ="x86-64" "target-features" ="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math" ="false" "use-soft-float" ="false" }
114
+ attributes #6 = { nounwind }
115
+
116
+
117
+ ; CHECK: define internal { double } @diffesquare(double %x, double %differeturn)
118
+ ; CHECK-NEXT: entry:
119
+ ; CHECK-NEXT: %"x.addr'ipa" = alloca double, align 8
120
+ ; CHECK-NEXT: store double 0.000000e+00, double* %"x.addr'ipa", align 8
121
+ ; CHECK-NEXT: %x.addr = alloca double, align 8
122
+ ; CHECK-NEXT: %"y'ipa" = alloca double, align 8
123
+ ; CHECK-NEXT: store double 0.000000e+00, double* %"y'ipa", align 8
124
+ ; CHECK-NEXT: %y = alloca double, align 8
125
+ ; CHECK-NEXT: store double %x, double* %x.addr, align 8
126
+ ; CHECK-NEXT: %0 = load double, double* %"y'ipa", align 8
127
+ ; CHECK-NEXT: %1 = fadd fast double %0, %differeturn
128
+ ; CHECK-NEXT: store double %1, double* %"y'ipa", align 8
129
+ ; CHECK-NEXT: call void @fixgradient_square_(double* %x.addr, double* %"x.addr'ipa", double* %y, double* %"y'ipa")
130
+ ; CHECK-NEXT: %2 = load double, double* %"x.addr'ipa", align 8
131
+ ; CHECK-NEXT: store double 0.000000e+00, double* %"x.addr'ipa", align 8
132
+ ; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0
133
+ ; CHECK-NEXT: ret { double } %3
134
+ ; CHECK-NEXT: }
135
+
136
+ ; CHECK: define internal void @fixgradient_square_(double*{{( %0)?}}, double*{{( %1)?}}, double*{{( %2)?}}, double*{{( %3)?}})
137
+ ; CHECK-NEXT: entry:
138
+ ; CHECK-NEXT: %4 = call i8* @augment_square_(double* %0, double* %1, double* %2, double* %3)
139
+ ; CHECK-NEXT: call void @gradient_square_(double* %0, double* %1, double* %2, double* %3, i8* %4)
140
+ ; CHECK-NEXT: ret void
141
+ ; CHECK-NEXT: }
0 commit comments