Skip to content

Commit e97c0d1

Browse files
authored
Void augmented (rust-lang#351)
1 parent f426b09 commit e97c0d1

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -980,10 +980,6 @@ class Enzyme : public ModulePass {
980980
tape = Builder.CreateLoad(
981981
Builder.CreatePointerCast(AL, PointerType::getUnqual(tapeType)));
982982
}
983-
llvm::errs() << *CI->getParent() << "\n";
984-
llvm::errs() << *CI->getParent() << "\n";
985-
llvm::errs() << *tape << "\n";
986-
llvm::errs() << *tapeType << "\n";
987983
assert(tape->getType() == tapeType);
988984
args.push_back(tape);
989985
}
@@ -1056,7 +1052,8 @@ class Enzyme : public ModulePass {
10561052
}
10571053
}
10581054

1059-
if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy()) {
1055+
if (!diffret->getType()->isEmptyTy() && !diffret->getType()->isVoidTy() &&
1056+
!CI->getType()->isEmptyTy() && !CI->getType()->isVoidTy()) {
10601057
if (diffret->getType() == CI->getType()) {
10611058
CI->replaceAllUsesWith(diffret);
10621059
} else if (mode == DerivativeMode::ReverseModePrimal) {
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double* %x) {
5+
entry:
6+
%gep = getelementptr double, double* %x, i32 1
7+
%y = load double, double* %x
8+
%z = load double, double* %gep
9+
%res = fmul fast double %y, %z
10+
ret double %res
11+
}
12+
13+
define void @test_derivative(double* %x, double* %dx) {
14+
entry:
15+
%size = call i64 (double (double*)*, ...) @__enzyme_augmentsize(double (double*)* nonnull @tester, metadata !"enzyme_dup")
16+
%cache = alloca i8, i64 %size, align 1
17+
call void (double (double*)*, ...) @__enzyme_augmentfwd(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
18+
tail call void (double (double*)*, ...) @__enzyme_reverse(double (double*)* nonnull @tester, metadata !"enzyme_allocated", i64 %size, metadata !"enzyme_tape", i8* %cache, double* %x, double* %dx)
19+
ret void
20+
}
21+
22+
; Function Attrs: nounwind
23+
declare void @__enzyme_augmentfwd(double (double*)*, ...)
24+
declare i64 @__enzyme_augmentsize(double (double*)*, ...)
25+
declare void @__enzyme_reverse(double (double*)*, ...)
26+
27+
; CHECK: define void @test_derivative(double* %x, double* %dx)
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %cache = alloca i8, i64 16
30+
; CHECK-NEXT: %0 = call { { double, double }, double } @augmented_tester(double* %x, double* %dx)
31+
; CHECK-NEXT: %1 = extractvalue { { double, double }, double } %0, 0
32+
; CHECK-NEXT: %2 = bitcast i8* %cache to { double, double }*
33+
; CHECK-NEXT: store { double, double } %1, { double, double }* %2
34+
; CHECK-NEXT: %3 = bitcast i8* %cache to { double, double }*
35+
; CHECK-NEXT: %4 = load { double, double }, { double, double }* %3
36+
; CHECK-NEXT: call void @diffetester(double* %x, double* %dx, double 1.000000e+00, { double, double } %4)
37+
; CHECK-NEXT: ret void
38+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)