Skip to content

Commit de15dd5

Browse files
committed
Optimize convolution large, aggregate checkpoints by labels
1 parent 45ee676 commit de15dd5

File tree

4 files changed

+112
-54
lines changed

4 files changed

+112
-54
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
main.test.cpp
2+
.competitive-verifier/*
3+
verify_files.json

cp-algo/math/fft.hpp

Lines changed: 99 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace cp_algo::math::fft {
1919
}
2020
static u64x4 mod, imod;
2121

22-
void init() {
22+
static void init() {
2323
if(!_init) {
2424
factor = 1 + random::rng() % (base::mod() - 1);
2525
ifactor = base(1) / factor;
@@ -40,16 +40,16 @@ namespace cp_algo::math::fft {
4040
};
4141
u64x4 step4 = u64x4{} + (bpow(factor, 4) * b2x32).getr();
4242
u64x4 stepn = u64x4{} + (bpow(factor, n) * b2x32).getr();
43-
for(size_t i = 0; i < std::min(n, size(a)); i += flen) {
43+
for(size_t i = 0; i < std::min(n, std::size(a)); i += flen) {
4444
auto splt = [&](size_t i, auto mul) {
45-
if(i >= size(a)) {
45+
if(i >= std::size(a)) {
4646
return std::pair{vftype(), vftype()};
4747
}
4848
u64x4 au = {
49-
i < size(a) ? a[i].getr() : 0,
50-
i + 1 < size(a) ? a[i + 1].getr() : 0,
51-
i + 2 < size(a) ? a[i + 2].getr() : 0,
52-
i + 3 < size(a) ? a[i + 3].getr() : 0
49+
i < std::size(a) ? a[i].getr() : 0,
50+
i + 1 < std::size(a) ? a[i + 1].getr() : 0,
51+
i + 2 < std::size(a) ? a[i + 2].getr() : 0,
52+
i + 3 < std::size(a) ? a[i + 3].getr() : 0
5353
};
5454
au = montgomery_mul(au, mul, mod, imod);
5555
au = au >= base::mod() ? au - base::mod() : au;
@@ -101,7 +101,8 @@ namespace cp_algo::math::fft {
101101
}
102102

103103
void recover_mod(auto &&C, auto &res, size_t k) {
104-
res.assign((k / flen + 1) * flen, base(0));
104+
size_t check = (k + flen - 1) / flen * flen;
105+
assert(res.size() >= check);
105106
size_t n = A.size();
106107
auto const splitsplit = base(split() * split()).getr();
107108
base b2x32 = bpow(base(2), 32);
@@ -134,7 +135,6 @@ namespace cp_algo::math::fft {
134135
}
135136
cur = montgomery_mul(cur, step4, mod, imod);
136137
}
137-
res.resize(k);
138138
checkpoint("recover mod");
139139
}
140140

@@ -158,12 +158,12 @@ namespace cp_algo::math::fft {
158158
mul(cvector(B.A), B.B, res, k);
159159
}
160160
std::vector<base, big_alloc<base>> operator *= (dft &B) {
161-
std::vector<base, big_alloc<base>> res;
161+
std::vector<base, big_alloc<base>> res(2 * A.size());
162162
mul_inplace(B, res, 2 * A.size());
163163
return res;
164164
}
165165
std::vector<base, big_alloc<base>> operator *= (dft const& B) {
166-
std::vector<base, big_alloc<base>> res;
166+
std::vector<base, big_alloc<base>> res(2 * A.size());
167167
mul(B, res, 2 * A.size());
168168
return res;
169169
}
@@ -180,11 +180,11 @@ namespace cp_algo::math::fft {
180180
template<modint_type base> u64x4 dft<base>::imod = {};
181181

182182
void mul_slow(auto &a, auto const& b, size_t k) {
183-
if(empty(a) || empty(b)) {
183+
if(std::empty(a) || std::empty(b)) {
184184
a.clear();
185185
} else {
186-
size_t n = std::min(k, size(a));
187-
size_t m = std::min(k, size(b));
186+
size_t n = std::min(k, std::size(a));
187+
size_t m = std::min(k, std::size(b));
188188
a.resize(k);
189189
for(int j = int(k - 1); j >= 0; j--) {
190190
a[j] *= b[0];
@@ -202,55 +202,103 @@ namespace cp_algo::math::fft {
202202
}
203203
void mul_truncate(auto &a, auto const& b, size_t k) {
204204
using base = std::decay_t<decltype(a[0])>;
205-
if(std::min({k, size(a), size(b)}) < magic) {
205+
if(std::min({k, std::size(a), std::size(b)}) < magic) {
206206
mul_slow(a, b, k);
207207
return;
208208
}
209209
auto n = std::max(flen, std::bit_ceil(
210-
std::min(k, size(a)) + std::min(k, size(b)) - 1
210+
std::min(k, std::size(a)) + std::min(k, std::size(b)) - 1
211211
) / 2);
212212
auto A = dft<base>(a | std::views::take(k), n);
213-
if(&a == &b) {
214-
A.mul(A, a, k);
215-
} else {
216-
A.mul_inplace(dft<base>(b | std::views::take(k), n), a, k);
213+
auto B = dft<base>(b | std::views::take(k), n);
214+
a.resize((k + flen - 1) / flen * flen);
215+
A.mul_inplace(B, a, k);
216+
a.resize(k);
217+
}
218+
219+
// store mod x^n-k in first half, x^n+k in second half
220+
void mod_split(auto &&x, size_t n, auto k) {
221+
using base = std::decay_t<decltype(k)>;
222+
dft<base>::init();
223+
assert(std::size(x) == 2 * n);
224+
u64x4 cur = u64x4{} + (k * bpow(base(2), 32)).getr();
225+
for(size_t i = 0; i < n; i += flen) {
226+
u64x4 xl = {
227+
x[i].getr(),
228+
x[i + 1].getr(),
229+
x[i + 2].getr(),
230+
x[i + 3].getr()
231+
};
232+
u64x4 xr = {
233+
x[n + i].getr(),
234+
x[n + i + 1].getr(),
235+
x[n + i + 2].getr(),
236+
x[n + i + 3].getr()
237+
};
238+
xr = montgomery_mul(xr, cur, dft<base>::mod, dft<base>::imod);
239+
xr = xr >= base::mod() ? xr - base::mod() : xr;
240+
auto t = xr;
241+
xr = xl - t;
242+
xl += t;
243+
xl = xl >= base::mod() ? xl - base::mod() : xl;
244+
xr = xr >= base::mod() ? xr + base::mod() : xr;
245+
for(size_t k = 0; k < flen; k++) {
246+
x[i + k].setr(typename base::UInt(xl[k]));
247+
x[n + i + k].setr(typename base::UInt(xr[k]));
248+
}
217249
}
250+
cp_algo::checkpoint("mod split");
218251
}
219-
void mul(auto &a, auto const& b) {
252+
void cyclic_mul(auto &a, auto &&b, size_t k) {
253+
assert(std::popcount(k) == 1);
254+
assert(std::size(a) == std::size(b) && std::size(a) == k);
255+
using base = std::decay_t<decltype(a[0])>;
256+
dft<base>::init();
257+
if(k <= (1 << 16)) {
258+
auto ap = std::ranges::to<std::vector<base, big_alloc<base>>>(a);
259+
mul_truncate(ap, b, 2 * k);
260+
mod_split(ap, k, bpow(dft<base>::factor, k));
261+
std::ranges::copy(ap | std::views::take(k), begin(a));
262+
return;
263+
}
264+
k /= 2;
265+
auto factor = bpow(dft<base>::factor, k);
266+
mod_split(a, k, factor);
267+
mod_split(b, k, factor);
268+
auto la = std::span(a).first(k);
269+
auto lb = std::span(b).first(k);
270+
auto ra = std::span(a).last(k);
271+
auto rb = std::span(b).last(k);
272+
cyclic_mul(la, lb, k);
273+
auto A = dft<base>(ra, k / 2);
274+
auto B = dft<base>(rb, k / 2);
275+
A.mul_inplace(B, ra, k);
276+
base i2 = base(2).inv();
277+
factor = factor.inv() * i2;
278+
for(size_t i = 0; i < k; i++) {
279+
auto t = (a[i] + a[i + k]) * i2;
280+
a[i + k] = (a[i] - a[i + k]) * factor;
281+
a[i] = t;
282+
}
283+
cp_algo::checkpoint("mod join");
284+
}
285+
void cyclic_mul(auto &a, auto const& b, size_t k) {
286+
return cyclic_mul(a, make_copy(b), k);
287+
}
288+
void mul(auto &a, auto &&b) {
220289
size_t N = size(a) + size(b) - 1;
221-
if(std::max(size(a), size(b)) > (1 << 23)) {
222-
using T = std::decay_t<decltype(a[0])>;
223-
// do karatsuba to save memory
224-
auto n = (std::max(size(a), size(b)) + 1) / 2;
225-
auto a0 = to<std::vector<T, big_alloc<T>>>(a | std::views::take(n));
226-
auto a1 = to<std::vector<T, big_alloc<T>>>(a | std::views::drop(n));
227-
auto b0 = to<std::vector<T, big_alloc<T>>>(b | std::views::take(n));
228-
auto b1 = to<std::vector<T, big_alloc<T>>>(b | std::views::drop(n));
229-
a0.resize(n); a1.resize(n);
230-
b0.resize(n); b1.resize(n);
231-
auto a01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform(std::plus{}, a0, a1));
232-
auto b01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform(std::plus{}, b0, b1));
233-
checkpoint("karatsuba split");
234-
mul(a0, b0);
235-
mul(a1, b1);
236-
mul(a01, b01);
237-
a.assign(4 * n, 0);
238-
for(auto [i, ai]: a0 | std::views::enumerate) {
239-
a[i] += ai;
240-
a[i + n] -= ai;
241-
}
242-
for(auto [i, ai]: a1 | std::views::enumerate) {
243-
a[i + n] -= ai;
244-
a[i + 2 * n] += ai;
245-
}
246-
for(auto [i, ai]: a01 | std::views::enumerate) {
247-
a[i + n] += ai;
248-
}
290+
if(N > (1 << 19)) {
291+
size_t NN = std::bit_ceil(N);
292+
a.resize(NN);
293+
b.resize(NN);
294+
cyclic_mul(a, b, NN);
249295
a.resize(N);
250-
checkpoint("karatsuba join");
251-
} else if(size(a)) {
296+
} else {
252297
mul_truncate(a, b, N);
253298
}
254299
}
300+
void mul(auto &a, auto const& b) {
301+
mul(a, make_copy(b));
302+
}
255303
}
256304
#endif // CP_ALGO_MATH_FFT_HPP

cp-algo/util/checkpoint.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,24 @@
33
#include <iostream>
44
#include <chrono>
55
#include <string>
6+
#include <map>
67
namespace cp_algo {
8+
std::map<std::string, double> checkpoints;
79
template<bool final = false>
810
void checkpoint([[maybe_unused]] std::string const& msg = "") {
911
#ifdef CP_ALGO_CHECKPOINT
1012
static double last = 0;
1113
double now = (double)clock() / CLOCKS_PER_SEC;
1214
double delta = now - last;
1315
last = now;
14-
if(msg.size()) {
15-
std::cerr << msg << ": " << (final ? now : delta) * 1000 << " ms\n";
16+
if(msg.size() && !final) {
17+
checkpoints[msg] += delta;
18+
}
19+
if(final) {
20+
for(auto const& [key, value] : checkpoints) {
21+
std::cerr << key << ": " << value * 1000 << " ms\n";
22+
}
23+
std::cerr << "Total: " << now * 1000 << " ms\n";
1624
}
1725
#endif
1826
}

cp-algo/util/complex.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace cp_algo {
3131
T const imag() const {return y;}
3232
T& real() {return x;}
3333
T& imag() {return y;}
34-
static constexpr complex polar(T r, T theta) {return {r * cos(theta), r * sin(theta)};}
34+
static constexpr complex polar(T r, T theta) {return {T(r * cos(theta)), T(r * sin(theta))};}
3535
auto operator <=> (complex const& t) const = default;
3636
};
3737
template<typename T>

0 commit comments

Comments
 (0)