@@ -199,7 +199,6 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
199199std::vector<mint> convolution_fft (std::vector<mint> a, std::vector<mint> b) {
200200 int n = int (a.size ()), m = int (b.size ());
201201 int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
202- assert (mint::mod () % z == 1 );
203202 a.resize (z);
204203 internal::butterfly (a);
205204 b.resize (z);
@@ -220,6 +219,10 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
220219std::vector<mint> convolution (std::vector<mint>&& a, std::vector<mint>&& b) {
221220 int n = int (a.size ()), m = int (b.size ());
222221 if (!n || !m) return {};
222+
223+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
224+ assert (mint::mod () % z == 1 );
225+
223226 if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
224227 return internal::convolution_fft (a, b);
225228}
@@ -229,6 +232,10 @@ std::vector<mint> convolution(const std::vector<mint>& a,
229232 const std::vector<mint>& b) {
230233 int n = int (a.size ()), m = int (b.size ());
231234 if (!n || !m) return {};
235+
236+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
237+ assert (mint::mod () % z == 1 );
238+
232239 if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
233240 return internal::convolution_fft (a, b);
234241}
@@ -241,6 +248,10 @@ std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
241248 if (!n || !m) return {};
242249
243250 using mint = static_modint<mod>;
251+
252+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
253+ assert (mint::mod () % z == 1 );
254+
244255 std::vector<mint> a2 (n), b2 (m);
245256 for (int i = 0 ; i < n; i++) {
246257 a2[i] = mint (a[i]);
@@ -280,7 +291,7 @@ std::vector<long long> convolution_ll(const std::vector<long long>& a,
280291 static_assert (MOD1 % (1ull << MAX_AB_BIT) == 1 , " MOD1 isn't enough to support an array length of 2^24." );
281292 static_assert (MOD2 % (1ull << MAX_AB_BIT) == 1 , " MOD2 isn't enough to support an array length of 2^24." );
282293 static_assert (MOD3 % (1ull << MAX_AB_BIT) == 1 , " MOD3 isn't enough to support an array length of 2^24." );
283- assert (a. size () + b. size () - 1 <= (1ull << MAX_AB_BIT));
294+ assert (n + m - 1 <= (1 << MAX_AB_BIT));
284295
285296 auto c1 = convolution<MOD1>(a, b);
286297 auto c2 = convolution<MOD2>(a, b);
0 commit comments