@@ -33,15 +33,25 @@ b1 = rand(n);
33
33
db1 = zeros (n);
34
34
db12 = zeros (n);
35
35
36
- #=
37
- # Batch test fails
38
- # Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075
36
+ # Batch test
37
+ n = 4
38
+ A = rand (n, n);
39
+ dA = zeros (n, n);
40
+ dA2 = zeros (n, n);
41
+ b1 = rand (n);
42
+ db1 = zeros (n);
43
+ db12 = zeros (n);
39
44
40
- function fbatch(y, A, b1; alg = LUFactorization())
45
+ function f ( A, b1; alg = LUFactorization ())
41
46
prob = LinearProblem (A, b1)
42
-
43
47
sol1 = solve (prob, alg)
48
+ s1 = sol1. u
49
+ norm (s1)
50
+ end
44
51
52
+ function fbatch (y, A, b1; alg = LUFactorization ())
53
+ prob = LinearProblem (A, b1)
54
+ sol1 = solve (prob, alg)
45
55
s1 = sol1. u
46
56
y[1 ] = norm (s1)
47
57
nothing
50
60
y = [0.0 ]
51
61
dy1 = [1.0 ]
52
62
dy2 = [1.0 ]
63
+ Enzyme. autodiff (Reverse, fbatch, Duplicated (y, dy1), Duplicated (copy (A), dA), Duplicated (copy (b1), db1))
64
+
65
+ @test y[1 ] ≈ f (copy (A),b1)
66
+ dA_2 = ForwardDiff. gradient (x-> f (x,eltype (x).(b1)), copy (A))
67
+ db1_2 = ForwardDiff. gradient (x-> f (eltype (x).(A),x), copy (b1))
68
+
69
+ @test dA ≈ dA_2
70
+ @test db1 ≈ db1_2
71
+
72
+ y .= 0
73
+ dy1 .= 1
74
+ dy2 .= 1
75
+ dA .= 0
76
+ dA2 .= 0
77
+ db1 .= 0
78
+ db12 .= 0
53
79
Enzyme. autodiff (Reverse, fbatch, BatchDuplicated (y, (dy1, dy2)), BatchDuplicated (copy (A), (dA, dA2)), BatchDuplicated (copy (b1), (db1, db12)))
54
80
55
- dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
56
- db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
57
-
58
- @test_broken dA ≈ dA_2
59
- @test_broken dA2 ≈ dA_2
60
- @test_broken db1 ≈ db1_2
61
- @test_broken db12 ≈ db1_2
62
- =#
81
+ @test dA ≈ dA_2
82
+ @test db1 ≈ db1_2
83
+ @test dA2 ≈ dA_2
84
+ @test db12 ≈ db1_2
63
85
64
86
function f (A, b1, b2; alg = LUFactorization ())
65
87
prob = LinearProblem (A, b1)
0 commit comments