Skip to content

feat: static version of LiFukushimaLineSearch #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LineSearch"
uuid = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
authors = ["SciML"]
version = "0.1.2"
version = "0.1.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Expand All @@ -38,6 +39,7 @@ ReTestItems = "1.28.0"
ReverseDiff = "1.15.3"
SciMLBase = "2.53.1"
SciMLJacobianOperators = "0.1"
StaticArraysCore = "1.4"
Test = "1.10"
Tracker = "0.2.35"
Zygote = "0.6.71"
Expand Down
1 change: 1 addition & 0 deletions src/LineSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using MaybeInplace: @bb
using SciMLBase: SciMLBase, AbstractSciMLProblem, AbstractNonlinearProblem, ReturnCode,
NonlinearProblem, NonlinearLeastSquaresProblem, NonlinearFunction
using SciMLJacobianOperators: VecJacOperator, JacVecOperator
using StaticArraysCore: SArray

abstract type AbstractLineSearchAlgorithm end
abstract type AbstractLineSearchCache end
Expand Down
85 changes: 73 additions & 12 deletions src/li_fukushima.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,30 @@

A derivative-free line search and global convergence of Broyden-like method for nonlinear
equations [li2000derivative](@cite).

!!! tip

For static arrays and numbers if `nan_maxiters` is either `nothing` or `missing`,
we provide a fully non-allocating implementation of the algorithm, that can be used
inside GPU kernels. However, this particular version doesn't support `stats` and
`reinit!` and those will be ignored. Additionally, we fix the initial alpha for the
search to be `1`.
"""
@kwdef @concrete struct LiFukushimaLineSearch <: AbstractLineSearchAlgorithm
lambda_0 = 1
beta = 1 // 2
sigma_1 = 1 // 1000
sigma_2 = 1 // 1000
eta = 1 // 10
rho = 9 // 10
nan_maxiters::Int = 5
beta = 0.5
sigma_1 = 0.001
sigma_2 = 0.001
eta = 0.1
rho = 0.9
nan_maxiters <: Union{Missing, Nothing, Int} = 5
maxiters::Int = 100
end

@concrete mutable struct LiFukushimaLineSearchCache <: AbstractLineSearchCache
ϕ
f
p
internalnorm
u_cache
fu_cache
λ₀
Expand All @@ -30,15 +37,45 @@ end
η
ρ
α
nan_maxiters::Int
nan_maxiters <: Union{Missing, Nothing, Int}
maxiters::Int
stats <: Union{SciMLBase.NLStats, Nothing}
alg <: LiFukushimaLineSearch
end

@concrete struct StaticLiFukushimaLineSearchCache <: AbstractLineSearchCache
f
p
λ₀
β
σ₁
σ₂
η
ρ
maxiters::Int
end

function CommonSolve.init(
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch, fu, u;
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch,
fu::Union{SArray, Number}, u::Union{SArray, Number};
stats::Union{SciMLBase.NLStats, Nothing} = nothing, kwargs...)
if (alg.nan_maxiters === nothing || alg.nan_maxiters === missing) && stats === nothing
T = promote_type(eltype(fu), eltype(u))
return StaticLiFukushimaLineSearchCache(prob.f, prob.p, T(alg.lambda_0),
T(alg.beta), T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho),
alg.maxiters)
end
return generic_lifukushima_init(prob, alg, fu, u; stats, kwargs...)
end

function CommonSolve.init(
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch, fu, u; kwargs...)
return generic_lifukushima_init(prob, alg, fu, u; kwargs...)
end

function generic_lifukushima_init(
prob::AbstractNonlinearProblem, alg::LiFukushimaLineSearch,
fu, u; stats::Union{SciMLBase.NLStats, Nothing} = nothing, kwargs...)
@bb u_cache = similar(u)
@bb fu_cache = similar(fu)
T = promote_type(eltype(fu), eltype(u))
Expand All @@ -51,7 +88,7 @@ function CommonSolve.init(
end

return LiFukushimaLineSearchCache(
ϕ, prob.f, prob.p, T(1), u_cache, fu_cache, T(alg.lambda_0), T(alg.beta),
ϕ, prob.f, prob.p, u_cache, fu_cache, T(alg.lambda_0), T(alg.beta),
T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho), T(1), alg.nan_maxiters,
alg.maxiters, stats, alg)
end
Expand All @@ -74,7 +111,8 @@ function CommonSolve.solve!(cache::LiFukushimaLineSearchCache, u, du)
λ₂, λ₁ = cache.λ₀, cache.λ₀
fxλp_norm = ϕ(λ₂)

if !isfinite(fxλp_norm)
if !isfinite(fxλp_norm) && cache.nan_maxiters !== nothing &&
cache.nan_maxiters !== missing
nan_converged = false
for _ in 1:(cache.nan_maxiters)
λ₁, λ₂ = λ₂, cache.β * λ₂
Expand All @@ -85,7 +123,7 @@ function CommonSolve.solve!(cache::LiFukushimaLineSearchCache, u, du)
nan_converged || return LineSearchSolution(cache.α, ReturnCode.Failure)
end

for i in 1:(cache.maxiters)
for _ in 1:(cache.maxiters)
fxλp_norm = ϕ(λ₂)
converged = fxλp_norm ≤ (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
converged && return LineSearchSolution(λ₂, ReturnCode.Success)
Expand All @@ -95,6 +133,29 @@ function CommonSolve.solve!(cache::LiFukushimaLineSearchCache, u, du)
return LineSearchSolution(cache.α, ReturnCode.Failure)
end

function CommonSolve.solve!(cache::StaticLiFukushimaLineSearchCache, u, du)
T = promote_type(eltype(du), eltype(u))

fx_norm = norm(cache.f(u, cache.p))
du_norm = norm(du)
fxλ_norm = norm(cache.f(u .+ du, cache.p))

if fxλ_norm ≤ cache.ρ * fx_norm - cache.σ₂ * du_norm^2
return LineSearchSolution(T(true), ReturnCode.Success)
end

λ₂, λ₁ = cache.λ₀, cache.λ₀

for _ in 1:(cache.maxiters)
fxλp_norm = norm(cache.f(u .+ λ₂ .* du, cache.p))
converged = fxλp_norm ≤ (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
converged && return LineSearchSolution(λ₂, ReturnCode.Success)
λ₁, λ₂ = λ₂, cache.β * λ₂
end

return LineSearchSolution(T(true), ReturnCode.Failure)
end

function SciMLBase.reinit!(
cache::LiFukushimaLineSearchCache; p = missing, stats = missing, kwargs...)
p !== missing && (cache.p = p)
Expand Down
Loading