-
-
Notifications
You must be signed in to change notification settings - Fork 61
Add column-pivoted QR factorization fallback on failed LU factorization #617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -119,8 +119,23 @@ EnumX.@enumx DefaultAlgorithmChoice begin | |||||
KrylovJL_LSMR | ||||||
end | ||||||
|
||||||
""" | ||||||
DefaultLinearSolver(;safetyfallback=true) | ||||||
|
||||||
The default linear solver. This is the algorithm chosen when `solve(prob)` | ||||||
is called. It's a polyalgorithm that detects the optimal method for a given | ||||||
`A, b` and hardware (Intel, AMD, GPU, etc.). | ||||||
|
||||||
## Keyword Arguments | ||||||
|
||||||
* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization | ||||||
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults | ||||||
to true. | ||||||
""" | ||||||
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm | ||||||
alg::DefaultAlgorithmChoice.T | ||||||
safetyfallback::Bool | ||||||
DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
|
||||||
const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64} | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -41,6 +41,11 @@ end | |||||||
ex = Expr(:if, ex.args...) | ||||||||
end | ||||||||
|
||||||||
# Handle special case of Column-pivoted QR fallback for LU | ||||||||
function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
setfield!(cache, :QRFactorizationPivoted, v) | ||||||||
end | ||||||||
|
||||||||
# Legacy fallback | ||||||||
# For SciML algorithms already using `defaultalg`, all assume square matrix. | ||||||||
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true)) | ||||||||
|
@@ -352,11 +357,32 @@ end | |||||||
kwargs...) | ||||||||
ex = :() | ||||||||
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) | ||||||||
newex = quote | ||||||||
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) | ||||||||
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; | ||||||||
retcode = sol.retcode, | ||||||||
iters = sol.iters, stats = sol.stats) | ||||||||
if alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, | ||||||||
DefaultAlgorithmChoice.RFLUFactorization, | ||||||||
DefaultAlgorithmChoice.MKLLUFactorization, | ||||||||
DefaultAlgorithmChoice.AppleAccelerateLUFactorization, | ||||||||
DefaultAlgorithmChoice.GenericLUFactorization)) | ||||||||
newex = quote | ||||||||
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) | ||||||||
if sol.retcode === ReturnCode.Failure && alg.safetyfallback | ||||||||
## TODO: Add verbosity logging here about using the fallback | ||||||||
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; | ||||||||
retcode = sol.retcode, | ||||||||
iters = sol.iters, stats = sol.stats) | ||||||||
else | ||||||||
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; | ||||||||
retcode = sol.retcode, | ||||||||
iters = sol.iters, stats = sol.stats) | ||||||||
end | ||||||||
end | ||||||||
else | ||||||||
newex = quote | ||||||||
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) | ||||||||
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; | ||||||||
retcode = sol.retcode, | ||||||||
iters = sol.iters, stats = sol.stats) | ||||||||
end | ||||||||
end | ||||||||
alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg) | ||||||||
ex = if ex == :() | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -158,4 +158,18 @@ prob = LinearProblem(A, b) | |||||||||||||
@test_broken SciMLBase.successful_retcode(solve(prob)) | ||||||||||||||
|
||||||||||||||
prob2 = LinearProblem(A2, b) | ||||||||||||||
@test SciMLBase.successful_retcode(solve(prob2)) | ||||||||||||||
@test SciMLBase.successful_retcode(solve(prob2)) | ||||||||||||||
|
||||||||||||||
# Column-Pivoted QR fallback on failed LU | ||||||||||||||
A = [1.0 0 0 0 | ||||||||||||||
0 1 0 0 | ||||||||||||||
0 0 1 0 | ||||||||||||||
Comment on lines
+165
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
0 0 0 0] | ||||||||||||||
b = rand(4) | ||||||||||||||
prob = LinearProblem(A, b) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false)) | ||||||||||||||
@test sol.retcode === ReturnCode.Failure | ||||||||||||||
@test sol.u == zeros(4) | ||||||||||||||
|
||||||||||||||
sol = solve(prob) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
@test sol.u ≈ svd(A)\b |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,24 +1,13 @@ | ||||||||
using LinearSolve, RecursiveFactorization | ||||||||
using LinearSolve, LinearAlgebra, RecursiveFactorization | ||||||||
|
||||||||
alglist = ( | ||||||||
LUFactorization, | ||||||||
QRFactorization, | ||||||||
DiagonalFactorization, | ||||||||
DirectLdiv!, | ||||||||
SparspakFactorization, | ||||||||
KLUFactorization, | ||||||||
UMFPACKFactorization, | ||||||||
KrylovJL_GMRES, | ||||||||
GenericLUFactorization, | ||||||||
RFLUFactorization, | ||||||||
LDLtFactorization, | ||||||||
BunchKaufmanFactorization, | ||||||||
CHOLMODFactorization, | ||||||||
SVDFactorization, | ||||||||
CholeskyFactorization, | ||||||||
NormalCholeskyFactorization, | ||||||||
AppleAccelerateLUFactorization, | ||||||||
MKLLUFactorization, | ||||||||
KrylovJL_CRAIGMR, | ||||||||
KrylovJL_LSMR | ||||||||
) | ||||||||
|
@@ -28,14 +17,23 @@ alglist = ( | |||||||
A = [2.0 1.0; -1.0 1.0] | ||||||||
b = [-1.0, 1.0] | ||||||||
prob = LinearProblem(A, b) | ||||||||
linsolve = init(prob, alg) | ||||||||
linsolve = init(prob, alg()) | ||||||||
sol = solve!(linsolve) | ||||||||
@test SciMLBase.successful_retcode(sol.retcode) || sol.retcode == ReturnCode.Default # The latter seems off... | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
lualgs = ( | ||||||||
LUFactorization(), | ||||||||
QRFactorization(), | ||||||||
GenericLUFactorization(), | ||||||||
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false), | ||||||||
RFLUFactorization(), | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
NormalCholeskyFactorization(), | ||||||||
) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
@testset "Failure" begin | ||||||||
for alg in alglist | ||||||||
for alg in lualgs | ||||||||
@show alg | ||||||||
A = [1.0 1.0; 1.0 1.0] | ||||||||
b = [-1.0, 1.0] | ||||||||
prob = LinearProblem(A, b) | ||||||||
|
@@ -44,3 +42,24 @@ end | |||||||
@test !SciMLBase.successful_retcode(sol.retcode) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
rankdeficientalgs = ( | ||||||||
QRFactorization(LinearAlgebra.ColumnNorm()), | ||||||||
KrylovJL_GMRES(), | ||||||||
SVDFactorization(), | ||||||||
KrylovJL_CRAIGMR(), | ||||||||
KrylovJL_LSMR(), | ||||||||
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) | ||||||||
) | ||||||||
|
||||||||
@testset "Rank Deficient Success" begin | ||||||||
for alg in rankdeficientalgs | ||||||||
@show alg | ||||||||
A = [1.0 1.0; 1.0 1.0] | ||||||||
b = [-1.0, 1.0] | ||||||||
prob = LinearProblem(A, b) | ||||||||
linsolve = init(prob, alg) | ||||||||
sol = solve!(linsolve) | ||||||||
@test SciMLBase.successful_retcode(sol.retcode) | ||||||||
end | ||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶