Skip to content

Commit b934c11

Browse files
authored
Clean up fixgradient return type (rust-lang#341)
* clean up fixgradient return type
1 parent e850ca4 commit b934c11

File tree

2 files changed

+160
-7
lines changed

2 files changed

+160
-7
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,9 +2802,13 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
28022802

28032803
if (mode == DerivativeMode::ReverseModeCombined) {
28042804

2805-
FunctionType *FTy =
2806-
FunctionType::get(StructType::get(todiff->getContext(), {res.second}),
2807-
res.first, todiff->getFunctionType()->isVarArg());
2805+
Type *FRetTy = res.second.empty()
2806+
? Type::getVoidTy(todiff->getContext())
2807+
: StructType::get(todiff->getContext(), {res.second});
2808+
2809+
FunctionType *FTy = FunctionType::get(
2810+
FRetTy, res.first, todiff->getFunctionType()->isVarArg());
2811+
28082812
Function *NewF = Function::Create(
28092813
FTy, Function::LinkageTypes::InternalLinkage,
28102814
"fixgradient_" + todiff->getName(), todiff->getParent());
@@ -2862,10 +2866,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
28622866
}
28632867
auto revcal = bb.CreateCall(revfn, revargs);
28642868
revcal->setCallingConv(revfn->getCallingConv());
2865-
if (NewF->getReturnType()->isEmptyTy())
2869+
2870+
if (NewF->getReturnType()->isEmptyTy()) {
28662871
bb.CreateRet(UndefValue::get(NewF->getReturnType()));
2867-
else
2872+
} else if (NewF->getReturnType()->isVoidTy()) {
2873+
bb.CreateRetVoid();
2874+
} else {
28682875
bb.CreateRet(revcal);
2876+
}
28692877
assert(!returnUsed);
28702878

28712879
return insert_or_assign2<ReverseCacheKey, Function *>(
@@ -3008,9 +3016,13 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
30083016
st == nullptr && !foundcalled->getReturnType()->isVoidTy();
30093017
if (wrongRet) {
30103018
// if (wrongRet || !hasTape) {
3019+
Type *FRetTy =
3020+
res.second.empty()
3021+
? Type::getVoidTy(todiff->getContext())
3022+
: StructType::get(todiff->getContext(), {res.second});
3023+
30113024
FunctionType *FTy = FunctionType::get(
3012-
StructType::get(todiff->getContext(), {res.second}), res.first,
3013-
todiff->getFunctionType()->isVarArg());
3025+
FRetTy, res.first, todiff->getFunctionType()->isVarArg());
30143026
Function *NewF = Function::Create(
30153027
FTy, Function::LinkageTypes::InternalLinkage,
30163028
"fixgradient_" + todiff->getName(), todiff->getParent());
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
4+
; #include <stdio.h>
5+
6+
; double __enzyme_autodiff(void*, double);
7+
8+
; __attribute__((noinline))
9+
; void square_(const double* src, double* dest) {
10+
; *dest = *src * *src;
11+
; }
12+
13+
; void* augment_square_(const double* src, const double *d_src, double* dest, double* d_dest) {
14+
; *dest = *src * *src;
15+
; return NULL;
16+
; }
17+
18+
; void gradient_square_(const double* src, double *d_src, const double* dest, const double* d_dest, void* tape) {
19+
; *d_src = *d_dest * *src * 2;
20+
; }
21+
22+
; void* __enzyme_register_gradient_square[] = {
23+
; (void*)square_,
24+
; (void*)augment_square_,
25+
; (void*)gradient_square_,
26+
; };
27+
28+
29+
; double square(double x) {
30+
; double y;
31+
; square_(&x, &y);
32+
; return y;
33+
; }
34+
35+
; double dsquare(double x) {
36+
; return __enzyme_autodiff((void*)square, x);
37+
; }
38+
39+
40+
; int main() {
41+
; double res = dsquare(3.0);
42+
; printf("res=%f\n", res);
43+
; }
44+
45+
@__enzyme_register_gradient_square = dso_local local_unnamed_addr global [3 x i8*] [i8* bitcast (void (double*, double*)* @square_ to i8*), i8* bitcast (i8* (double*, double*, double*, double*)* @augment_square_ to i8*), i8* bitcast (void (double*, double*, double*, double*, i8*)* @gradient_square_ to i8*)], align 16
46+
@.str = private unnamed_addr constant [8 x i8] c"res=%f\0A\00", align 1
47+
48+
define dso_local void @square_(double* nocapture readonly %src, double* nocapture %dest) #0 {
49+
entry:
50+
%0 = load double, double* %src, align 8
51+
%mul = fmul double %0, %0
52+
store double %mul, double* %dest, align 8
53+
ret void
54+
}
55+
56+
define dso_local noalias i8* @augment_square_(double* nocapture readonly %src, double* nocapture readnone %d_src, double* nocapture %dest, double* nocapture readnone %d_dest) #1 {
57+
entry:
58+
%0 = load double, double* %src, align 8
59+
%mul = fmul double %0, %0
60+
store double %mul, double* %dest, align 8
61+
ret i8* null
62+
}
63+
64+
define dso_local void @gradient_square_(double* nocapture readonly %src, double* nocapture %d_src, double* nocapture readnone %dest, double* nocapture readonly %d_dest, i8* nocapture readnone %tape) #1 {
65+
entry:
66+
%0 = load double, double* %d_dest, align 8
67+
%1 = load double, double* %src, align 8
68+
%mul = fmul double %0, %1
69+
%mul1 = fmul double %mul, 2.000000e+00
70+
store double %mul1, double* %d_src, align 8
71+
ret void
72+
}
73+
74+
define dso_local double @square(double %x) #2 {
75+
entry:
76+
%x.addr = alloca double, align 8
77+
%y = alloca double, align 8
78+
store double %x, double* %x.addr, align 8
79+
%0 = bitcast double* %y to i8*
80+
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #6
81+
call void @square_(double* nonnull %x.addr, double* nonnull %y)
82+
%1 = load double, double* %y, align 8
83+
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #6
84+
ret double %1
85+
}
86+
87+
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #3
88+
89+
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #3
90+
91+
define dso_local double @dsquare(double %x) local_unnamed_addr #2 {
92+
entry:
93+
%call = tail call double @__enzyme_autodiff(i8* bitcast (double (double)* @square to i8*), double %x) #6
94+
ret double %call
95+
}
96+
97+
declare dso_local double @__enzyme_autodiff(i8*, double) local_unnamed_addr #4
98+
99+
define dso_local i32 @main() local_unnamed_addr #2 {
100+
entry:
101+
%call.i = tail call double @__enzyme_autodiff(i8* bitcast (double (double)* @square to i8*), double 3.000000e+00) #6
102+
%call1 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([8 x i8], [8 x i8]* @.str, i64 0, i64 0), double %call.i)
103+
ret i32 0
104+
}
105+
106+
declare dso_local i32 @printf(i8* nocapture readonly, ...) local_unnamed_addr #5
107+
108+
attributes #0 = { noinline norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
109+
attributes #1 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
110+
attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
111+
attributes #3 = { argmemonly nounwind }
112+
attributes #4 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
113+
attributes #5 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
114+
attributes #6 = { nounwind }
115+
116+
117+
; CHECK: define internal { double } @diffesquare(double %x, double %differeturn)
118+
; CHECK-NEXT: entry:
119+
; CHECK-NEXT: %"x.addr'ipa" = alloca double, align 8
120+
; CHECK-NEXT: store double 0.000000e+00, double* %"x.addr'ipa", align 8
121+
; CHECK-NEXT: %x.addr = alloca double, align 8
122+
; CHECK-NEXT: %"y'ipa" = alloca double, align 8
123+
; CHECK-NEXT: store double 0.000000e+00, double* %"y'ipa", align 8
124+
; CHECK-NEXT: %y = alloca double, align 8
125+
; CHECK-NEXT: store double %x, double* %x.addr, align 8
126+
; CHECK-NEXT: %0 = load double, double* %"y'ipa", align 8
127+
; CHECK-NEXT: %1 = fadd fast double %0, %differeturn
128+
; CHECK-NEXT: store double %1, double* %"y'ipa", align 8
129+
; CHECK-NEXT: call void @fixgradient_square_(double* %x.addr, double* %"x.addr'ipa", double* %y, double* %"y'ipa")
130+
; CHECK-NEXT: %2 = load double, double* %"x.addr'ipa", align 8
131+
; CHECK-NEXT: store double 0.000000e+00, double* %"x.addr'ipa", align 8
132+
; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0
133+
; CHECK-NEXT: ret { double } %3
134+
; CHECK-NEXT: }
135+
136+
; CHECK: define internal void @fixgradient_square_(double*{{( %0)?}}, double*{{( %1)?}}, double*{{( %2)?}}, double*{{( %3)?}})
137+
; CHECK-NEXT: entry:
138+
; CHECK-NEXT: %4 = call i8* @augment_square_(double* %0, double* %1, double* %2, double* %3)
139+
; CHECK-NEXT: call void @gradient_square_(double* %0, double* %1, double* %2, double* %3, i8* %4)
140+
; CHECK-NEXT: ret void
141+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)