diff --git a/linear_algebra_matrix/linalg_longlong.hpp b/linear_algebra_matrix/linalg_longlong.hpp index 5cf8d7f8..32f30748 100644 --- a/linear_algebra_matrix/linalg_longlong.hpp +++ b/linear_algebra_matrix/linalg_longlong.hpp @@ -4,7 +4,6 @@ #include #include -// CUT begin template std::vector> gauss_jordan(std::vector> mtr, mdint mod) { // Gauss-Jordan elimination 行基本変形のみを用いるガウス消去法 @@ -28,7 +27,7 @@ std::vector> gauss_jordan(std::vector> mtr, mtr[piv][w] ? mod - mtr[piv][w] : 0; // To preserve sign of determinant } } - lint pivinv = mod_inverse(mtr[h][c], mod); + lint pivinv = inv_mod(mtr[h][c], mod); for (int hh = 0; hh < H; hh++) { if (hh == h) continue; lint coeff = mtr[hh][c] * pivinv % mod; diff --git a/number/bare_mod_algebra.hpp b/number/bare_mod_algebra.hpp index 774a44d4..727dfbfe 100644 --- a/number/bare_mod_algebra.hpp +++ b/number/bare_mod_algebra.hpp @@ -1,11 +1,11 @@ #pragma once #include #include +#include #include #include #include -// CUT begin // Solve ax+by=gcd(a, b) template Int extgcd(Int a, Int b, Int &x, Int &y) { Int d = a; @@ -16,9 +16,10 @@ template Int extgcd(Int a, Int b, Int &x, Int &y) { } return d; } -// Calculate a^(-1) (MOD m) s if gcd(a, m) == 1 -// Calculate x s.t. ax == gcd(a, m) MOD m -template Int mod_inverse(Int a, Int m) { + +// Calculate a^(-1) (MOD m) if gcd(a, m) == 1 +// Calculate x s.t. ax == gcd(a, m) MOD m and 0 <= x < m +template Int inv_mod(Int a, Int m) { Int x, y; extgcd(a, m, x, y); x %= m; @@ -76,37 +77,25 @@ template // 蟻本 P.262 // 中国剰余定理を利用して,色々な素数で割った余りから元の値を復元 -// 連立線形合同式 A * x = B mod M の解 +// 連立線形合同式 A_i x = R_i mod M_i の解 // Requirement: M[i] > 0 // Output: x = first MOD second (if solution exists), (0, 0) (otherwise) template std::pair -linear_congruence(const std::vector &A, const std::vector &B, const std::vector &M) { +linear_congruence(const std::vector &A, const std::vector &R, const std::vector &M) { Int r = 0, m = 1; assert(A.size() == M.size()); - assert(B.size() == M.size()); + assert(R.size() == M.size()); for (int i = 0; i < (int)A.size(); i++) { assert(M[i] > 0); const Int ai = A[i] % M[i]; - Int a = ai * m, b = B[i] - ai * r, d = std::__gcd(M[i], a); + Int a = ai * m, b = R[i] - ai * r, d = std::gcd(M[i], a); if (b % d != 0) { return std::make_pair(0, 0); // 解なし } - Int t = b / d * mod_inverse(a / d, M[i] / d) % (M[i] / d); + Int t = b / d * inv_mod(a / d, M[i] / d) % (M[i] / d); r += m * t; m *= M[i] / d; } return std::make_pair((r < 0 ? r + m : r), m); } - -template Int pow_mod(Int x, long long n, Int md) { - static_assert(sizeof(Int) * 2 <= sizeof(Long), "Watch out for overflow"); - if (md == 1) return 0; - Int ans = 1; - while (n > 0) { - if (n & 1) ans = (Long)ans * x % md; - x = (Long)x * x % md; - n >>= 1; - } - return ans; -} diff --git a/number/bare_mod_algebra.md b/number/bare_mod_algebra.md new file mode 100644 index 00000000..19f5dfd4 --- /dev/null +++ b/number/bare_mod_algebra.md @@ -0,0 +1,47 @@ +--- +title: Modular arithmetic utilities (C++ の基本型整数に対する拡張 GCD・中国剰余定理・連立線形合同式などの実装) +documentation_of: ./bare_mod_algebra.hpp +--- + +modint を使わない `int` などの整数型上での剰余演算ユーティリティ集.拡張ユークリッドの互除法,モジュラ逆元,中国剰余定理 (CRT),連立線形合同式の求解などを提供する. + +## 使用方法 + +### `extgcd(a, b, x, y)` + +$ax + by = \gcd(a, b)$ を満たす整数 $x, y$ を求め,$\gcd(a, b)$ を返す. + +```cpp +long long x, y; +long long g = extgcd(12LL, 8LL, x, y); // g = 4 +``` + +### `inv_mod(a, m)` + +$a$ の $m$ を法とする逆元を返す.$\gcd(a, m) = 1$ であることが必要.返り値は $[0, m)$ の範囲. + +```cpp +long long inv = inv_mod(3LL, 7LL); // inv = 5 (3*5 = 15 ≡ 1 mod 7) +``` + +### `inv_gcd(a, b)` + +$g = \gcd(a, b)$ および $xa \equiv g \pmod{b}$,$0 \le x < b/g$ を満たす $(g, x)$ を返す.$b \ge 1$ が必要. + +### `crt(r, m)` + +中国剰余定理.$x \equiv r_i \pmod{m_i}$ $(i = 0, \ldots, n-1)$ を同時に満たす $x$ を求める.解が存在すれば $(x \bmod \mathrm{lcm}, \mathrm{lcm})$ を,存在しなければ $(0, 0)$ を返す. + +```cpp +vector r = {2, 3}, m = {3, 5}; +auto [x, lcm] = crt(r, m); // x = 8, lcm = 15 +``` + +### `linear_congruence(A, R, M)` + +連立線形合同式 $A_i x \equiv R_i \pmod{M_i}$ の解を求める.解が存在すれば $(x \bmod m, m)$ を,存在しなければ $(0, 0)$ を返す. + +```cpp +vector A = {1, 1}, R = {2, 3}, M = {3, 5}; +auto [x, m] = linear_congruence(A, R, M); // x ≡ 8 mod 15 +``` diff --git a/number/factorize.hpp b/number/factorize.hpp index 53a498fe..ee2f0218 100644 --- a/number/factorize.hpp +++ b/number/factorize.hpp @@ -1,8 +1,7 @@ #pragma once -#include "../random/xorshift.hpp" #include -#include #include +#include #include #include @@ -28,7 +27,7 @@ inline int get_id(long long n) { // Miller-Rabin primality test // https://ja.wikipedia.org/wiki/%E3%83%9F%E3%83%A9%E3%83%BC%E2%80%93%E3%83%A9%E3%83%93%E3%83%B3%E7%B4%A0%E6%95%B0%E5%88%A4%E5%AE%9A%E6%B3%95 // Complexity: O(lg n) per query -struct { +inline struct { long long modpow(__int128 x, __int128 n, long long mod) noexcept { __int128 ret = 1; for (x %= mod; n; x = x * x % mod, n >>= 1) ret = (n & 1) ? ret * x % mod : ret; @@ -53,7 +52,19 @@ struct { } } is_prime; -struct { +inline struct { + static uint32_t xorshift() noexcept { + static uint32_t x = 123456789; + static uint32_t y = 362436069; + static uint32_t z = 521288629; + static uint32_t w = 88675123; + uint32_t t = x ^ (x << 11); + x = y; + y = z; + z = w; + return w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)); + } + // Pollard's rho algorithm: find factor greater than 1 long long find_factor(long long n) { assert(n > 1); @@ -63,7 +74,7 @@ struct { auto f = [&](__int128 x) -> long long { return (x * x + c) % n; }; for (int t = 1;; t++) { - for (c = 0; c == 0 or c + 2 == n;) c = rand_int() % n; + for (c = 0; c == 0 or c + 2 == n;) c = xorshift() % n; long long x0 = t, m = std::max(n >> 3, 1LL), x, ys, y = x0, r = 1, g, q = 1; do { x = y; @@ -73,7 +84,7 @@ struct { ys = y; for (int i = std::min(m, r - k); i--;) y = f(y), q = __int128(q) * std::abs(x - y) % n; - g = std::__gcd(q, n); + g = std::gcd(q, n); k += m; } while (k < r and g <= 1); r <<= 1; @@ -81,7 +92,7 @@ struct { if (g == n) { do { ys = f(ys); - g = std::__gcd(std::abs(x - ys), n); + g = std::gcd(std::abs(x - ys), n); } while (g <= 1); } if (g != n) return g; diff --git a/number/pow_mod.hpp b/number/pow_mod.hpp new file mode 100644 index 00000000..91f6fe0c --- /dev/null +++ b/number/pow_mod.hpp @@ -0,0 +1,21 @@ +#pragma once +#include +#include + +template Int pow_mod(Int x, long long n, Int md) { + using Long = + std::conditional_t, long long, + std::conditional_t, __int128, void>>; + assert(n >= 0 and md > 0); + if (md == 1) return 0; + if (n == 0) return 1; + + x = (x % md + md) % md; + Int ans = 1; + while (n > 0) { + if (n & 1) ans = (Long)ans * x % md; + x = (Long)x * x % md; + n >>= 1; + } + return ans; +} diff --git a/number/pow_mod.md b/number/pow_mod.md new file mode 100644 index 00000000..fdb3a2a4 --- /dev/null +++ b/number/pow_mod.md @@ -0,0 +1,17 @@ +--- +title: Modular exponentiation (べき乗 mod) +documentation_of: ./pow_mod.hpp +--- + +整数 $x$, 非負整数 $n$, 正整数 $m$ に対し,$x^n \bmod m$ を $O(\log n)$ で計算する.繰り返し二乗法による実装.`Int` が `int` のとき内部で `long long`,`long long` のとき `__int128` を用いてオーバーフローを回避する. + +## 使用方法 + +```cpp +int a = pow_mod(3, 100, 1000000007); // Int = int +long long b = pow_mod(3LL, 100LL, (long long)1e18 + 9); // Int = long long +``` + +- `x`: 底. +- `n`: 指数($n \ge 0$). +- `md`: 法($m \ge 1$).$m = 1$ のとき常に $0$ を返す. diff --git a/number/primitive_root.hpp b/number/primitive_root.hpp index 8f5a66b8..a6b522f0 100644 --- a/number/primitive_root.hpp +++ b/number/primitive_root.hpp @@ -1,6 +1,6 @@ #pragma once -#include "bare_mod_algebra.hpp" #include "factorize.hpp" +#include "pow_mod.hpp" // Find smallest primitive root for given number n (最小の原始根探索) // n must be 2 / 4 / p^k / 2p^k (p: odd prime, k > 0) @@ -25,10 +25,10 @@ inline long long find_smallest_primitive_root(long long n) { for (long long g = 1; g < n; g++) { if (std::gcd(n, g) != 1) continue; - if (pow_mod(g, phi, n) != 1) return -1; + if (pow_mod(g, phi, n) != 1) return -1; bool ok = true; for (auto pp : fac) { - if (pow_mod(g, phi / pp, n) == 1) { + if (pow_mod(g, phi / pp, n) == 1) { ok = false; break; } diff --git a/number/test/miller-rabin.test.cpp b/number/test/miller-rabin.test.cpp index 8823ed41..27cd13fe 100644 --- a/number/test/miller-rabin.test.cpp +++ b/number/test/miller-rabin.test.cpp @@ -1,5 +1,5 @@ #include "../factorize.hpp" -#define PROBLEM "https://yukicoder.me/problems/no/3030" +#define PROBLEM "https://yukicoder.me/problems/no/8030" #include int main() { diff --git a/number/test/sieve.stress.test.cpp b/number/test/sieve.stress.test.cpp index 1e8103c2..c89e6fea 100644 --- a/number/test/sieve.stress.test.cpp +++ b/number/test/sieve.stress.test.cpp @@ -1,10 +1,10 @@ #define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A" // DUMMY #include "../sieve.hpp" #include "../modint_runtime.hpp" -#include #include #include #include +#include #include using namespace std; @@ -15,7 +15,7 @@ struct Case { int euler_phi(int x) { int ret = 0; - for (int d = 1; d <= x; d++) ret += (std::__gcd(d, x) == 1); + for (int d = 1; d <= x; d++) ret += (std::gcd(d, x) == 1); return ret; } diff --git a/rational/rational_number.hpp b/rational/rational_number.hpp index 08d9439e..09266de2 100644 --- a/rational/rational_number.hpp +++ b/rational/rational_number.hpp @@ -8,7 +8,7 @@ template struct Rational { Int num, den; // den >= 0 static constexpr Int my_gcd(Int a, Int b) { - // return __gcd(a, b); + // return gcd(a, b); if (a < 0) a = -a; if (b < 0) b = -b; while (a and b) { diff --git a/utilities/pow_sum.hpp b/utilities/pow_sum.hpp index 914beb27..4c34eb84 100644 --- a/utilities/pow_sum.hpp +++ b/utilities/pow_sum.hpp @@ -1,13 +1,15 @@ #pragma once -#include +#include +#include #include -// CUT begin -// {x^n, x^0 + ... + x^(n - 1)} for n >= 1 +// {x^n, x^0 + ... + x^(n - 1)} for n >= 0 // Verify: ABC212H template std::pair pow_sum(T x, Int n) { + using Uint = std::make_unsigned_t; + if (n == 0) return {1, 0}; T sum = 1, p = x; // ans = x^0 + ... + x^(len - 1), p = x^len - for (int d = std::__lg(n) - 1; d >= 0; d--) { + for (int d = std::bit_width(Uint(n)) - 2; d >= 0; d--) { sum = sum * (p + 1); p *= p; if ((n >> d) & 1) {