diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index 8b2d54c45..f6c828edb 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -82,6 +82,12 @@ use `Krylov_GMRES()`. ## Full List of Methods +### Polyalgorithms + +```@docs +LinearSolve.DefaultLinearSolver +``` + ### RecursiveFactorization.jl !!! note diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index af2d94eb4..63a58a0c3 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -119,8 +119,23 @@ EnumX.@enumx DefaultAlgorithmChoice begin KrylovJL_LSMR end +""" + DefaultLinearSolver(;safetyfallback=true) + +The default linear solver. This is the algorithm chosen when `solve(prob)` +is called. It's a polyalgorithm that detects the optimal method for a given +`A, b` and hardware (Intel, AMD, GPU, etc.). + +## Keyword Arguments + +* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization + when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults + to true. +""" struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm alg::DefaultAlgorithmChoice.T + safetyfallback::Bool + DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback) end const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64} diff --git a/src/default.jl b/src/default.jl index 545337b3f..acfe8201f 100644 --- a/src/default.jl +++ b/src/default.jl @@ -41,6 +41,11 @@ end ex = Expr(:if, ex.args...) end +# Handle special case of Column-pivoted QR fallback for LU +function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted) + setfield!(cache, :QRFactorizationPivoted, v) +end + # Legacy fallback # For SciML algorithms already using `defaultalg`, all assume square matrix. defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true)) @@ -352,11 +357,32 @@ end kwargs...) ex = :() for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) - newex = quote - sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) - SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; - retcode = sol.retcode, - iters = sol.iters, stats = sol.stats) + if alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, + DefaultAlgorithmChoice.RFLUFactorization, + DefaultAlgorithmChoice.MKLLUFactorization, + DefaultAlgorithmChoice.AppleAccelerateLUFactorization, + DefaultAlgorithmChoice.GenericLUFactorization)) + newex = quote + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) + if sol.retcode === ReturnCode.Failure && alg.safetyfallback + ## TODO: Add verbosity logging here about using the fallback + sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...) + SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; + retcode = sol.retcode, + iters = sol.iters, stats = sol.stats) + else + SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; + retcode = sol.retcode, + iters = sol.iters, stats = sol.stats) + end + end + else + newex = quote + sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...) + SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache; + retcode = sol.retcode, + iters = sol.iters, stats = sol.stats) + end end alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg) ex = if ex == :() diff --git a/test/default_algs.jl b/test/default_algs.jl index 58f30db50..cf0becfb8 100644 --- a/test/default_algs.jl +++ b/test/default_algs.jl @@ -158,4 +158,18 @@ prob = LinearProblem(A, b) @test_broken SciMLBase.successful_retcode(solve(prob)) prob2 = LinearProblem(A2, b) -@test SciMLBase.successful_retcode(solve(prob2)) \ No newline at end of file +@test SciMLBase.successful_retcode(solve(prob2)) + +# Column-Pivoted QR fallback on failed LU +A = [1.0 0 0 0 + 0 1 0 0 + 0 0 1 0 + 0 0 0 0] +b = rand(4) +prob = LinearProblem(A, b) +sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false)) +@test sol.retcode === ReturnCode.Failure +@test sol.u == zeros(4) + +sol = solve(prob) +@test sol.u ≈ svd(A)\b diff --git a/test/retcodes.jl b/test/retcodes.jl index 1e33e8adb..af59de6eb 100644 --- a/test/retcodes.jl +++ b/test/retcodes.jl @@ -1,24 +1,13 @@ -using LinearSolve, RecursiveFactorization +using LinearSolve, LinearAlgebra, RecursiveFactorization alglist = ( LUFactorization, QRFactorization, - DiagonalFactorization, - DirectLdiv!, - SparspakFactorization, - KLUFactorization, - UMFPACKFactorization, KrylovJL_GMRES, GenericLUFactorization, RFLUFactorization, - LDLtFactorization, - BunchKaufmanFactorization, - CHOLMODFactorization, SVDFactorization, - CholeskyFactorization, NormalCholeskyFactorization, - AppleAccelerateLUFactorization, - MKLLUFactorization, KrylovJL_CRAIGMR, KrylovJL_LSMR ) @@ -28,14 +17,23 @@ alglist = ( A = [2.0 1.0; -1.0 1.0] b = [-1.0, 1.0] prob = LinearProblem(A, b) - linsolve = init(prob, alg) + linsolve = init(prob, alg()) sol = solve!(linsolve) @test SciMLBase.successful_retcode(sol.retcode) || sol.retcode == ReturnCode.Default # The latter seems off... end end +lualgs = ( + LUFactorization(), + QRFactorization(), + GenericLUFactorization(), + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false), + RFLUFactorization(), + NormalCholeskyFactorization(), +) @testset "Failure" begin - for alg in alglist + for alg in lualgs + @show alg A = [1.0 1.0; 1.0 1.0] b = [-1.0, 1.0] prob = LinearProblem(A, b) @@ -44,3 +42,24 @@ end @test !SciMLBase.successful_retcode(sol.retcode) end end + +rankdeficientalgs = ( + QRFactorization(LinearAlgebra.ColumnNorm()), + KrylovJL_GMRES(), + SVDFactorization(), + KrylovJL_CRAIGMR(), + KrylovJL_LSMR(), + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) +) + +@testset "Rank Deficient Success" begin + for alg in rankdeficientalgs + @show alg + A = [1.0 1.0; 1.0 1.0] + b = [-1.0, 1.0] + prob = LinearProblem(A, b) + linsolve = init(prob, alg) + sol = solve!(linsolve) + @test SciMLBase.successful_retcode(sol.retcode) + end +end \ No newline at end of file