@@ -19,7 +19,7 @@ namespace cp_algo::math::fft {
19
19
}
20
20
static u64x4 mod, imod;
21
21
22
- void init () {
22
+ static void init () {
23
23
if (!_init) {
24
24
factor = 1 + random ::rng () % (base::mod () - 1 );
25
25
ifactor = base (1 ) / factor;
@@ -40,16 +40,16 @@ namespace cp_algo::math::fft {
40
40
};
41
41
u64x4 step4 = u64x4{} + (bpow (factor, 4 ) * b2x32).getr ();
42
42
u64x4 stepn = u64x4{} + (bpow (factor, n) * b2x32).getr ();
43
- for (size_t i = 0 ; i < std::min (n, size (a)); i += flen) {
43
+ for (size_t i = 0 ; i < std::min (n, std:: size (a)); i += flen) {
44
44
auto splt = [&](size_t i, auto mul) {
45
- if (i >= size (a)) {
45
+ if (i >= std:: size (a)) {
46
46
return std::pair{vftype (), vftype ()};
47
47
}
48
48
u64x4 au = {
49
- i < size (a) ? a[i].getr () : 0 ,
50
- i + 1 < size (a) ? a[i + 1 ].getr () : 0 ,
51
- i + 2 < size (a) ? a[i + 2 ].getr () : 0 ,
52
- i + 3 < size (a) ? a[i + 3 ].getr () : 0
49
+ i < std:: size (a) ? a[i].getr () : 0 ,
50
+ i + 1 < std:: size (a) ? a[i + 1 ].getr () : 0 ,
51
+ i + 2 < std:: size (a) ? a[i + 2 ].getr () : 0 ,
52
+ i + 3 < std:: size (a) ? a[i + 3 ].getr () : 0
53
53
};
54
54
au = montgomery_mul (au, mul, mod, imod);
55
55
au = au >= base::mod () ? au - base::mod () : au;
@@ -101,7 +101,8 @@ namespace cp_algo::math::fft {
101
101
}
102
102
103
103
void recover_mod (auto &&C, auto &res, size_t k) {
104
- res.assign ((k / flen + 1 ) * flen, base (0 ));
104
+ size_t check = (k + flen - 1 ) / flen * flen;
105
+ assert (res.size () >= check);
105
106
size_t n = A.size ();
106
107
auto const splitsplit = base (split () * split ()).getr ();
107
108
base b2x32 = bpow (base (2 ), 32 );
@@ -134,7 +135,6 @@ namespace cp_algo::math::fft {
134
135
}
135
136
cur = montgomery_mul (cur, step4, mod, imod);
136
137
}
137
- res.resize (k);
138
138
checkpoint (" recover mod" );
139
139
}
140
140
@@ -158,12 +158,12 @@ namespace cp_algo::math::fft {
158
158
mul (cvector (B.A ), B.B , res, k);
159
159
}
160
160
std::vector<base, big_alloc<base>> operator *= (dft &B) {
161
- std::vector<base, big_alloc<base>> res;
161
+ std::vector<base, big_alloc<base>> res ( 2 * A. size ()) ;
162
162
mul_inplace (B, res, 2 * A.size ());
163
163
return res;
164
164
}
165
165
std::vector<base, big_alloc<base>> operator *= (dft const & B) {
166
- std::vector<base, big_alloc<base>> res;
166
+ std::vector<base, big_alloc<base>> res ( 2 * A. size ()) ;
167
167
mul (B, res, 2 * A.size ());
168
168
return res;
169
169
}
@@ -180,11 +180,11 @@ namespace cp_algo::math::fft {
180
180
template <modint_type base> u64x4 dft<base>::imod = {};
181
181
182
182
void mul_slow (auto &a, auto const & b, size_t k) {
183
- if (empty (a) || empty (b)) {
183
+ if (std:: empty (a) || std:: empty (b)) {
184
184
a.clear ();
185
185
} else {
186
- size_t n = std::min (k, size (a));
187
- size_t m = std::min (k, size (b));
186
+ size_t n = std::min (k, std:: size (a));
187
+ size_t m = std::min (k, std:: size (b));
188
188
a.resize (k);
189
189
for (int j = int (k - 1 ); j >= 0 ; j--) {
190
190
a[j] *= b[0 ];
@@ -202,55 +202,103 @@ namespace cp_algo::math::fft {
202
202
}
203
203
void mul_truncate (auto &a, auto const & b, size_t k) {
204
204
using base = std::decay_t <decltype (a[0 ])>;
205
- if (std::min ({k, size (a), size (b)}) < magic) {
205
+ if (std::min ({k, std:: size (a), std:: size (b)}) < magic) {
206
206
mul_slow (a, b, k);
207
207
return ;
208
208
}
209
209
auto n = std::max (flen, std::bit_ceil (
210
- std::min (k, size (a)) + std::min (k, size (b)) - 1
210
+ std::min (k, std:: size (a)) + std::min (k, std:: size (b)) - 1
211
211
) / 2 );
212
212
auto A = dft<base>(a | std::views::take (k), n);
213
- if (&a == &b) {
214
- A.mul (A, a, k);
215
- } else {
216
- A.mul_inplace (dft<base>(b | std::views::take (k), n), a, k);
213
+ auto B = dft<base>(b | std::views::take (k), n);
214
+ a.resize ((k + flen - 1 ) / flen * flen);
215
+ A.mul_inplace (B, a, k);
216
+ a.resize (k);
217
+ }
218
+
219
+ // store mod x^n-k in first half, x^n+k in second half
220
+ void mod_split (auto &&x, size_t n, auto k) {
221
+ using base = std::decay_t <decltype (k)>;
222
+ dft<base>::init ();
223
+ assert (std::size (x) == 2 * n);
224
+ u64x4 cur = u64x4{} + (k * bpow (base (2 ), 32 )).getr ();
225
+ for (size_t i = 0 ; i < n; i += flen) {
226
+ u64x4 xl = {
227
+ x[i].getr (),
228
+ x[i + 1 ].getr (),
229
+ x[i + 2 ].getr (),
230
+ x[i + 3 ].getr ()
231
+ };
232
+ u64x4 xr = {
233
+ x[n + i].getr (),
234
+ x[n + i + 1 ].getr (),
235
+ x[n + i + 2 ].getr (),
236
+ x[n + i + 3 ].getr ()
237
+ };
238
+ xr = montgomery_mul (xr, cur, dft<base>::mod, dft<base>::imod);
239
+ xr = xr >= base::mod () ? xr - base::mod () : xr;
240
+ auto t = xr;
241
+ xr = xl - t;
242
+ xl += t;
243
+ xl = xl >= base::mod () ? xl - base::mod () : xl;
244
+ xr = xr >= base::mod () ? xr + base::mod () : xr;
245
+ for (size_t k = 0 ; k < flen; k++) {
246
+ x[i + k].setr (typename base::UInt (xl[k]));
247
+ x[n + i + k].setr (typename base::UInt (xr[k]));
248
+ }
217
249
}
250
+ cp_algo::checkpoint (" mod split" );
218
251
}
219
- void mul (auto &a, auto const & b) {
252
+ void cyclic_mul (auto &a, auto &&b, size_t k) {
253
+ assert (std::popcount (k) == 1 );
254
+ assert (std::size (a) == std::size (b) && std::size (a) == k);
255
+ using base = std::decay_t <decltype (a[0 ])>;
256
+ dft<base>::init ();
257
+ if (k <= (1 << 16 )) {
258
+ auto ap = std::ranges::to<std::vector<base, big_alloc<base>>>(a);
259
+ mul_truncate (ap, b, 2 * k);
260
+ mod_split (ap, k, bpow (dft<base>::factor, k));
261
+ std::ranges::copy (ap | std::views::take (k), begin (a));
262
+ return ;
263
+ }
264
+ k /= 2 ;
265
+ auto factor = bpow (dft<base>::factor, k);
266
+ mod_split (a, k, factor);
267
+ mod_split (b, k, factor);
268
+ auto la = std::span (a).first (k);
269
+ auto lb = std::span (b).first (k);
270
+ auto ra = std::span (a).last (k);
271
+ auto rb = std::span (b).last (k);
272
+ cyclic_mul (la, lb, k);
273
+ auto A = dft<base>(ra, k / 2 );
274
+ auto B = dft<base>(rb, k / 2 );
275
+ A.mul_inplace (B, ra, k);
276
+ base i2 = base (2 ).inv ();
277
+ factor = factor.inv () * i2;
278
+ for (size_t i = 0 ; i < k; i++) {
279
+ auto t = (a[i] + a[i + k]) * i2;
280
+ a[i + k] = (a[i] - a[i + k]) * factor;
281
+ a[i] = t;
282
+ }
283
+ cp_algo::checkpoint (" mod join" );
284
+ }
285
+ void cyclic_mul (auto &a, auto const & b, size_t k) {
286
+ return cyclic_mul (a, make_copy (b), k);
287
+ }
288
+ void mul (auto &a, auto &&b) {
220
289
size_t N = size (a) + size (b) - 1 ;
221
- if (std::max (size (a), size (b)) > (1 << 23 )) {
222
- using T = std::decay_t <decltype (a[0 ])>;
223
- // do karatsuba to save memory
224
- auto n = (std::max (size (a), size (b)) + 1 ) / 2 ;
225
- auto a0 = to<std::vector<T, big_alloc<T>>>(a | std::views::take (n));
226
- auto a1 = to<std::vector<T, big_alloc<T>>>(a | std::views::drop (n));
227
- auto b0 = to<std::vector<T, big_alloc<T>>>(b | std::views::take (n));
228
- auto b1 = to<std::vector<T, big_alloc<T>>>(b | std::views::drop (n));
229
- a0.resize (n); a1.resize (n);
230
- b0.resize (n); b1.resize (n);
231
- auto a01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform (std::plus{}, a0, a1));
232
- auto b01 = to<std::vector<T, big_alloc<T>>>(std::views::zip_transform (std::plus{}, b0, b1));
233
- checkpoint (" karatsuba split" );
234
- mul (a0, b0);
235
- mul (a1, b1);
236
- mul (a01, b01);
237
- a.assign (4 * n, 0 );
238
- for (auto [i, ai]: a0 | std::views::enumerate) {
239
- a[i] += ai;
240
- a[i + n] -= ai;
241
- }
242
- for (auto [i, ai]: a1 | std::views::enumerate) {
243
- a[i + n] -= ai;
244
- a[i + 2 * n] += ai;
245
- }
246
- for (auto [i, ai]: a01 | std::views::enumerate) {
247
- a[i + n] += ai;
248
- }
290
+ if (N > (1 << 19 )) {
291
+ size_t NN = std::bit_ceil (N);
292
+ a.resize (NN);
293
+ b.resize (NN);
294
+ cyclic_mul (a, b, NN);
249
295
a.resize (N);
250
- checkpoint (" karatsuba join" );
251
- } else if (size (a)) {
296
+ } else {
252
297
mul_truncate (a, b, N);
253
298
}
254
299
}
300
+ void mul (auto &a, auto const & b) {
301
+ mul (a, make_copy (b));
302
+ }
255
303
}
256
304
#endif // CP_ALGO_MATH_FFT_HPP
0 commit comments