Skip to content

Commit b2d8a28

Browse files
committed
make preconditioners part of the solver rather than a random extra
1 parent 844253c commit b2d8a28

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
133133
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
134134
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
135135
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
136+
KrylovPreconditioners = "45d422c2-293f-44ce-8315-2cb988662dec"
136137
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
137138
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
138139
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
@@ -146,4 +147,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
146147
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
147148

148149
[targets]
149-
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]
150+
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "KrylovPreconditioners", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck", "StableRNGs", "Zygote"]

src/common.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ default_alias_b(::Any, ::Any, ::Any) = false
121121
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
122122
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
123123

124+
DEFAULT_PRECS(A, p) = IdentityOperator(size(A)[1]), IdentityOperator(size(A)[2])
125+
124126
function __init_u0_from_Ab(A, b)
125127
u0 = similar(b, size(A, 2))
126128
fill!(u0, false)
@@ -136,12 +138,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
136138
reltol = default_tol(real(eltype(prob.b))),
137139
maxiters::Int = length(prob.b),
138140
verbose::Bool = false,
139-
Pl = IdentityOperator(size(prob.A)[1]),
140-
Pr = IdentityOperator(size(prob.A)[2]),
141+
Pl = nothing,
142+
Pr = nothing,
141143
assumptions = OperatorAssumptions(issquare(prob.A)),
142144
sensealg = LinearSolveAdjoint(),
143145
kwargs...)
144-
@unpack A, b, u0, p = prob
146+
(;A, b, u0, p) = prob
145147

146148
A = if alias_A || A isa SMatrix
147149
A
@@ -167,6 +169,18 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
167169
reltol = real(eltype(prob.b))(reltol)
168170
abstol = real(eltype(prob.b))(abstol)
169171

172+
precs = hasproperty(alg, :precs) ? alg.precs : DEFAULT_PRECS
173+
_Pl, _Pr = precs(A, p)
174+
if isnothing(Pl)
175+
Pl = _Pl
176+
else
177+
@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
178+
end
179+
if isnothing(Pr)
180+
Pr = _Pr
181+
else
182+
@warn "passing Preconditioners at `init`/`solve` time is deprecated. Instead add a `precs` function to your algorithm."
183+
end
170184
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
171185
assumptions)
172186
isfresh = true
@@ -179,6 +193,33 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
179193
return cache
180194
end
181195

196+
197+
function SciMLBase.reinit!(cache::LinearCache;
198+
A = nothing,
199+
b = cache.b,
200+
u = cache.u,
201+
p = nothing,)
202+
(; alg, cacheval, isfresh, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache
203+
204+
precs = hasproperty(alg, :precs) ? alg.precs : DEFAULT_PRECS
205+
Pl, Pr = if isnothing(A) || isnothing(p)
206+
(cache.Pl, cache.Pr)
207+
else
208+
if isnothing(A)
209+
A = cache.A
210+
end
211+
if isnothing(p)
212+
p = cache.p
213+
end
214+
precs(A, p)
215+
end
216+
217+
return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval),
218+
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq),
219+
typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
220+
maxiters, verbose, assumptions, sensealg)
221+
end
222+
182223
function SciMLBase.solve(prob::LinearProblem, args...; kwargs...)
183224
return solve(prob, nothing, args...; kwargs...)
184225
end

src/iterative_wrappers.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@ KrylovJL(args...; KrylovAlg = Krylov.gmres!,
1010
1111
A generic wrapper over the Krylov.jl krylov-subspace iterative solvers.
1212
"""
13-
struct KrylovJL{F, I, A, K} <: AbstractKrylovSubspaceMethod
13+
struct KrylovJL{F, I, A, P, K} <: AbstractKrylovSubspaceMethod
1414
KrylovAlg::F
1515
gmres_restart::I
1616
window::I
17+
precs::P
1718
args::A
1819
kwargs::K
1920
end
2021

2122
function KrylovJL(args...; KrylovAlg = Krylov.gmres!,
2223
gmres_restart = 0, window = 0,
24+
precs = DEFAULT_PRECS,
2325
kwargs...)
2426
return KrylovJL(KrylovAlg, gmres_restart, window,
25-
args, kwargs)
27+
precs, args, kwargs)
2628
end
2729

2830
default_alias_A(::KrylovJL, ::Any, ::Any) = true
@@ -231,8 +233,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
231233
cache.isfresh = false
232234
end
233235

234-
M = cache.Pl
235-
N = cache.Pr
236+
M, N = cache.Pl, cache.Pr
236237

237238
# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
238239
M = _isidentity_struct(M) ? I : M
@@ -258,25 +259,21 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
258259
end
259260

260261
args = (cacheval, cache.A, cache.b)
261-
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
262+
kwargs = (atol = atol, rtol, itmax, verbose,
262263
ldiv = true, history = true, alg.kwargs...)
263264

264265
if cache.cacheval isa Krylov.CgSolver
265266
N !== I &&
266267
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
267-
Krylov.solve!(args...; M = M,
268-
kwargs...)
268+
Krylov.solve!(args...; M, kwargs...)
269269
elseif cache.cacheval isa Krylov.GmresSolver
270-
Krylov.solve!(args...; M = M, N = N, restart = alg.gmres_restart > 0,
271-
kwargs...)
270+
Krylov.solve!(args...; M, N, restart = alg.gmres_restart > 0, kwargs...)
272271
elseif cache.cacheval isa Krylov.BicgstabSolver
273-
Krylov.solve!(args...; M = M, N = N,
274-
kwargs...)
272+
Krylov.solve!(args...; M, N, kwargs...)
275273
elseif cache.cacheval isa Krylov.MinresSolver
276274
N !== I &&
277275
@warn "$(alg.KrylovAlg) doesn't support right preconditioning."
278-
Krylov.solve!(args...; M = M,
279-
kwargs...)
276+
Krylov.solve!(args...; M, kwargs...)
280277
else
281278
Krylov.solve!(args...; kwargs...)
282279
end

0 commit comments

Comments
 (0)