11#pragma once
22#include < algorithm>
33#include < cassert>
4+ #include < numeric>
45#include < tuple>
56#include < utility>
67#include < vector>
78
8- // CUT begin
99// Solve ax+by=gcd(a, b)
1010template <class Int > Int extgcd (Int a, Int b, Int &x, Int &y) {
1111 Int d = a;
@@ -16,9 +16,10 @@ template <class Int> Int extgcd(Int a, Int b, Int &x, Int &y) {
1616 }
1717 return d;
1818}
19- // Calculate a^(-1) (MOD m) s if gcd(a, m) == 1
20- // Calculate x s.t. ax == gcd(a, m) MOD m
21- template <class Int > Int mod_inverse (Int a, Int m) {
19+
20+ // Calculate a^(-1) (MOD m) if gcd(a, m) == 1
21+ // Calculate x s.t. ax == gcd(a, m) MOD m and 0 <= x < m
22+ template <class Int > Int inv_mod (Int a, Int m) {
2223 Int x, y;
2324 extgcd<Int>(a, m, x, y);
2425 x %= m;
@@ -76,37 +77,25 @@ template <class Int>
7677
7778// 蟻本 P.262
7879// 中国剰余定理を利用して,色々な素数で割った余りから元の値を復元
79- // 連立線形合同式 A * x = B mod M の解
80+ // 連立線形合同式 A_i x = R_i mod M_i の解
8081// Requirement: M[i] > 0
8182// Output: x = first MOD second (if solution exists), (0, 0) (otherwise)
8283template <class Int >
8384std::pair<Int, Int>
84- linear_congruence (const std::vector<Int> &A, const std::vector<Int> &B , const std::vector<Int> &M) {
85+ linear_congruence (const std::vector<Int> &A, const std::vector<Int> &R , const std::vector<Int> &M) {
8586 Int r = 0 , m = 1 ;
8687 assert (A.size () == M.size ());
87- assert (B .size () == M.size ());
88+ assert (R .size () == M.size ());
8889 for (int i = 0 ; i < (int )A.size (); i++) {
8990 assert (M[i] > 0 );
9091 const Int ai = A[i] % M[i];
91- Int a = ai * m, b = B [i] - ai * r, d = std::__gcd (M[i], a);
92+ Int a = ai * m, b = R [i] - ai * r, d = std::gcd (M[i], a);
9293 if (b % d != 0 ) {
9394 return std::make_pair (0 , 0 ); // 解なし
9495 }
95- Int t = b / d * mod_inverse <Int>(a / d, M[i] / d) % (M[i] / d);
96+ Int t = b / d * inv_mod <Int>(a / d, M[i] / d) % (M[i] / d);
9697 r += m * t;
9798 m *= M[i] / d;
9899 }
99100 return std::make_pair ((r < 0 ? r + m : r), m);
100101}
101-
102- template <class Int = int , class Long = long long > Int pow_mod (Int x, long long n, Int md) {
103- static_assert (sizeof (Int) * 2 <= sizeof (Long), " Watch out for overflow" );
104- if (md == 1 ) return 0 ;
105- Int ans = 1 ;
106- while (n > 0 ) {
107- if (n & 1 ) ans = (Long)ans * x % md;
108- x = (Long)x * x % md;
109- n >>= 1 ;
110- }
111- return ans;
112- }
0 commit comments