Skip to content

Add column-pivoted QR factorization fallback on failed LU factorization #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/solvers/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ use `Krylov_GMRES()`.

## Full List of Methods

### Polyalgorithms

```@docs
LinearSolve.DefaultLinearSolver
```

### RecursiveFactorization.jl

!!! note
Expand Down
15 changes: 15 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,23 @@ EnumX.@enumx DefaultAlgorithmChoice begin
KrylovJL_LSMR
end

"""
DefaultLinearSolver(;safetyfallback=true)

The default linear solver. This is the algorithm chosen when `solve(prob)`
is called. It's a polyalgorithm that detects the optimal method for a given
`A, b` and hardware (Intel, AMD, GPU, etc.).

## Keyword Arguments

* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
to true.
Comment on lines +131 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
to true.
- `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
to true.

"""
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
alg::DefaultAlgorithmChoice.T
safetyfallback::Bool
DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback)
DefaultLinearSolver(alg; safetyfallback = true) = new(alg, safetyfallback)

end

const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}
Expand Down
36 changes: 31 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ end
ex = Expr(:if, ex.args...)
end

# Handle special case of Column-pivoted QR fallback for LU
function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted)
function __setfield!(cache::DefaultLinearSolverInit,
alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted)

setfield!(cache, :QRFactorizationPivoted, v)
end

# Legacy fallback
# For SciML algorithms already using `defaultalg`, all assume square matrix.
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))
Expand Down Expand Up @@ -352,11 +357,32 @@ end
kwargs...)
ex = :()
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
if alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
DefaultAlgorithmChoice.RFLUFactorization,
DefaultAlgorithmChoice.MKLLUFactorization,
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
DefaultAlgorithmChoice.GenericLUFactorization))
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
## TODO: Add verbosity logging here about using the fallback
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)
sol = SciMLBase.solve!(
cache, QRFactorization(ColumnNorm()), args...; kwargs...)

SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
else
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
else
newex = quote
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
retcode = sol.retcode,
iters = sol.iters, stats = sol.stats)
end
end
alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg)
ex = if ex == :()
Expand Down
16 changes: 15 additions & 1 deletion test/default_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,18 @@ prob = LinearProblem(A, b)
@test_broken SciMLBase.successful_retcode(solve(prob))

prob2 = LinearProblem(A2, b)
@test SciMLBase.successful_retcode(solve(prob2))
@test SciMLBase.successful_retcode(solve(prob2))

# Column-Pivoted QR fallback on failed LU
A = [1.0 0 0 0
0 1 0 0
0 0 1 0
Comment on lines +165 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
0 1 0 0
0 0 1 0
0 0 0 0]
0 1 0 0
0 0 1 0
0 0 0 0]

0 0 0 0]
b = rand(4)
prob = LinearProblem(A, b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false))
sol = solve(prob,
LinearSolve.DefaultLinearSolver(
LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback = false))

sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false))
@test sol.retcode === ReturnCode.Failure
@test sol.u == zeros(4)

sol = solve(prob)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test sol.u svd(A)\b
@test sol.u svd(A) \ b

@test sol.u ≈ svd(A)\b
47 changes: 33 additions & 14 deletions test/retcodes.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
using LinearSolve, RecursiveFactorization
using LinearSolve, LinearAlgebra, RecursiveFactorization

alglist = (
LUFactorization,
QRFactorization,
DiagonalFactorization,
DirectLdiv!,
SparspakFactorization,
KLUFactorization,
UMFPACKFactorization,
KrylovJL_GMRES,
GenericLUFactorization,
RFLUFactorization,
LDLtFactorization,
BunchKaufmanFactorization,
CHOLMODFactorization,
SVDFactorization,
CholeskyFactorization,
NormalCholeskyFactorization,
AppleAccelerateLUFactorization,
MKLLUFactorization,
KrylovJL_CRAIGMR,
KrylovJL_LSMR
)
Expand All @@ -28,14 +17,23 @@ alglist = (
A = [2.0 1.0; -1.0 1.0]
b = [-1.0, 1.0]
prob = LinearProblem(A, b)
linsolve = init(prob, alg)
linsolve = init(prob, alg())
sol = solve!(linsolve)
@test SciMLBase.successful_retcode(sol.retcode) || sol.retcode == ReturnCode.Default # The latter seems off...
end
end

lualgs = (
LUFactorization(),
QRFactorization(),
GenericLUFactorization(),
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false),
RFLUFactorization(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false),
LinearSolve.DefaultLinearSolver(
LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback = false),

NormalCholeskyFactorization(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
NormalCholeskyFactorization(),
NormalCholeskyFactorization()

@testset "Failure" begin
for alg in alglist
for alg in lualgs
@show alg
A = [1.0 1.0; 1.0 1.0]
b = [-1.0, 1.0]
prob = LinearProblem(A, b)
Expand All @@ -44,3 +42,24 @@ end
@test !SciMLBase.successful_retcode(sol.retcode)
end
end

rankdeficientalgs = (
QRFactorization(LinearAlgebra.ColumnNorm()),
KrylovJL_GMRES(),
SVDFactorization(),
KrylovJL_CRAIGMR(),
KrylovJL_LSMR(),
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
)

@testset "Rank Deficient Success" begin
for alg in rankdeficientalgs
@show alg
A = [1.0 1.0; 1.0 1.0]
b = [-1.0, 1.0]
prob = LinearProblem(A, b)
linsolve = init(prob, alg)
sol = solve!(linsolve)
@test SciMLBase.successful_retcode(sol.retcode)
end
end
Loading