Skip to content

Commit cdde6e1

Browse files
committed
fix: forwarddiff support
1 parent 693949b commit cdde6e1

File tree

11 files changed

+203
-161
lines changed

11 files changed

+203
-161
lines changed

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ function SciMLBase.__solve(
3535
)
3636
end
3737

38-
linsolve = alg.ls === :qr ? LSO.QR() :
39-
(alg.ls === :cholesky ? LSO.Cholesky() :
40-
(alg.ls === :lsmr ? LSO.LSMR() : nothing))
38+
linsolve = alg.linsolve === :qr ? LSO.QR() :
39+
(alg.linsolve === :cholesky ? LSO.Cholesky() :
40+
(alg.linsolve === :lsmr ? LSO.LSMR() : nothing))
4141

4242
lso_solver = if alg.alg === :lm
4343
LSO.LevenbergMarquardt(linsolve)

ext/NonlinearSolveSundialsExt.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
module NonlinearSolveSundialsExt
22

3+
using Sundials: KINSOL
4+
5+
using CommonSolve: CommonSolve
36
using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve,
47
nonlinearsolve_dual_solution
5-
using NonlinearSolve: DualNonlinearProblem
8+
using NonlinearSolve: NonlinearSolve, DualNonlinearProblem
69
using SciMLBase: SciMLBase
7-
using Sundials: KINSOL
810

911
function SciMLBase.__solve(prob::DualNonlinearProblem, alg::KINSOL, args...; kwargs...)
1012
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
1113
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
1214
return SciMLBase.build_solution(
13-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
15+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
16+
)
17+
end
18+
19+
function SciMLBase.__init(prob::DualNonlinearProblem, alg::KINSOL, args...; kwargs...)
20+
p = NonlinearSolveBase.nodual_value(prob.p)
21+
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
22+
cache = CommonSolve.init(newprob, alg, args...; kwargs...)
23+
return NonlinearSolveForwardDiffCache(
24+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
25+
)
1426
end
1527

1628
end

lib/NonlinearSolveBase/src/linear_solve.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
6969
error("Default Julia Backsolve Operator `\\` doesn't support Preconditioners")
7070
return NativeJLLinearSolveCache(A, b, stats)
7171
elseif no_preconditioner && linsolve === nothing
72-
# Non-allocating linear solve exists in StaticArrays.jl
73-
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix}) &&
74-
Core.Compiler.return_type(\, Tuple{typeof(A), typeof(b)}) <: SArray
72+
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix})
7573
return NativeJLLinearSolveCache(A, b, stats)
7674
end
7775
end

lib/NonlinearSolveBase/src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function maybe_pinv!!(workspace, A::StridedMatrix)
222222
!issingular && return LinearAlgebra.tril!(parent(inv(A_)))
223223
else
224224
F = LinearAlgebra.lu(A; check = false)
225-
if issuccess(F)
225+
if LinearAlgebra.issuccess(F)
226226
Ai = LinearAlgebra.inv!(F)
227227
return convert(typeof(parent(Ai)), Ai)
228228
end

lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Setfield: @set!
77

88
using ADTypes: ADTypes
99
using ArrayInterface: ArrayInterface
10-
using LinearAlgebra: LinearAlgebra, Diagonal, dot
10+
using LinearAlgebra: LinearAlgebra, Diagonal, dot, diagind
1111
using StaticArraysCore: SArray
1212

1313
using CommonSolve: CommonSolve

lib/NonlinearSolveQuasiNewton/test/core_tests.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,14 @@ end
170170
LiFukushimaLineSearch()
171171
)
172172
@testset "[OOP] u0: $(typeof(u0))" for u0 in (ones(32), @SVector(ones(2)), 1.0)
173+
broken = Sys.iswindows() && u0 isa Vector{Float64} &&
174+
linesearch isa BackTracking && ad isa AutoFiniteDiff
175+
173176
solver = LimitedMemoryBroyden(; linesearch)
174177
sol = solve_oop(quadratic_f, u0; solver)
175-
@test SciMLBase.successful_retcode(sol)
178+
@test SciMLBase.successful_retcode(sol) broken=broken
176179
err = maximum(abs, quadratic_f(sol.u, 2.0))
177-
@test err < 1e-9
180+
@test err<1e-9 broken=broken
178181

179182
cache = init(
180183
NonlinearProblem{false}(quadratic_f, u0, 2.0), solver, abstol = 1e-9
@@ -185,11 +188,14 @@ end
185188
@testset "[IIP] u0: $(typeof(u0))" for u0 in (ones(32),)
186189
ad isa AutoZygote && continue
187190

191+
broken = Sys.iswindows() && u0 isa Vector{Float64} &&
192+
linesearch isa BackTracking && ad isa AutoFiniteDiff
193+
188194
solver = LimitedMemoryBroyden(; linesearch)
189195
sol = solve_iip(quadratic_f!, u0; solver)
190196
@test SciMLBase.successful_retcode(sol)
191197
err = maximum(abs, quadratic_f(sol.u, 2.0))
192-
@test err < 1e-9
198+
@test err<1e-9 broken=broken
193199

194200
cache = init(
195201
NonlinearProblem{true}(quadratic_f!, u0, 2.0), solver, abstol = 1e-9

src/NonlinearSolve.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ using NonlinearSolveQuasiNewton: Broyden, Klement
2828
using SimpleNonlinearSolve: SimpleBroyden, SimpleKlement
2929

3030
# Default AD Support
31-
using FiniteDiff: FiniteDiff # Default Finite Difference Method
32-
using ForwardDiff: ForwardDiff # Default Forward Mode AD
31+
using FiniteDiff: FiniteDiff # Default Finite Difference Method
32+
using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD
3333

3434
# Sparse AD Support: Implemented via extensions
3535
using SparseArrays: SparseArrays
@@ -39,9 +39,9 @@ using SparseMatrixColorings: SparseMatrixColorings
3939
using BracketingNonlinearSolve: BracketingNonlinearSolve
4040
using LineSearch: LineSearch
4141
using LinearSolve: LinearSolve
42-
using NonlinearSolveFirstOrder: NonlinearSolveFirstOrder
43-
using NonlinearSolveQuasiNewton: NonlinearSolveQuasiNewton
44-
using NonlinearSolveSpectralMethods: NonlinearSolveSpectralMethods
42+
using NonlinearSolveFirstOrder: NonlinearSolveFirstOrder, GeneralizedFirstOrderAlgorithm
43+
using NonlinearSolveQuasiNewton: NonlinearSolveQuasiNewton, QuasiNewtonAlgorithm
44+
using NonlinearSolveSpectralMethods: NonlinearSolveSpectralMethods, GeneralizedDFSane
4545
using SimpleNonlinearSolve: SimpleNonlinearSolve
4646

4747
const SII = SymbolicIndexingInterface
@@ -53,16 +53,16 @@ include("extension_algs.jl")
5353

5454
include("default.jl")
5555

56-
# const ALL_SOLVER_TYPES = [
57-
# Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
58-
# GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
59-
# LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
60-
# SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
61-
# CMINPACK, PETScSNES,
62-
# NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
63-
# ]
56+
const ALL_SOLVER_TYPES = [
57+
Nothing, AbstractNonlinearSolveAlgorithm,
58+
GeneralizedDFSane, GeneralizedFirstOrderAlgorithm, QuasiNewtonAlgorithm,
59+
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
60+
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
61+
CMINPACK, PETScSNES,
62+
NonlinearSolvePolyAlgorithm
63+
]
6464

65-
# include("internal/forward_diff.jl") # we need to define after the algorithms
65+
include("forward_diff.jl")
6666

6767
@setup_workload begin
6868
include("../common/nonlinear_problem_workloads.jl")

src/forward_diff.jl

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
1-
const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
2-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
1+
const DualNonlinearProblem = NonlinearProblem{
2+
<:Union{Number, <:AbstractArray}, iip,
3+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
4+
} where {iip, T, V, P}
35
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
46
<:Union{Number, <:AbstractArray}, iip,
5-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
7+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
8+
} where {iip, T, V, P}
69
const DualAbstractNonlinearProblem = Union{
7-
DualNonlinearProblem, DualNonlinearLeastSquaresProblem}
10+
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
11+
}
812

913
for algType in ALL_SOLVER_TYPES
1014
@eval function SciMLBase.__solve(
11-
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
12-
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
13-
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
15+
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
16+
)
17+
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
18+
prob, alg, args...; kwargs...
19+
)
20+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
1421
return SciMLBase.build_solution(
15-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
22+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
23+
)
1624
end
1725
end
1826

19-
@concrete mutable struct NonlinearSolveForwardDiffCache
27+
@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
2028
cache
2129
prob
2230
alg
@@ -25,36 +33,41 @@ end
2533
partials_p
2634
end
2735

28-
@internal_caches NonlinearSolveForwardDiffCache :cache
29-
30-
function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
31-
p = cache.p, u0 = get_u(cache.cache), kwargs...)
32-
inner_cache = reinit_cache!(cache.cache; p = __value(p), u0 = __value(u0), kwargs...)
36+
function InternalAPI.reinit!(
37+
cache::NonlinearSolveForwardDiffCache, args...;
38+
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
39+
)
40+
inner_cache = InternalAPI.reinit!(
41+
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
42+
)
3343
cache.cache = inner_cache
3444
cache.p = p
35-
cache.values_p = __value(p)
45+
cache.values_p = nodual_value(p)
3646
cache.partials_p = ForwardDiff.partials(p)
3747
return cache
3848
end
3949

4050
for algType in ALL_SOLVER_TYPES
51+
# XXX: Extend to DualNonlinearLeastSquaresProblem
4152
@eval function SciMLBase.__init(
42-
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
43-
p = __value(prob.p)
44-
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
53+
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...
54+
)
55+
p = nodual_value(prob.p)
56+
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
4557
cache = init(newprob, alg, args...; kwargs...)
4658
return NonlinearSolveForwardDiffCache(
47-
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
59+
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
60+
)
4861
end
4962
end
5063

51-
function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
64+
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
5265
sol = solve!(cache.cache)
5366
prob = cache.prob
5467

5568
uu = sol.u
56-
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
57-
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
69+
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
70+
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
5871

5972
z_arr = -Jᵤ \ Jₚ
6073

@@ -65,11 +78,12 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
6578
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
6679
end
6780

68-
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, cache.p)
81+
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
6982
return SciMLBase.build_solution(
70-
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
83+
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
84+
)
7185
end
7286

73-
@inline __value(x) = x
74-
@inline __value(x::Dual) = ForwardDiff.value(x)
75-
@inline __value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
87+
nodual_value(x) = x
88+
nodual_value(x::Dual) = ForwardDiff.value(x)
89+
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

test/23_test_problems_tests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,13 @@ end
106106
test_on_library(problems, dicts, alg_ops, broken_tests)
107107
end
108108

109-
@testitem "Broyden" setup=[RobustnessTesting] tags=[:core] begin
110-
alg_ops = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
109+
@testitem "Broyden" setup=[RobustnessTesting] tags=[:core] retries=3 begin
110+
alg_ops = (
111+
Broyden(),
112+
Broyden(; init_jacobian = Val(:true_jacobian)),
111113
Broyden(; update_rule = Val(:bad_broyden)),
112-
Broyden(; init_jacobian = Val(:true_jacobian), update_rule = Val(:bad_broyden)))
114+
Broyden(; init_jacobian = Val(:true_jacobian), update_rule = Val(:bad_broyden))
115+
)
113116

114117
broken_tests = Dict(alg => Int[] for alg in alg_ops)
115118
if Sys.isapple()

test/cuda_tests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
SOLVERS = (
1616
NewtonRaphson(),
1717
LevenbergMarquardt(; linsolve = QRFactorization()),
18-
# XXX: Fails currently
19-
# LevenbergMarquardt(; linsolve = KrylovJL_GMRES()),
18+
LevenbergMarquardt(; linsolve = KrylovJL_GMRES()),
2019
PseudoTransient(),
2120
Klement(),
2221
Broyden(; linesearch = LiFukushimaLineSearch()),

0 commit comments

Comments
 (0)