Skip to content

Commit 1e6150e

Browse files
Merge pull request #377 from wsmoses/master
Add Enzyme extension
2 parents 37e5328 + 89e10df commit 1e6150e

File tree

6 files changed

+303
-3
lines changed

6 files changed

+303
-3
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
11+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1112
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
1213
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1314
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3031

3132
[weakdeps]
3233
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
34+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3335
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3436
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
3537
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
@@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
4244
[extensions]
4345
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
4446
LinearSolveCUDAExt = "CUDA"
47+
LinearSolveEnzymeExt = "Enzyme"
4548
LinearSolveHYPREExt = "HYPRE"
4649
LinearSolveIterativeSolversExt = "IterativeSolvers"
4750
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
@@ -78,6 +81,8 @@ julia = "1.6"
7881

7982
[extras]
8083
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
84+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
85+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
8186
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8287
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
8388
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -95,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
95100
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
96101

97102
[targets]
98-
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
103+
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]

ext/LinearSolveEnzymeExt.jl

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
module LinearSolveEnzymeExt
2+
3+
using LinearSolve
4+
using LinearSolve.LinearAlgebra
5+
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
6+
7+
8+
using Enzyme
9+
10+
using EnzymeCore
11+
12+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
13+
res = func.val(prob.val, alg.val; kwargs...)
14+
dres = if EnzymeRules.width(config) == 1
15+
func.val(prob.dval, alg.val; kwargs...)
16+
else
17+
ntuple(Val(EnzymeRules.width(config))) do i
18+
Base.@_inline_meta
19+
func.val(prob.dval[i], alg.val; kwargs...)
20+
end
21+
end
22+
d_A = if EnzymeRules.width(config) == 1
23+
dres.A
24+
else
25+
(dval.A for dval in dres)
26+
end
27+
d_b = if EnzymeRules.width(config) == 1
28+
dres.b
29+
else
30+
(dval.b for dval in dres)
31+
end
32+
33+
34+
prob_d_A = if EnzymeRules.width(config) == 1
35+
prob.dval.A
36+
else
37+
(dval.A for dval in prob.dval)
38+
end
39+
prob_d_b = if EnzymeRules.width(config) == 1
40+
prob.dval.b
41+
else
42+
(dval.b for dval in prob.dval)
43+
end
44+
45+
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
46+
end
47+
48+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
49+
d_A, d_b, prob_d_A, prob_d_b = cache
50+
51+
if EnzymeRules.width(config) == 1
52+
if d_A !== prob_d_A
53+
prob_d_A .+= d_A
54+
d_A .= 0
55+
end
56+
if d_b !== prob_d_b
57+
prob_d_b .+= d_b
58+
d_b .= 0
59+
end
60+
else
61+
for i in 1:EnzymeRules.width(config)
62+
if d_A !== prob_d_A[i]
63+
prob_d_A[i] .+= d_A[i]
64+
d_A[i] .= 0
65+
end
66+
if d_b !== prob_d_b[i]
67+
prob_d_b[i] .+= d_b[i]
68+
d_b[i] .= 0
69+
end
70+
end
71+
end
72+
73+
return (nothing, nothing)
74+
end
75+
76+
# y=inv(A) B
77+
# dA −= z y^T
78+
# dB += z, where z = inv(A^T) dy
79+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
80+
res = func.val(linsolve.val; kwargs...)
81+
82+
dres = if EnzymeRules.width(config) == 1
83+
deepcopy(res)
84+
else
85+
ntuple(Val(EnzymeRules.width(config))) do i
86+
Base.@_inline_meta
87+
deepcopy(res)
88+
end
89+
end
90+
91+
if EnzymeRules.width(config) == 1
92+
dres.u .= 0
93+
else
94+
for dr in dres
95+
dr.u .= 0
96+
end
97+
end
98+
99+
resvals = if EnzymeRules.width(config) == 1
100+
dres.u
101+
else
102+
ntuple(Val(EnzymeRules.width(config))) do i
103+
Base.@_inline_meta
104+
dres[i].u
105+
end
106+
end
107+
108+
dAs = if EnzymeRules.width(config) == 1
109+
(linsolve.dval.A,)
110+
else
111+
ntuple(Val(EnzymeRules.width(config))) do i
112+
Base.@_inline_meta
113+
linsolve.dval[i].A
114+
end
115+
end
116+
117+
dbs = if EnzymeRules.width(config) == 1
118+
(linsolve.dval.b,)
119+
else
120+
ntuple(Val(EnzymeRules.width(config))) do i
121+
Base.@_inline_meta
122+
linsolve.dval[i].b
123+
end
124+
end
125+
126+
cachesolve = deepcopy(linsolve.val)
127+
128+
cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
129+
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
130+
end
131+
132+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
133+
y, dys, _linsolve, dAs, dbs = cache
134+
135+
@assert !(typeof(linsolve) <: Const)
136+
@assert !(typeof(linsolve) <: Active)
137+
138+
if EnzymeRules.width(config) == 1
139+
dys = (dys,)
140+
end
141+
142+
for (dA, db, dy) in zip(dAs, dbs, dys)
143+
z = if _linsolve.cacheval isa Factorization
144+
_linsolve.cacheval' \ dy
145+
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
146+
_linsolve.cacheval[1]' \ dy
147+
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod
148+
# Doesn't modify `A`, so it's safe to just reuse it
149+
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
150+
solve(invprob;
151+
abstol = _linsolve.val.abstol,
152+
reltol = _linsolve.val.reltol,
153+
verbose = _linsolve.val.verbose)
154+
else
155+
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
156+
end
157+
158+
dA .-= z * transpose(y)
159+
db .+= z
160+
dy .= eltype(dy)(0)
161+
end
162+
163+
return (nothing,)
164+
end
165+
166+
end

src/init.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ function __init__()
1515
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
1616
include("../ext/LinearSolveMKLExt.jl")
1717
end
18+
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
19+
include("../ext/LinearSolveEnzymeExt.jl")
20+
end
1821
end
1922
end

test/basictests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ end
202202
end
203203
end
204204

205-
test_algs = if VERISON >= v"1.9"
205+
test_algs = if VERSION >= v"1.9"
206206
(LUFactorization(),
207207
QRFactorization(),
208208
SVDFactorization(),

test/enzyme.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using Enzyme, ForwardDiff
2+
using LinearSolve, LinearAlgebra, Test
3+
4+
n = 4
5+
A = rand(n, n);
6+
dA = zeros(n, n);
7+
b1 = rand(n);
8+
db1 = zeros(n);
9+
10+
function f(A, b1; alg = LUFactorization())
11+
prob = LinearProblem(A, b1)
12+
13+
sol1 = solve(prob, alg)
14+
15+
s1 = sol1.u
16+
norm(s1)
17+
end
18+
19+
f(A, b1) # Uses BLAS
20+
21+
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))
22+
23+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
24+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
25+
26+
@test dA dA2
27+
@test db1 db12
28+
29+
A = rand(n, n);
30+
dA = zeros(n, n);
31+
dA2 = zeros(n, n);
32+
b1 = rand(n);
33+
db1 = zeros(n);
34+
db12 = zeros(n);
35+
36+
#=
37+
# Batch test fails
38+
# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075
39+
40+
function fbatch(y, A, b1; alg = LUFactorization())
41+
prob = LinearProblem(A, b1)
42+
43+
sol1 = solve(prob, alg)
44+
45+
s1 = sol1.u
46+
y[1] = norm(s1)
47+
nothing
48+
end
49+
50+
y = [0.0]
51+
dy1 = [1.0]
52+
dy2 = [1.0]
53+
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
54+
55+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
56+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
57+
58+
@test_broken dA ≈ dA_2
59+
@test_broken dA2 ≈ dA_2
60+
@test_broken db1 ≈ db1_2
61+
@test_broken db12 ≈ db1_2
62+
=#
63+
64+
function f(A, b1, b2; alg = LUFactorization())
65+
prob = LinearProblem(A, b1)
66+
cache = init(prob, alg)
67+
s1 = copy(solve!(cache).u)
68+
cache.b = b2
69+
s2 = solve!(cache).u
70+
norm(s1 + s2)
71+
end
72+
73+
A = rand(n, n);
74+
dA = zeros(n, n);
75+
b1 = rand(n);
76+
db1 = zeros(n);
77+
b2 = rand(n);
78+
db2 = zeros(n);
79+
80+
f(A, b1, b2)
81+
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
82+
83+
dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A))
84+
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1))
85+
db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2))
86+
87+
@test dA dA2
88+
@test db1 db12
89+
@test db2 db22
90+
91+
function f2(A, b1, b2; alg = RFLUFactorization())
92+
prob = LinearProblem(A, b1)
93+
cache = init(prob, alg)
94+
s1 = copy(solve!(cache).u)
95+
cache.b = b2
96+
s2 = solve!(cache).u
97+
norm(s1 + s2)
98+
end
99+
100+
f2(A, b1, b2)
101+
dA = zeros(n, n);
102+
db1 = zeros(n);
103+
db2 = zeros(n);
104+
Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
105+
106+
@test dA dA2
107+
@test db1 db12
108+
@test db2 db22
109+
110+
#=
111+
function f3(A, b1, b2; alg = KrylovJL_GMRES())
112+
prob = LinearProblem(A, b1)
113+
cache = init(prob, alg)
114+
s1 = copy(solve!(cache).u)
115+
cache.b = b2
116+
s2 = solve!(cache).u
117+
norm(s1 + s2)
118+
end
119+
120+
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
121+
122+
@test dA ≈ dA2 atol=5e-5
123+
@test db1 ≈ db12
124+
@test db2 ≈ db22
125+
=#

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension)
88

99
if GROUP == "All" || GROUP == "Core"
1010
@time @safetestset "Basic Tests" include("basictests.jl")
11-
@time @safetestset "Re-solve" include("resolve.jl")
11+
VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl")
1212
@time @safetestset "Zero Initialization Tests" include("zeroinittests.jl")
1313
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
1414
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
1515
@time @safetestset "Default Alg Tests" include("default_algs.jl")
16+
VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
1617
@time @safetestset "Traits" include("traits.jl")
1718
end
1819

0 commit comments

Comments
 (0)