-
-
Notifications
You must be signed in to change notification settings - Fork 61
Remove ipiv allocation from GenericLUFactorization #618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c004a2
7332ca1
d51e5d0
7e8c2fc
d315d13
62cdcbf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,3 +1,25 @@ | ||||||||
@generated function SciMLBase.solve!(cache::LinearCache, alg::AbstractFactorization; | ||||||||
kwargs...) | ||||||||
quote | ||||||||
if cache.isfresh | ||||||||
fact = do_factorization(alg, cache.A, cache.b, cache.u) | ||||||||
cache.cacheval = fact | ||||||||
|
||||||||
# If factorization was not successful, return failure. Don't reset `isfresh` | ||||||||
if _notsuccessful(fact) | ||||||||
return SciMLBase.build_linear_solution( | ||||||||
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) | ||||||||
end | ||||||||
|
||||||||
cache.isfresh = false | ||||||||
end | ||||||||
|
||||||||
y = _ldiv!(cache.u, @get_cacheval(cache, $(Meta.quot(defaultalg_symbol(alg)))), | ||||||||
cache.b) | ||||||||
return SciMLBase.build_linear_solution(alg, y, nothing, cache; retcode = ReturnCode.Success) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
macro get_cacheval(cache, algsym) | ||||||||
quote | ||||||||
if $(esc(cache)).alg isa DefaultLinearSolver | ||||||||
|
@@ -8,6 +30,8 @@ macro get_cacheval(cache, algsym) | |||||||
end | ||||||||
end | ||||||||
|
||||||||
const PREALLOCATED_IPIV = Vector{LinearAlgebra.BlasInt}(undef, 0) | ||||||||
|
||||||||
_ldiv!(x, A, b) = ldiv!(x, A, b) | ||||||||
|
||||||||
_ldiv!(x, A, b::SVector) = (x .= A \ b) | ||||||||
|
@@ -41,8 +65,7 @@ function LinearSolve.init_cacheval( | |||||||
alg::RFLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr, | ||||||||
maxiters::Int, | ||||||||
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) | ||||||||
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0) | ||||||||
PREALLOCATED_LU, ipiv | ||||||||
PREALLOCATED_LU, PREALLOCATED_IPIV | ||||||||
end | ||||||||
|
||||||||
function LinearSolve.init_cacheval(alg::RFLUFactorization, | ||||||||
|
@@ -144,41 +167,85 @@ function do_factorization(alg::LUFactorization, A, b, u) | |||||||
return fact | ||||||||
end | ||||||||
|
||||||||
function do_factorization(alg::GenericLUFactorization, A, b, u) | ||||||||
function init_cacheval( | ||||||||
alg::GenericLUFactorization, A, b, u, Pl, Pr, | ||||||||
maxiters::Int, abstol, reltol, verbose::Bool, | ||||||||
assumptions::OperatorAssumptions) | ||||||||
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)) | ||||||||
ArrayInterface.lu_instance(convert(AbstractMatrix, A)), ipiv | ||||||||
end | ||||||||
|
||||||||
function init_cacheval( | ||||||||
alg::GenericLUFactorization, A::Matrix{Float64}, b, u, Pl, Pr, | ||||||||
maxiters::Int, abstol, reltol, verbose::Bool, | ||||||||
assumptions::OperatorAssumptions) | ||||||||
PREALLOCATED_LU, PREALLOCATED_IPIV | ||||||||
end | ||||||||
|
||||||||
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::GenericLUFactorization; | ||||||||
kwargs...) | ||||||||
A = cache.A | ||||||||
A = convert(AbstractMatrix, A) | ||||||||
fact = LinearAlgebra.generic_lufact!(A, alg.pivot, check = false) | ||||||||
return fact | ||||||||
fact, ipiv = LinearSolve.@get_cacheval(cache, :GenericLUFactorization) | ||||||||
|
||||||||
if cache.isfresh | ||||||||
if length(ipiv) != min(size(A)...) | ||||||||
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)) | ||||||||
end | ||||||||
fact = generic_lufact!(A, alg.pivot, ipiv; check = false) | ||||||||
cache.cacheval = (fact, ipiv) | ||||||||
|
||||||||
if !LinearAlgebra.issuccess(fact) | ||||||||
return SciMLBase.build_linear_solution( | ||||||||
alg, cache.u, nothing, cache; retcode = ReturnCode.Failure) | ||||||||
end | ||||||||
|
||||||||
cache.isfresh = false | ||||||||
end | ||||||||
y = ldiv!(cache.u, LinearSolve.@get_cacheval(cache, :GenericLUFactorization)[1], cache.b) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||
SciMLBase.build_linear_solution(alg, y, nothing, cache) | ||||||||
end | ||||||||
|
||||||||
function init_cacheval( | ||||||||
alg::Union{LUFactorization, GenericLUFactorization}, A, b, u, Pl, Pr, | ||||||||
alg::LUFactorization, A, b, u, Pl, Pr, | ||||||||
maxiters::Int, abstol, reltol, verbose::Bool, | ||||||||
assumptions::OperatorAssumptions) | ||||||||
ArrayInterface.lu_instance(convert(AbstractMatrix, A)) | ||||||||
end | ||||||||
|
||||||||
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, | ||||||||
function init_cacheval(alg::LUFactorization, | ||||||||
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, | ||||||||
verbose::Bool, assumptions::OperatorAssumptions) | ||||||||
error_no_cudss_lu(A) | ||||||||
if alg isa LUFactorization | ||||||||
return lu(A; check = false) | ||||||||
else | ||||||||
A isa GPUArraysCore.AnyGPUArray && return nothing | ||||||||
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false) | ||||||||
end | ||||||||
return lu(A; check = false) | ||||||||
end | ||||||||
|
||||||||
function init_cacheval(alg::GenericLUFactorization, | ||||||||
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, | ||||||||
verbose::Bool, assumptions::OperatorAssumptions) | ||||||||
error_no_cudss_lu(A) | ||||||||
A isa GPUArraysCore.AnyGPUArray && return nothing | ||||||||
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0) | ||||||||
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false), ipiv | ||||||||
end | ||||||||
|
||||||||
const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1)) | ||||||||
|
||||||||
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, | ||||||||
function init_cacheval(alg::LUFactorization, | ||||||||
A::Matrix{Float64}, b, u, Pl, Pr, | ||||||||
maxiters::Int, abstol, reltol, verbose::Bool, | ||||||||
assumptions::OperatorAssumptions) | ||||||||
PREALLOCATED_LU | ||||||||
end | ||||||||
|
||||||||
function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, | ||||||||
function init_cacheval(alg::LUFactorization, | ||||||||
A::AbstractSciMLOperator, b, u, Pl, Pr, | ||||||||
maxiters::Int, abstol, reltol, verbose::Bool, | ||||||||
assumptions::OperatorAssumptions) | ||||||||
nothing | ||||||||
end | ||||||||
|
||||||||
function init_cacheval(alg::GenericLUFactorization, | ||||||||
A::AbstractSciMLOperator, b, u, Pl, Pr, | ||||||||
maxiters::Int, abstol, reltol, verbose::Bool, | ||||||||
assumptions::OperatorAssumptions) | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,134 @@ | ||||||||||||||||||
# From LinearAlgebra.lu.jl | ||||||||||||||||||
# Modified to be non-allocating | ||||||||||||||||||
@static if VERSION < v"1.11" | ||||||||||||||||||
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T), | ||||||||||||||||||
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)); | ||||||||||||||||||
check::Bool = true, allowsingular::Bool = false) where {T} | ||||||||||||||||||
Comment on lines
+4
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
check && LinearAlgebra.LAPACK.chkfinite(A) | ||||||||||||||||||
# Extract values | ||||||||||||||||||
m, n = size(A) | ||||||||||||||||||
minmn = min(m,n) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
# Initialize variables | ||||||||||||||||||
info = 0 | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
@inbounds begin | ||||||||||||||||||
for k = 1:minmn | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# find index max | ||||||||||||||||||
kp = k | ||||||||||||||||||
if pivot === LinearAlgebra.RowMaximum() && k < m | ||||||||||||||||||
amax = abs(A[k, k]) | ||||||||||||||||||
for i = k+1:m | ||||||||||||||||||
absi = abs(A[i,k]) | ||||||||||||||||||
Comment on lines
+21
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
if absi > amax | ||||||||||||||||||
kp = i | ||||||||||||||||||
amax = absi | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
elseif pivot === LinearAlgebra.RowNonZero() | ||||||||||||||||||
for i = k:m | ||||||||||||||||||
if !iszero(A[i,k]) | ||||||||||||||||||
Comment on lines
+29
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
kp = i | ||||||||||||||||||
break | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
ipiv[k] = kp | ||||||||||||||||||
if !iszero(A[kp,k]) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
if k != kp | ||||||||||||||||||
# Interchange | ||||||||||||||||||
for i = 1:n | ||||||||||||||||||
tmp = A[k,i] | ||||||||||||||||||
A[k,i] = A[kp,i] | ||||||||||||||||||
A[kp,i] = tmp | ||||||||||||||||||
Comment on lines
+40
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
# Scale first column | ||||||||||||||||||
Akkinv = inv(A[k,k]) | ||||||||||||||||||
for i = k+1:m | ||||||||||||||||||
A[i,k] *= Akkinv | ||||||||||||||||||
Comment on lines
+47
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
elseif info == 0 | ||||||||||||||||||
info = k | ||||||||||||||||||
end | ||||||||||||||||||
# Update the rest | ||||||||||||||||||
for j = k+1:n | ||||||||||||||||||
for i = k+1:m | ||||||||||||||||||
A[i,j] -= A[i,k]*A[k,j] | ||||||||||||||||||
Comment on lines
+55
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
check && LinearAlgebra.checknonsingular(info, pivot) | ||||||||||||||||||
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info)) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
elseif VERSION < v"1.13" | ||||||||||||||||||
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T), | ||||||||||||||||||
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)); | ||||||||||||||||||
check::Bool = true, allowsingular::Bool = false) where {T} | ||||||||||||||||||
Comment on lines
+66
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
check && LAPACK.chkfinite(A) | ||||||||||||||||||
# Extract values | ||||||||||||||||||
m, n = size(A) | ||||||||||||||||||
minmn = min(m,n) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
# Initialize variables | ||||||||||||||||||
info = 0 | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
@inbounds begin | ||||||||||||||||||
for k = 1:minmn | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
# find index max | ||||||||||||||||||
kp = k | ||||||||||||||||||
if pivot === LinearAlgebra.RowMaximum() && k < m | ||||||||||||||||||
amax = abs(A[k, k]) | ||||||||||||||||||
for i = k+1:m | ||||||||||||||||||
absi = abs(A[i,k]) | ||||||||||||||||||
Comment on lines
+83
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
if absi > amax | ||||||||||||||||||
kp = i | ||||||||||||||||||
amax = absi | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
elseif pivot === LinearAlgebra.RowNonZero() | ||||||||||||||||||
for i = k:m | ||||||||||||||||||
if !iszero(A[i,k]) | ||||||||||||||||||
Comment on lines
+91
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
kp = i | ||||||||||||||||||
break | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
ipiv[k] = kp | ||||||||||||||||||
if !iszero(A[kp,k]) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
if k != kp | ||||||||||||||||||
# Interchange | ||||||||||||||||||
for i = 1:n | ||||||||||||||||||
tmp = A[k,i] | ||||||||||||||||||
A[k,i] = A[kp,i] | ||||||||||||||||||
A[kp,i] = tmp | ||||||||||||||||||
Comment on lines
+102
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
# Scale first column | ||||||||||||||||||
Akkinv = inv(A[k,k]) | ||||||||||||||||||
for i = k+1:m | ||||||||||||||||||
A[i,k] *= Akkinv | ||||||||||||||||||
Comment on lines
+109
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
elseif info == 0 | ||||||||||||||||||
info = k | ||||||||||||||||||
end | ||||||||||||||||||
# Update the rest | ||||||||||||||||||
for j = k+1:n | ||||||||||||||||||
for i = k+1:m | ||||||||||||||||||
A[i,j] -= A[i,k]*A[k,j] | ||||||||||||||||||
Comment on lines
+117
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
end | ||||||||||||||||||
if pivot === LinearAlgebra.NoPivot() | ||||||||||||||||||
# Use a negative value to distinguish a failed factorization (zero in pivot | ||||||||||||||||||
# position during unpivoted LU) from a valid but rank-deficient factorization | ||||||||||||||||||
info = -info | ||||||||||||||||||
end | ||||||||||||||||||
check && LinearAlgebra._check_lu_success(info, allowsingular) | ||||||||||||||||||
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info)) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||
end | ||||||||||||||||||
else | ||||||||||||||||||
generic_lufact!(args...; kwargs...) = LinearAlgebra.generic_lufact!(args...; kwargs...) | ||||||||||||||||||
end | ||||||||||||||||||
Comment on lines
+132
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶