Skip to content

Commit 7685996

Browse files
authored
Merge pull request #2393 from chriselrod/getcoeffchunk
Getcoeffchunk
2 parents 1b11b47 + 0b1092d commit 7685996

File tree

3 files changed

+118
-28
lines changed

3 files changed

+118
-28
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2020
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
21+
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
2122
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2223
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
2324
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -74,6 +75,7 @@ Distributions = "0.23, 0.24, 0.25"
7475
DocStringExtensions = "0.7, 0.8, 0.9"
7576
DomainSets = "0.6"
7677
DynamicQuantities = "^0.11.2"
78+
FindFirstFunctions = "1"
7779
ForwardDiff = "0.10.3"
7880
FunctionWrappersWrappers = "0.1"
7981
Graphs = "1.5.2"

src/systems/alias_elimination.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ the `constraint`.
153153
mask,
154154
constraint)
155155
eadj = M.row_cols
156-
for i in range
156+
@inbounds for i in range
157157
vertices = eadj[i]
158158
if constraint(length(vertices))
159159
for (j, v) in enumerate(vertices)
@@ -170,7 +170,7 @@ end
170170
range,
171171
mask,
172172
constraint)
173-
for i in range
173+
@inbounds for i in range
174174
row = @view M[i, :]
175175
if constraint(count(!iszero, row))
176176
for (v, val) in enumerate(row)
@@ -382,13 +382,6 @@ end
382382

383383
swap!(v, i, j) = v[i], v[j] = v[j], v[i]
384384

385-
function getcoeff(vars, coeffs, var)
386-
for (vj, v) in enumerate(vars)
387-
v == var && return coeffs[vj]
388-
end
389-
return 0
390-
end
391-
392385
"""
393386
$(SIGNATURES)
394387

src/systems/sparsematrixclil.jl

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ end
129129
# build something that works for us here and worry about it later.
130130
nonzerosmap(a::CLILVector) = NonZeros(a)
131131

132+
using FindFirstFunctions: findfirstequal
133+
132134
function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swapto, pivot,
133135
last_pivot; pivot_equal_optimization = true)
134136
# for ei in nzrows(>= k)
@@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
168170
# conservative, we leave it at this, as this captures the most important
169171
# case for MTK (where most pivots are `1` or `-1`).
170172
pivot_equal = pivot_equal_optimization && abs(pivot) == abs(last_pivot)
171-
172-
for ei in (k + 1):size(M, 1)
173+
@inbounds for ei in (k + 1):size(M, 1)
173174
# eliminate `v`
174175
coeff = 0
175176
ivars = eadj[ei]
176-
vj = findfirst(isequal(vpivot), ivars)
177+
vj = findfirstequal(vpivot, ivars)
177178
if vj !== nothing
178179
coeff = old_cadj[ei][vj]
179180
deleteat!(old_cadj[ei], vj)
@@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
189190
ivars = eadj[ei]
190191
icoeffs = old_cadj[ei]
191192

192-
tmp_incidence = similar(eadj[ei], 0)
193-
tmp_coeffs = similar(old_cadj[ei], 0)
194-
# TODO: We know both ivars and kvars are sorted, we could just write
195-
# a quick iterator here that does this without allocation/faster.
196-
vars = sort(union(ivars, kvars))
197-
198-
for v in vars
199-
v == vpivot && continue
200-
ck = getcoeff(kvars, kcoeffs, v)
201-
ci = getcoeff(ivars, icoeffs, v)
202-
p1 = Base.Checked.checked_mul(pivot, ci)
203-
p2 = Base.Checked.checked_mul(coeff, ck)
204-
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
205-
if !iszero(ci)
206-
push!(tmp_incidence, v)
207-
push!(tmp_coeffs, ci)
193+
numkvars = length(kvars)
194+
numivars = length(ivars)
195+
tmp_incidence = similar(eadj[ei], numkvars + numivars)
196+
tmp_coeffs = similar(old_cadj[ei], numkvars + numivars)
197+
tmp_len = 0
198+
kvind = ivind = 0
199+
if _debug_mode
200+
# in debug mode, we at least check to confirm we're iterating over
201+
# `v`s in the correct order
202+
vars = sort(union(ivars, kvars))
203+
vi = 0
204+
end
205+
if numivars > 0 && numkvars > 0
206+
kvv = kvars[kvind += 1]
207+
ivv = ivars[ivind += 1]
208+
dobreak = false
209+
while true
210+
if kvv == ivv
211+
v = kvv
212+
ck = kcoeffs[kvind]
213+
ci = icoeffs[ivind]
214+
kvind += 1
215+
ivind += 1
216+
if kvind > numkvars
217+
dobreak = true
218+
else
219+
kvv = kvars[kvind]
220+
end
221+
if ivind > numivars
222+
dobreak = true
223+
else
224+
ivv = ivars[ivind]
225+
end
226+
p1 = Base.Checked.checked_mul(pivot, ci)
227+
p2 = Base.Checked.checked_mul(coeff, ck)
228+
ci = exactdiv(Base.Checked.checked_sub(p1, p2), last_pivot)
229+
elseif kvv < ivv
230+
v = kvv
231+
ck = kcoeffs[kvind]
232+
kvind += 1
233+
if kvind > numkvars
234+
dobreak = true
235+
else
236+
kvv = kvars[kvind]
237+
end
238+
p2 = Base.Checked.checked_mul(coeff, ck)
239+
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
240+
else # kvv > ivv
241+
v = ivv
242+
ci = icoeffs[ivind]
243+
ivind += 1
244+
if ivind > numivars
245+
dobreak = true
246+
else
247+
ivv = ivars[ivind]
248+
end
249+
ci = exactdiv(Base.Checked.checked_mul(pivot, ci), last_pivot)
250+
end
251+
if _debug_mode
252+
@assert v == vars[vi += 1]
253+
end
254+
if v != vpivot && !iszero(ci)
255+
tmp_incidence[tmp_len += 1] = v
256+
tmp_coeffs[tmp_len] = ci
257+
end
258+
dobreak && break
259+
end
260+
elseif numkvars > 0
261+
ivind = 1
262+
kvv = kvars[kvind += 1]
263+
elseif numivars > 0
264+
kvind = 1
265+
ivv = ivars[ivind += 1]
266+
end
267+
if kvind <= numkvars
268+
v = kvv
269+
while true
270+
if _debug_mode
271+
@assert v == vars[vi += 1]
272+
end
273+
if v != vpivot
274+
ck = kcoeffs[kvind]
275+
p2 = Base.Checked.checked_mul(coeff, ck)
276+
ci = exactdiv(Base.Checked.checked_neg(p2), last_pivot)
277+
if !iszero(ci)
278+
tmp_incidence[tmp_len += 1] = v
279+
tmp_coeffs[tmp_len] = ci
280+
end
281+
end
282+
(kvind == numkvars) && break
283+
v = kvars[kvind += 1]
284+
end
285+
elseif ivind <= numivars
286+
v = ivv
287+
while true
288+
if _debug_mode
289+
@assert v == vars[vi += 1]
290+
end
291+
if v != vpivot
292+
p1 = Base.Checked.checked_mul(pivot, icoeffs[ivind])
293+
ci = exactdiv(p1, last_pivot)
294+
if !iszero(ci)
295+
tmp_incidence[tmp_len += 1] = v
296+
tmp_coeffs[tmp_len] = ci
297+
end
298+
end
299+
(ivind == numivars) && break
300+
v = ivars[ivind += 1]
208301
end
209302
end
303+
resize!(tmp_incidence, tmp_len)
304+
resize!(tmp_coeffs, tmp_len)
210305
eadj[ei] = tmp_incidence
211306
old_cadj[ei] = tmp_coeffs
212307
end

0 commit comments

Comments
 (0)