Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions linear_algebra_matrix/linalg_longlong.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <cstdlib>
#include <vector>

// CUT begin
template <typename lint, typename mdint>
std::vector<std::vector<lint>> gauss_jordan(std::vector<std::vector<lint>> mtr, mdint mod) {
// Gauss-Jordan elimination 行基本変形のみを用いるガウス消去法
Expand All @@ -28,7 +27,7 @@ std::vector<std::vector<lint>> gauss_jordan(std::vector<std::vector<lint>> mtr,
mtr[piv][w] ? mod - mtr[piv][w] : 0; // To preserve sign of determinant
}
}
lint pivinv = mod_inverse<lint>(mtr[h][c], mod);
lint pivinv = inv_mod<lint>(mtr[h][c], mod);
for (int hh = 0; hh < H; hh++) {
if (hh == h) continue;
lint coeff = mtr[hh][c] * pivinv % mod;
Expand Down
31 changes: 10 additions & 21 deletions number/bare_mod_algebra.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <numeric>
#include <tuple>
#include <utility>
#include <vector>

// CUT begin
// Solve ax+by=gcd(a, b)
template <class Int> Int extgcd(Int a, Int b, Int &x, Int &y) {
Int d = a;
Expand All @@ -16,9 +16,10 @@ template <class Int> Int extgcd(Int a, Int b, Int &x, Int &y) {
}
return d;
}
// Calculate a^(-1) (MOD m) s if gcd(a, m) == 1
// Calculate x s.t. ax == gcd(a, m) MOD m
template <class Int> Int mod_inverse(Int a, Int m) {

// Calculate a^(-1) (MOD m) if gcd(a, m) == 1
// Calculate x s.t. ax == gcd(a, m) MOD m and 0 <= x < m
template <class Int> Int inv_mod(Int a, Int m) {
Int x, y;
extgcd<Int>(a, m, x, y);
x %= m;
Expand Down Expand Up @@ -76,37 +77,25 @@ template <class Int>

// 蟻本 P.262
// 中国剰余定理を利用して,色々な素数で割った余りから元の値を復元
// 連立線形合同式 A * x = B mod M の解
// 連立線形合同式 A_i x = R_i mod M_i の解
// Requirement: M[i] > 0
// Output: x = first MOD second (if solution exists), (0, 0) (otherwise)
template <class Int>
std::pair<Int, Int>
linear_congruence(const std::vector<Int> &A, const std::vector<Int> &B, const std::vector<Int> &M) {
linear_congruence(const std::vector<Int> &A, const std::vector<Int> &R, const std::vector<Int> &M) {
Int r = 0, m = 1;
assert(A.size() == M.size());
assert(B.size() == M.size());
assert(R.size() == M.size());
for (int i = 0; i < (int)A.size(); i++) {
assert(M[i] > 0);
const Int ai = A[i] % M[i];
Int a = ai * m, b = B[i] - ai * r, d = std::__gcd(M[i], a);
Int a = ai * m, b = R[i] - ai * r, d = std::gcd(M[i], a);
if (b % d != 0) {
return std::make_pair(0, 0); // 解なし
}
Int t = b / d * mod_inverse<Int>(a / d, M[i] / d) % (M[i] / d);
Int t = b / d * inv_mod<Int>(a / d, M[i] / d) % (M[i] / d);
r += m * t;
m *= M[i] / d;
}
return std::make_pair((r < 0 ? r + m : r), m);
}

template <class Int = int, class Long = long long> Int pow_mod(Int x, long long n, Int md) {
static_assert(sizeof(Int) * 2 <= sizeof(Long), "Watch out for overflow");
if (md == 1) return 0;
Int ans = 1;
while (n > 0) {
if (n & 1) ans = (Long)ans * x % md;
x = (Long)x * x % md;
n >>= 1;
}
return ans;
}
47 changes: 47 additions & 0 deletions number/bare_mod_algebra.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
title: Modular arithmetic utilities (C++ の基本型整数に対する拡張 GCD・中国剰余定理・連立線形合同式などの実装)
documentation_of: ./bare_mod_algebra.hpp
---

modint を使わない `int` などの整数型上での剰余演算ユーティリティ集.拡張ユークリッドの互除法,モジュラ逆元,中国剰余定理 (CRT),連立線形合同式の求解などを提供する.

## 使用方法

### `extgcd(a, b, x, y)`

$ax + by = \gcd(a, b)$ を満たす整数 $x, y$ を求め,$\gcd(a, b)$ を返す.

```cpp
long long x, y;
long long g = extgcd(12LL, 8LL, x, y); // g = 4
```

### `inv_mod(a, m)`

$a$ の $m$ を法とする逆元を返す.$\gcd(a, m) = 1$ であることが必要.返り値は $[0, m)$ の範囲.

```cpp
long long inv = inv_mod(3LL, 7LL); // inv = 5 (3*5 = 15 ≡ 1 mod 7)
```

### `inv_gcd(a, b)`

$g = \gcd(a, b)$ および $xa \equiv g \pmod{b}$,$0 \le x < b/g$ を満たす $(g, x)$ を返す.$b \ge 1$ が必要.

### `crt(r, m)`

中国剰余定理.$x \equiv r_i \pmod{m_i}$ $(i = 0, \ldots, n-1)$ を同時に満たす $x$ を求める.解が存在すれば $(x \bmod \mathrm{lcm}, \mathrm{lcm})$ を,存在しなければ $(0, 0)$ を返す.

```cpp
vector<long long> r = {2, 3}, m = {3, 5};
auto [x, lcm] = crt(r, m); // x = 8, lcm = 15
```

### `linear_congruence(A, R, M)`

連立線形合同式 $A_i x \equiv R_i \pmod{M_i}$ の解を求める.解が存在すれば $(x \bmod m, m)$ を,存在しなければ $(0, 0)$ を返す.

```cpp
vector<long long> A = {1, 1}, R = {2, 3}, M = {3, 5};
auto [x, m] = linear_congruence(A, R, M); // x ≡ 8 mod 15
```
25 changes: 18 additions & 7 deletions number/factorize.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#pragma once
#include "../random/xorshift.hpp"
#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <numeric>
#include <vector>

Expand All @@ -28,7 +27,7 @@ inline int get_id(long long n) {
// Miller-Rabin primality test
// https://ja.wikipedia.org/wiki/%E3%83%9F%E3%83%A9%E3%83%BC%E2%80%93%E3%83%A9%E3%83%93%E3%83%B3%E7%B4%A0%E6%95%B0%E5%88%A4%E5%AE%9A%E6%B3%95
// Complexity: O(lg n) per query
struct {
inline struct {
long long modpow(__int128 x, __int128 n, long long mod) noexcept {
__int128 ret = 1;
for (x %= mod; n; x = x * x % mod, n >>= 1) ret = (n & 1) ? ret * x % mod : ret;
Expand All @@ -53,7 +52,19 @@ struct {
}
} is_prime;

struct {
inline struct {
static uint32_t xorshift() noexcept {
static uint32_t x = 123456789;
static uint32_t y = 362436069;
static uint32_t z = 521288629;
static uint32_t w = 88675123;
uint32_t t = x ^ (x << 11);
x = y;
y = z;
z = w;
return w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
}

// Pollard's rho algorithm: find factor greater than 1
long long find_factor(long long n) {
assert(n > 1);
Expand All @@ -63,7 +74,7 @@ struct {
auto f = [&](__int128 x) -> long long { return (x * x + c) % n; };

for (int t = 1;; t++) {
for (c = 0; c == 0 or c + 2 == n;) c = rand_int() % n;
for (c = 0; c == 0 or c + 2 == n;) c = xorshift() % n;
long long x0 = t, m = std::max(n >> 3, 1LL), x, ys, y = x0, r = 1, g, q = 1;
do {
x = y;
Expand All @@ -73,15 +84,15 @@ struct {
ys = y;
for (int i = std::min(m, r - k); i--;)
y = f(y), q = __int128(q) * std::abs(x - y) % n;
g = std::__gcd<long long>(q, n);
g = std::gcd<long long>(q, n);
k += m;
} while (k < r and g <= 1);
r <<= 1;
} while (g <= 1);
if (g == n) {
do {
ys = f(ys);
g = std::__gcd(std::abs(x - ys), n);
g = std::gcd(std::abs(x - ys), n);
} while (g <= 1);
}
if (g != n) return g;
Expand Down
21 changes: 21 additions & 0 deletions number/pow_mod.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once
#include <cassert>
#include <type_traits>

template <class Int> Int pow_mod(Int x, long long n, Int md) {
using Long =
std::conditional_t<std::is_same_v<Int, int>, long long,
std::conditional_t<std::is_same_v<Int, long long>, __int128, void>>;
assert(n >= 0 and md > 0);
if (md == 1) return 0;
if (n == 0) return 1;

x = (x % md + md) % md;
Int ans = 1;
while (n > 0) {
if (n & 1) ans = (Long)ans * x % md;
x = (Long)x * x % md;
n >>= 1;
}
return ans;
}
17 changes: 17 additions & 0 deletions number/pow_mod.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
title: Modular exponentiation (べき乗 mod)
documentation_of: ./pow_mod.hpp
---

整数 $x$, 非負整数 $n$, 正整数 $m$ に対し,$x^n \bmod m$ を $O(\log n)$ で計算する.繰り返し二乗法による実装.`Int` が `int` のとき内部で `long long`,`long long` のとき `__int128` を用いてオーバーフローを回避する.

## 使用方法

```cpp
int a = pow_mod(3, 100, 1000000007); // Int = int
long long b = pow_mod(3LL, 100LL, (long long)1e18 + 9); // Int = long long
```

- `x`: 底.
- `n`: 指数($n \ge 0$).
- `md`: 法($m \ge 1$).$m = 1$ のとき常に $0$ を返す.
6 changes: 3 additions & 3 deletions number/primitive_root.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once
#include "bare_mod_algebra.hpp"
#include "factorize.hpp"
#include "pow_mod.hpp"

// Find smallest primitive root for given number n (最小の原始根探索)
// n must be 2 / 4 / p^k / 2p^k (p: odd prime, k > 0)
Expand All @@ -25,10 +25,10 @@ inline long long find_smallest_primitive_root(long long n) {

for (long long g = 1; g < n; g++) {
if (std::gcd(n, g) != 1) continue;
if (pow_mod<long long, __int128>(g, phi, n) != 1) return -1;
if (pow_mod(g, phi, n) != 1) return -1;
bool ok = true;
for (auto pp : fac) {
if (pow_mod<long long, __int128>(g, phi / pp, n) == 1) {
if (pow_mod(g, phi / pp, n) == 1) {
ok = false;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion number/test/miller-rabin.test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "../factorize.hpp"
#define PROBLEM "https://yukicoder.me/problems/no/3030"
#define PROBLEM "https://yukicoder.me/problems/no/8030"
#include <iostream>

int main() {
Expand Down
4 changes: 2 additions & 2 deletions number/test/sieve.stress.test.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A" // DUMMY
#include "../sieve.hpp"
#include "../modint_runtime.hpp"
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <iostream>
#include <numeric>
#include <vector>
using namespace std;

Expand All @@ -15,7 +15,7 @@ struct Case {

int euler_phi(int x) {
int ret = 0;
for (int d = 1; d <= x; d++) ret += (std::__gcd(d, x) == 1);
for (int d = 1; d <= x; d++) ret += (std::gcd(d, x) == 1);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion rational/rational_number.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ template <class Int, bool AutoReduce = false> struct Rational {
Int num, den; // den >= 0

static constexpr Int my_gcd(Int a, Int b) {
// return __gcd(a, b);
// return gcd(a, b);
if (a < 0) a = -a;
if (b < 0) b = -b;
while (a and b) {
Expand Down
10 changes: 6 additions & 4 deletions utilities/pow_sum.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#pragma once
#include <algorithm>
#include <bit>
#include <type_traits>
#include <utility>

// CUT begin
// {x^n, x^0 + ... + x^(n - 1)} for n >= 1
// {x^n, x^0 + ... + x^(n - 1)} for n >= 0
// Verify: ABC212H
template <typename T, typename Int> std::pair<T, T> pow_sum(T x, Int n) {
using Uint = std::make_unsigned_t<Int>;
if (n == 0) return {1, 0};
T sum = 1, p = x; // ans = x^0 + ... + x^(len - 1), p = x^len
for (int d = std::__lg(n) - 1; d >= 0; d--) {
for (int d = std::bit_width(Uint(n)) - 2; d >= 0; d--) {
sum = sum * (p + 1);
p *= p;
if ((n >> d) & 1) {
Expand Down
Loading