Skip to content

Commit ab074d4

Browse files
authored
Merge pull request #422 from hitonanode/refactor-bare-mod-algebra
refactor bare_mod_algebra and so on
2 parents 5d6373d + 11e36c3 commit ab074d4

11 files changed

Lines changed: 127 additions & 41 deletions

File tree

linear_algebra_matrix/linalg_longlong.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <cstdlib>
55
#include <vector>
66

7-
// CUT begin
87
template <typename lint, typename mdint>
98
std::vector<std::vector<lint>> gauss_jordan(std::vector<std::vector<lint>> mtr, mdint mod) {
109
// Gauss-Jordan elimination 行基本変形のみを用いるガウス消去法
@@ -28,7 +27,7 @@ std::vector<std::vector<lint>> gauss_jordan(std::vector<std::vector<lint>> mtr,
2827
mtr[piv][w] ? mod - mtr[piv][w] : 0; // To preserve sign of determinant
2928
}
3029
}
31-
lint pivinv = mod_inverse<lint>(mtr[h][c], mod);
30+
lint pivinv = inv_mod<lint>(mtr[h][c], mod);
3231
for (int hh = 0; hh < H; hh++) {
3332
if (hh == h) continue;
3433
lint coeff = mtr[hh][c] * pivinv % mod;

number/bare_mod_algebra.hpp

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#pragma once
22
#include <algorithm>
33
#include <cassert>
4+
#include <numeric>
45
#include <tuple>
56
#include <utility>
67
#include <vector>
78

8-
// CUT begin
99
// Solve ax+by=gcd(a, b)
1010
template <class Int> Int extgcd(Int a, Int b, Int &x, Int &y) {
1111
Int d = a;
@@ -16,9 +16,10 @@ template <class Int> Int extgcd(Int a, Int b, Int &x, Int &y) {
1616
}
1717
return d;
1818
}
19-
// Calculate a^(-1) (MOD m) s if gcd(a, m) == 1
20-
// Calculate x s.t. ax == gcd(a, m) MOD m
21-
template <class Int> Int mod_inverse(Int a, Int m) {
19+
20+
// Calculate a^(-1) (MOD m) if gcd(a, m) == 1
21+
// Calculate x s.t. ax == gcd(a, m) MOD m and 0 <= x < m
22+
template <class Int> Int inv_mod(Int a, Int m) {
2223
Int x, y;
2324
extgcd<Int>(a, m, x, y);
2425
x %= m;
@@ -76,37 +77,25 @@ template <class Int>
7677

7778
// 蟻本 P.262
7879
// 中国剰余定理を利用して,色々な素数で割った余りから元の値を復元
79-
// 連立線形合同式 A * x = B mod M の解
80+
// 連立線形合同式 A_i x = R_i mod M_i の解
8081
// Requirement: M[i] > 0
8182
// Output: x = first MOD second (if solution exists), (0, 0) (otherwise)
8283
template <class Int>
8384
std::pair<Int, Int>
84-
linear_congruence(const std::vector<Int> &A, const std::vector<Int> &B, const std::vector<Int> &M) {
85+
linear_congruence(const std::vector<Int> &A, const std::vector<Int> &R, const std::vector<Int> &M) {
8586
Int r = 0, m = 1;
8687
assert(A.size() == M.size());
87-
assert(B.size() == M.size());
88+
assert(R.size() == M.size());
8889
for (int i = 0; i < (int)A.size(); i++) {
8990
assert(M[i] > 0);
9091
const Int ai = A[i] % M[i];
91-
Int a = ai * m, b = B[i] - ai * r, d = std::__gcd(M[i], a);
92+
Int a = ai * m, b = R[i] - ai * r, d = std::gcd(M[i], a);
9293
if (b % d != 0) {
9394
return std::make_pair(0, 0); // 解なし
9495
}
95-
Int t = b / d * mod_inverse<Int>(a / d, M[i] / d) % (M[i] / d);
96+
Int t = b / d * inv_mod<Int>(a / d, M[i] / d) % (M[i] / d);
9697
r += m * t;
9798
m *= M[i] / d;
9899
}
99100
return std::make_pair((r < 0 ? r + m : r), m);
100101
}
101-
102-
template <class Int = int, class Long = long long> Int pow_mod(Int x, long long n, Int md) {
103-
static_assert(sizeof(Int) * 2 <= sizeof(Long), "Watch out for overflow");
104-
if (md == 1) return 0;
105-
Int ans = 1;
106-
while (n > 0) {
107-
if (n & 1) ans = (Long)ans * x % md;
108-
x = (Long)x * x % md;
109-
n >>= 1;
110-
}
111-
return ans;
112-
}

number/bare_mod_algebra.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
---
2+
title: Modular arithmetic utilities (C++ の基本型整数に対する拡張 GCD・中国剰余定理・連立線形合同式などの実装)
3+
documentation_of: ./bare_mod_algebra.hpp
4+
---
5+
6+
modint を使わない `int` などの整数型上での剰余演算ユーティリティ集.拡張ユークリッドの互除法,モジュラ逆元,中国剰余定理 (CRT),連立線形合同式の求解などを提供する.
7+
8+
## 使用方法
9+
10+
### `extgcd(a, b, x, y)`
11+
12+
$ax + by = \gcd(a, b)$ を満たす整数 $x, y$ を求め,$\gcd(a, b)$ を返す.
13+
14+
```cpp
15+
long long x, y;
16+
long long g = extgcd(12LL, 8LL, x, y); // g = 4
17+
```
18+
19+
### `inv_mod(a, m)`
20+
21+
$a$ の $m$ を法とする逆元を返す.$\gcd(a, m) = 1$ であることが必要.返り値は $[0, m)$ の範囲.
22+
23+
```cpp
24+
long long inv = inv_mod(3LL, 7LL); // inv = 5 (3*5 = 15 ≡ 1 mod 7)
25+
```
26+
27+
### `inv_gcd(a, b)`
28+
29+
$g = \gcd(a, b)$ および $xa \equiv g \pmod{b}$,$0 \le x < b/g$ を満たす $(g, x)$ を返す.$b \ge 1$ が必要.
30+
31+
### `crt(r, m)`
32+
33+
中国剰余定理.$x \equiv r_i \pmod{m_i}$ $(i = 0, \ldots, n-1)$ を同時に満たす $x$ を求める.解が存在すれば $(x \bmod \mathrm{lcm}, \mathrm{lcm})$ を,存在しなければ $(0, 0)$ を返す.
34+
35+
```cpp
36+
vector<long long> r = {2, 3}, m = {3, 5};
37+
auto [x, lcm] = crt(r, m); // x = 8, lcm = 15
38+
```
39+
40+
### `linear_congruence(A, R, M)`
41+
42+
連立線形合同式 $A_i x \equiv R_i \pmod{M_i}$ の解を求める.解が存在すれば $(x \bmod m, m)$ を,存在しなければ $(0, 0)$ を返す.
43+
44+
```cpp
45+
vector<long long> A = {1, 1}, R = {2, 3}, M = {3, 5};
46+
auto [x, m] = linear_congruence(A, R, M); // x ≡ 8 mod 15
47+
```

number/factorize.hpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#pragma once
2-
#include "../random/xorshift.hpp"
32
#include <algorithm>
4-
#include <array>
53
#include <cassert>
4+
#include <cstdint>
65
#include <numeric>
76
#include <vector>
87

@@ -28,7 +27,7 @@ inline int get_id(long long n) {
2827
// Miller-Rabin primality test
2928
// https://ja.wikipedia.org/wiki/%E3%83%9F%E3%83%A9%E3%83%BC%E2%80%93%E3%83%A9%E3%83%93%E3%83%B3%E7%B4%A0%E6%95%B0%E5%88%A4%E5%AE%9A%E6%B3%95
3029
// Complexity: O(lg n) per query
31-
struct {
30+
inline struct {
3231
long long modpow(__int128 x, __int128 n, long long mod) noexcept {
3332
__int128 ret = 1;
3433
for (x %= mod; n; x = x * x % mod, n >>= 1) ret = (n & 1) ? ret * x % mod : ret;
@@ -53,7 +52,19 @@ struct {
5352
}
5453
} is_prime;
5554

56-
struct {
55+
inline struct {
56+
static uint32_t xorshift() noexcept {
57+
static uint32_t x = 123456789;
58+
static uint32_t y = 362436069;
59+
static uint32_t z = 521288629;
60+
static uint32_t w = 88675123;
61+
uint32_t t = x ^ (x << 11);
62+
x = y;
63+
y = z;
64+
z = w;
65+
return w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
66+
}
67+
5768
// Pollard's rho algorithm: find factor greater than 1
5869
long long find_factor(long long n) {
5970
assert(n > 1);
@@ -63,7 +74,7 @@ struct {
6374
auto f = [&](__int128 x) -> long long { return (x * x + c) % n; };
6475

6576
for (int t = 1;; t++) {
66-
for (c = 0; c == 0 or c + 2 == n;) c = rand_int() % n;
77+
for (c = 0; c == 0 or c + 2 == n;) c = xorshift() % n;
6778
long long x0 = t, m = std::max(n >> 3, 1LL), x, ys, y = x0, r = 1, g, q = 1;
6879
do {
6980
x = y;
@@ -73,15 +84,15 @@ struct {
7384
ys = y;
7485
for (int i = std::min(m, r - k); i--;)
7586
y = f(y), q = __int128(q) * std::abs(x - y) % n;
76-
g = std::__gcd<long long>(q, n);
87+
g = std::gcd<long long>(q, n);
7788
k += m;
7889
} while (k < r and g <= 1);
7990
r <<= 1;
8091
} while (g <= 1);
8192
if (g == n) {
8293
do {
8394
ys = f(ys);
84-
g = std::__gcd(std::abs(x - ys), n);
95+
g = std::gcd(std::abs(x - ys), n);
8596
} while (g <= 1);
8697
}
8798
if (g != n) return g;

number/pow_mod.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
#include <cassert>
3+
#include <type_traits>
4+
5+
template <class Int> Int pow_mod(Int x, long long n, Int md) {
6+
using Long =
7+
std::conditional_t<std::is_same_v<Int, int>, long long,
8+
std::conditional_t<std::is_same_v<Int, long long>, __int128, void>>;
9+
assert(n >= 0 and md > 0);
10+
if (md == 1) return 0;
11+
if (n == 0) return 1;
12+
13+
x = (x % md + md) % md;
14+
Int ans = 1;
15+
while (n > 0) {
16+
if (n & 1) ans = (Long)ans * x % md;
17+
x = (Long)x * x % md;
18+
n >>= 1;
19+
}
20+
return ans;
21+
}

number/pow_mod.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
---
2+
title: Modular exponentiation (べき乗 mod)
3+
documentation_of: ./pow_mod.hpp
4+
---
5+
6+
整数 $x$, 非負整数 $n$, 正整数 $m$ に対し,$x^n \bmod m$ を $O(\log n)$ で計算する.繰り返し二乗法による実装.`Int``int` のとき内部で `long long``long long` のとき `__int128` を用いてオーバーフローを回避する.
7+
8+
## 使用方法
9+
10+
```cpp
11+
int a = pow_mod(3, 100, 1000000007); // Int = int
12+
long long b = pow_mod(3LL, 100LL, (long long)1e18 + 9); // Int = long long
13+
```
14+
15+
- `x`: 底.
16+
- `n`: 指数($n \ge 0$).
17+
- `md`: 法($m \ge 1$).$m = 1$ のとき常に $0$ を返す.

number/primitive_root.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
2-
#include "bare_mod_algebra.hpp"
32
#include "factorize.hpp"
3+
#include "pow_mod.hpp"
44

55
// Find smallest primitive root for given number n (最小の原始根探索)
66
// n must be 2 / 4 / p^k / 2p^k (p: odd prime, k > 0)
@@ -25,10 +25,10 @@ inline long long find_smallest_primitive_root(long long n) {
2525

2626
for (long long g = 1; g < n; g++) {
2727
if (std::gcd(n, g) != 1) continue;
28-
if (pow_mod<long long, __int128>(g, phi, n) != 1) return -1;
28+
if (pow_mod(g, phi, n) != 1) return -1;
2929
bool ok = true;
3030
for (auto pp : fac) {
31-
if (pow_mod<long long, __int128>(g, phi / pp, n) == 1) {
31+
if (pow_mod(g, phi / pp, n) == 1) {
3232
ok = false;
3333
break;
3434
}

number/test/miller-rabin.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "../factorize.hpp"
2-
#define PROBLEM "https://yukicoder.me/problems/no/3030"
2+
#define PROBLEM "https://yukicoder.me/problems/no/8030"
33
#include <iostream>
44

55
int main() {

number/test/sieve.stress.test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A" // DUMMY
22
#include "../sieve.hpp"
33
#include "../modint_runtime.hpp"
4-
#include <algorithm>
54
#include <cassert>
65
#include <cstdio>
76
#include <iostream>
7+
#include <numeric>
88
#include <vector>
99
using namespace std;
1010

@@ -15,7 +15,7 @@ struct Case {
1515

1616
int euler_phi(int x) {
1717
int ret = 0;
18-
for (int d = 1; d <= x; d++) ret += (std::__gcd(d, x) == 1);
18+
for (int d = 1; d <= x; d++) ret += (std::gcd(d, x) == 1);
1919
return ret;
2020
}
2121

rational/rational_number.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ template <class Int, bool AutoReduce = false> struct Rational {
88
Int num, den; // den >= 0
99

1010
static constexpr Int my_gcd(Int a, Int b) {
11-
// return __gcd(a, b);
11+
// return gcd(a, b);
1212
if (a < 0) a = -a;
1313
if (b < 0) b = -b;
1414
while (a and b) {

0 commit comments

Comments
 (0)