Skip to content

Commit ef14bab

Browse files
committed
test: separate out the enzyme testing
1 parent 574b0d8 commit ef14bab

File tree

5 files changed

+24
-25
lines changed

5 files changed

+24
-25
lines changed

.github/workflows/Downgrade.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
version: ['1']
1919
group:
2020
- Core
21+
- Enzyme
2122
steps:
2223
- uses: actions/checkout@v4
2324
- uses: julia-actions/setup-julia@v2

.github/workflows/Tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- "LinearSolveHYPRE"
3232
- "LinearSolvePardiso"
3333
- "LinearSolveBandedMatrices"
34+
- "Enzyme"
3435
uses: "SciML/.github/.github/workflows/tests.yml@v1"
3536
with:
3637
group: "${{ matrix.group }}"

ext/LinearSolveEnzymeExt.jl

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module LinearSolveEnzymeExt
33
using LinearSolve
44
using LinearSolve.LinearAlgebra
55
using EnzymeCore
6+
using EnzymeCore: EnzymeRules
67

7-
function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfigWidth{1},
8+
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
89
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
910
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
1011
@assert !(prob isa Const)
@@ -19,26 +20,20 @@ function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfig
1920
dres = func.val(prob.dval, alg.val; kwargs...)
2021
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
2122
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
22-
if RT <: DuplicatedNoNeed
23-
return dres
24-
elseif RT <: Duplicated
25-
return Duplicated(res, dres)
26-
end
27-
error("Unsupported return type $RT")
2823

2924
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
30-
Duplicated(res, dres)
25+
return Duplicated(res, dres)
3126
elseif EnzymeRules.needs_shadow(config)
32-
dres
27+
return dres
3328
elseif EnzymeRules.needs_primal(config)
34-
res
29+
return res
3530
else
36-
nothing
31+
return nothing
3732
end
3833
end
3934

40-
function EnzymeCore.EnzymeRules.forward(
41-
config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
35+
function EnzymeRules.forward(
36+
config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
4237
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
4338
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
4439
@assert !(linsolve isa Const)
@@ -66,17 +61,17 @@ function EnzymeCore.EnzymeRules.forward(
6661
linsolve.val.b = b
6762

6863
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
69-
Duplicated(res, dres)
64+
return Duplicated(res, dres)
7065
elseif EnzymeRules.needs_shadow(config)
71-
dres
66+
return dres
7267
elseif EnzymeRules.needs_primal(config)
73-
res
68+
return res
7469
else
75-
nothing
70+
return nothing
7671
end
7772
end
7873

79-
function EnzymeCore.EnzymeRules.augmented_primal(
74+
function EnzymeRules.augmented_primal(
8075
config, func::Const{typeof(LinearSolve.init)},
8176
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
8277
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@@ -111,10 +106,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
111106
(dval.b for dval in prob.dval)
112107
end
113108

114-
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
109+
return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
115110
end
116111

117-
function EnzymeCore.EnzymeRules.reverse(
112+
function EnzymeRules.reverse(
118113
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
119114
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
120115
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@@ -148,7 +143,7 @@ end
148143
# y=inv(A) B
149144
# dA −= z y^T
150145
# dB += z, where z = inv(A^T) dy
151-
function EnzymeCore.EnzymeRules.augmented_primal(
146+
function EnzymeRules.augmented_primal(
152147
config, func::Const{typeof(LinearSolve.solve!)},
153148
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
154149
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@@ -201,10 +196,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
201196
cachesolve = deepcopy(linsolve.val)
202197

203198
cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
204-
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
199+
return EnzymeRules.AugmentedReturn(res, dres, cache)
205200
end
206201

207-
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
202+
function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
208203
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
209204
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
210205
y, dys, _linsolve, dAs, dbs = cache

test/enzyme.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using Enzyme, ForwardDiff
22
using LinearSolve, LinearAlgebra, Test
33
using FiniteDiff
4-
using SafeTestsets
54

65
n = 4
76
A = rand(n, n);

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@ if GROUP == "All" || GROUP == "Core"
1515
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
1616
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
1717
@time @safetestset "Default Alg Tests" include("default_algs.jl")
18-
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
1918
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
2019
@time @safetestset "Traits" include("traits.jl")
2120
@time @safetestset "BandedMatrices" include("banded.jl")
2221
@time @safetestset "Static Arrays" include("static_arrays.jl")
2322
end
2423

24+
if GROUP == "All" || GROUP == "Enzyme"
25+
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
26+
end
27+
2528
if GROUP == "LinearSolveCUDA"
2629
Pkg.activate("gpu")
2730
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))

0 commit comments

Comments
 (0)