diff --git a/ext/LinearSolveSparseArraysExt.jl b/ext/LinearSolveSparseArraysExt.jl index 9e1a632e0..d04e8117f 100644 --- a/ext/LinearSolveSparseArraysExt.jl +++ b/ext/LinearSolveSparseArraysExt.jl @@ -64,7 +64,15 @@ const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0 Int[], Float64[])) function LinearSolve.init_cacheval( - alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{<:Number, <:Integer}, b, u, + alg::LUFactorization, A::AbstractSparseArray{<:Number, <:Integer}, b, u, + Pl, Pr, + maxiters::Int, abstol, reltol, + verbose::Bool, assumptions::OperatorAssumptions) + nothing +end + +function LinearSolve.init_cacheval( + alg::GenericLUFactorization, A::AbstractSparseArray{<:Number, <:Integer}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -80,7 +88,7 @@ function LinearSolve.init_cacheval( end function LinearSolve.init_cacheval( - alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{Float64, Int64}, b, u, + alg::LUFactorization, A::AbstractSparseArray{Float64, Int64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -88,7 +96,7 @@ function LinearSolve.init_cacheval( end function LinearSolve.init_cacheval( - alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{T, Int64}, b, u, + alg::LUFactorization, A::AbstractSparseArray{T, Int64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES} @@ -96,7 +104,7 @@ function LinearSolve.init_cacheval( end function LinearSolve.init_cacheval( - alg::Union{LUFactorization, GenericLUFactorization}, A::AbstractSparseArray{T, Int32}, b, u, + alg::LUFactorization, A::AbstractSparseArray{T, Int32}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) where {T<:BLASELTYPES} diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 63a58a0c3..c7680903b 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -140,6 +140,7 @@ end const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64} +include("generic_lufact.jl") include("common.jl") include("extension_algs.jl") include("factorization.jl") @@ -171,28 +172,6 @@ end @inline _notsuccessful(F) = hasmethod(LinearAlgebra.issuccess, (typeof(F),)) ? !LinearAlgebra.issuccess(F) : false -@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization; - kwargs...) - quote - if cache.isfresh - fact = do_factorization(alg, cache.A, cache.b, cache.u) - cache.cacheval = fact - - # If factorization was not successful, return failure. Don't reset `isfresh` - if _notsuccessful(fact) - return SciMLBase.build_linear_solution( - alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) - end - - cache.isfresh = false - end - - y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))), - cache.b) - return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success) - end -end - # Solver Specific Traits ## Needs Square Matrix """ diff --git a/src/factorization.jl b/src/factorization.jl index 84ac5a41d..8d03a5d2c 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -1,3 +1,25 @@ +@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization; + kwargs...) + quote + if cache.isfresh + fact = do_factorization(alg, cache.A, cache.b, cache.u) + cache.cacheval = fact + + # If factorization was not successful, return failure. Don't reset `isfresh` + if _notsuccessful(fact) + return SciMLBase.build_linear_solution( + alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) + end + + cache.isfresh = false + end + + y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))), + cache.b) + return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success) + end +end + macro get_cacheval(cache, algsym) quote if $(esc(cache)).alg isa DefaultLinearSolver @@ -8,6 +30,8 @@ macro get_cacheval(cache, algsym) end end +const PREALLOCATED_IPIV = Vector{LinearAlgebra.BlasInt}(undef, 0) + _ldiv!(x, A, b) = ldiv!(x, A, b) _ldiv!(x, A, b::SVector) = (x .= A \ b) @@ -41,8 +65,7 @@ function LinearSolve.init_cacheval( alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0) - PREALLOCATED_LU, ipiv + PREALLOCATED_LU, PREALLOCATED_IPIV end function LinearSolve.init_cacheval(alg::RFLUFactorization, @@ -144,41 +167,85 @@ function do_factorization(alg::LUFactorization, A, b, u) return fact end -function do_factorization(alg::GenericLUFactorization, A, b, u) +function init_cacheval( + alg::GenericLUFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)) + ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv +end + +function init_cacheval( + alg::GenericLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + PREALLOCATED_LU, PREALLOCATED_IPIV +end + +function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::GenericLUFactorization; + kwargs...) + A = cache.A A = convert(AbstractMatrix, A) - fact = LinearAlgebra.generic_lufact!(A, alg.pivot, check = false) - return fact + fact, ipiv = LinearSolve.@get_cacheval(cache, :GenericLUFactorization) + + if cache.isfresh + if length(ipiv) != min(size(A)...) + ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)) + end + fact = generic_lufact!(A, alg.pivot, ipiv; check = false) + cache.cacheval = (fact, ipiv) + + if !LinearAlgebra.issuccess(fact) + return SciMLBase.build_linear_solution( + alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) + end + + cache.isfresh = false + end + y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :GenericLUFactorization)[1], cache.b) + SciMLBase.build_linear_solution(alg, y, nothing, cache) end function init_cacheval( - alg::Union{LUFactorization, GenericLUFactorization}, A, b, u, Pl, Pr, + alg::LUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ArrayInterface.lu_instance(convert(AbstractMatrix, A)) end -function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, +function init_cacheval(alg::LUFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) error_no_cudss_lu(A) - if alg isa LUFactorization - return lu(A; check = false) - else - A isa GPUArraysCore.AnyGPUArray && return nothing - return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false) - end + return lu(A; check = false) +end + +function init_cacheval(alg::GenericLUFactorization, + A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, + verbose::Bool, assumptions::OperatorAssumptions) + error_no_cudss_lu(A) + A isa GPUArraysCore.AnyGPUArray && return nothing + ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0) + return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false), ipiv end const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1)) -function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, +function init_cacheval(alg::LUFactorization, A::Matrix{Float64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) PREALLOCATED_LU end -function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, +function init_cacheval(alg::LUFactorization, + A::AbstractSciMLOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + nothing +end + +function init_cacheval(alg::GenericLUFactorization, A::AbstractSciMLOperator, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) diff --git a/src/generic_lufact.jl b/src/generic_lufact.jl new file mode 100644 index 000000000..36478376e --- /dev/null +++ b/src/generic_lufact.jl @@ -0,0 +1,134 @@ +# From LinearAlgebra.lu.jl +# Modified to be non-allocating +@static if VERSION < v"1.11" + function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T), + ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)); + check::Bool = true, allowsingular::Bool = false) where {T} + check && LinearAlgebra.LAPACK.chkfinite(A) + # Extract values + m, n = size(A) + minmn = min(m,n) + + # Initialize variables + info = 0 + + @inbounds begin + for k = 1:minmn + # find index max + kp = k + if pivot === LinearAlgebra.RowMaximum() && k < m + amax = abs(A[k, k]) + for i = k+1:m + absi = abs(A[i,k]) + if absi > amax + kp = i + amax = absi + end + end + elseif pivot === LinearAlgebra.RowNonZero() + for i = k:m + if !iszero(A[i,k]) + kp = i + break + end + end + end + ipiv[k] = kp + if !iszero(A[kp,k]) + if k != kp + # Interchange + for i = 1:n + tmp = A[k,i] + A[k,i] = A[kp,i] + A[kp,i] = tmp + end + end + # Scale first column + Akkinv = inv(A[k,k]) + for i = k+1:m + A[i,k] *= Akkinv + end + elseif info == 0 + info = k + end + # Update the rest + for j = k+1:n + for i = k+1:m + A[i,j] -= A[i,k]*A[k,j] + end + end + end + end + check && LinearAlgebra.checknonsingular(info, pivot) + return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info)) + end +elseif VERSION < v"1.13" + function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T), + ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)); + check::Bool = true, allowsingular::Bool = false) where {T} + check && LAPACK.chkfinite(A) + # Extract values + m, n = size(A) + minmn = min(m,n) + + # Initialize variables + info = 0 + + @inbounds begin + for k = 1:minmn + # find index max + kp = k + if pivot === LinearAlgebra.RowMaximum() && k < m + amax = abs(A[k, k]) + for i = k+1:m + absi = abs(A[i,k]) + if absi > amax + kp = i + amax = absi + end + end + elseif pivot === LinearAlgebra.RowNonZero() + for i = k:m + if !iszero(A[i,k]) + kp = i + break + end + end + end + ipiv[k] = kp + if !iszero(A[kp,k]) + if k != kp + # Interchange + for i = 1:n + tmp = A[k,i] + A[k,i] = A[kp,i] + A[kp,i] = tmp + end + end + # Scale first column + Akkinv = inv(A[k,k]) + for i = k+1:m + A[i,k] *= Akkinv + end + elseif info == 0 + info = k + end + # Update the rest + for j = k+1:n + for i = k+1:m + A[i,j] -= A[i,k]*A[k,j] + end + end + end + end + if pivot === LinearAlgebra.NoPivot() + # Use a negative value to distinguish a failed factorization (zero in pivot + # position during unpivoted LU) from a valid but rank-deficient factorization + info = -info + end + check && LinearAlgebra._check_lu_success(info, allowsingular) + return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info)) + end +else + generic_lufact!(args...; kwargs...) = LinearAlgebra.generic_lufact!(args...; kwargs...) +end \ No newline at end of file