129
129
# build something that works for us here and worry about it later.
130
130
nonzerosmap (a:: CLILVector ) = NonZeros (a)
131
131
132
+ using FindFirstFunctions: findfirstequal
133
+
132
134
function bareiss_update_virtual_colswap_mtk! (zero!, M:: SparseMatrixCLIL , k, swapto, pivot,
133
135
last_pivot; pivot_equal_optimization = true )
134
136
# for ei in nzrows(>= k)
@@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
168
170
# conservative, we leave it at this, as this captures the most important
169
171
# case for MTK (where most pivots are `1` or `-1`).
170
172
pivot_equal = pivot_equal_optimization && abs (pivot) == abs (last_pivot)
171
-
172
- for ei in (k + 1 ): size (M, 1 )
173
+ @inbounds for ei in (k + 1 ): size (M, 1 )
173
174
# eliminate `v`
174
175
coeff = 0
175
176
ivars = eadj[ei]
176
- vj = findfirst ( isequal ( vpivot) , ivars)
177
+ vj = findfirstequal ( vpivot, ivars)
177
178
if vj != = nothing
178
179
coeff = old_cadj[ei][vj]
179
180
deleteat! (old_cadj[ei], vj)
@@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
189
190
ivars = eadj[ei]
190
191
icoeffs = old_cadj[ei]
191
192
192
- tmp_incidence = similar (eadj[ei], 0 )
193
- tmp_coeffs = similar (old_cadj[ei], 0 )
194
- # TODO : We know both ivars and kvars are sorted, we could just write
195
- # a quick iterator here that does this without allocation/faster.
196
- vars = sort (union (ivars, kvars))
197
-
198
- for v in vars
199
- v == vpivot && continue
200
- ck = getcoeff (kvars, kcoeffs, v)
201
- ci = getcoeff (ivars, icoeffs, v)
202
- p1 = Base. Checked. checked_mul (pivot, ci)
203
- p2 = Base. Checked. checked_mul (coeff, ck)
204
- ci = exactdiv (Base. Checked. checked_sub (p1, p2), last_pivot)
205
- if ! iszero (ci)
206
- push! (tmp_incidence, v)
207
- push! (tmp_coeffs, ci)
193
+ numkvars = length (kvars)
194
+ numivars = length (ivars)
195
+ tmp_incidence = similar (eadj[ei], numkvars + numivars)
196
+ tmp_coeffs = similar (old_cadj[ei], numkvars + numivars)
197
+ tmp_len = 0
198
+ kvind = ivind = 0
199
+ if _debug_mode
200
+ # in debug mode, we at least check to confirm we're iterating over
201
+ # `v`s in the correct order
202
+ vars = sort (union (ivars, kvars))
203
+ vi = 0
204
+ end
205
+ if numivars > 0 && numkvars > 0
206
+ kvv = kvars[kvind += 1 ]
207
+ ivv = ivars[ivind += 1 ]
208
+ dobreak = false
209
+ while true
210
+ if kvv == ivv
211
+ v = kvv
212
+ ck = kcoeffs[kvind]
213
+ ci = icoeffs[ivind]
214
+ kvind += 1
215
+ ivind += 1
216
+ if kvind > numkvars
217
+ dobreak = true
218
+ else
219
+ kvv = kvars[kvind]
220
+ end
221
+ if ivind > numivars
222
+ dobreak = true
223
+ else
224
+ ivv = ivars[ivind]
225
+ end
226
+ p1 = Base. Checked. checked_mul (pivot, ci)
227
+ p2 = Base. Checked. checked_mul (coeff, ck)
228
+ ci = exactdiv (Base. Checked. checked_sub (p1, p2), last_pivot)
229
+ elseif kvv < ivv
230
+ v = kvv
231
+ ck = kcoeffs[kvind]
232
+ kvind += 1
233
+ if kvind > numkvars
234
+ dobreak = true
235
+ else
236
+ kvv = kvars[kvind]
237
+ end
238
+ p2 = Base. Checked. checked_mul (coeff, ck)
239
+ ci = exactdiv (Base. Checked. checked_neg (p2), last_pivot)
240
+ else # kvv > ivv
241
+ v = ivv
242
+ ci = icoeffs[ivind]
243
+ ivind += 1
244
+ if ivind > numivars
245
+ dobreak = true
246
+ else
247
+ ivv = ivars[ivind]
248
+ end
249
+ ci = exactdiv (Base. Checked. checked_mul (pivot, ci), last_pivot)
250
+ end
251
+ if _debug_mode
252
+ @assert v == vars[vi += 1 ]
253
+ end
254
+ if v != vpivot && ! iszero (ci)
255
+ tmp_incidence[tmp_len += 1 ] = v
256
+ tmp_coeffs[tmp_len] = ci
257
+ end
258
+ dobreak && break
259
+ end
260
+ elseif numkvars > 0
261
+ ivind = 1
262
+ kvv = kvars[kvind += 1 ]
263
+ elseif numivars > 0
264
+ kvind = 1
265
+ ivv = ivars[ivind += 1 ]
266
+ end
267
+ if kvind <= numkvars
268
+ v = kvv
269
+ while true
270
+ if _debug_mode
271
+ @assert v == vars[vi += 1 ]
272
+ end
273
+ if v != vpivot
274
+ ck = kcoeffs[kvind]
275
+ p2 = Base. Checked. checked_mul (coeff, ck)
276
+ ci = exactdiv (Base. Checked. checked_neg (p2), last_pivot)
277
+ if ! iszero (ci)
278
+ tmp_incidence[tmp_len += 1 ] = v
279
+ tmp_coeffs[tmp_len] = ci
280
+ end
281
+ end
282
+ (kvind == numkvars) && break
283
+ v = kvars[kvind += 1 ]
284
+ end
285
+ elseif ivind <= numivars
286
+ v = ivv
287
+ while true
288
+ if _debug_mode
289
+ @assert v == vars[vi += 1 ]
290
+ end
291
+ if v != vpivot
292
+ p1 = Base. Checked. checked_mul (pivot, icoeffs[ivind])
293
+ ci = exactdiv (p1, last_pivot)
294
+ if ! iszero (ci)
295
+ tmp_incidence[tmp_len += 1 ] = v
296
+ tmp_coeffs[tmp_len] = ci
297
+ end
298
+ end
299
+ (ivind == numivars) && break
300
+ v = ivars[ivind += 1 ]
208
301
end
209
302
end
303
+ resize! (tmp_incidence, tmp_len)
304
+ resize! (tmp_coeffs, tmp_len)
210
305
eadj[ei] = tmp_incidence
211
306
old_cadj[ei] = tmp_coeffs
212
307
end
0 commit comments