Skip to content

Commit a1f122a

Browse files
add sanity checks and test non-prime input
Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
1 parent 220d83b commit a1f122a

2 files changed

Lines changed: 96 additions & 31 deletions

File tree

src/modular/prime_params.rs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ pub struct PrimeParams<const LIMBS: usize> {
2727

2828
impl<const LIMBS: usize> PrimeParams<LIMBS> {
2929
/// Instantiates a new set of [`PrimeParams`] given [`FixedMontyParams`] for a prime modulus.
30+
///
31+
/// This method will return `None` if the modulus is determined to be non-prime, however
32+
/// this is not an exhaustive check and non-prime values can be accepted.
3033
#[must_use]
3134
#[allow(clippy::unwrap_in_result, clippy::missing_panics_doc)]
3235
pub const fn new_vartime(params: &FixedMontyParams<LIMBS>) -> Option<Self> {
33-
// A primitive root exists if and only if n is 1, 2, 4, p^k or 2p^k, k > 0, p is an odd prime
34-
3536
let p = params.modulus();
3637
let p_minus_one = p.as_ref().set_bit_vartime(0, false);
3738
let s = NonZeroU32::new(p_minus_one.trailing_zeros_vartime()).expect("ensured non-zero");
@@ -53,6 +54,11 @@ impl<const LIMBS: usize> PrimeParams<LIMBS> {
5354
let exp = t.shr_vartime(1);
5455
// the s'th root of unity is calculated as `generator^t`
5556
let root = FixedMontyForm::new(&gen_uint, params).pow_vartime(&t);
57+
// root^(2^(s-1)) must be equal to -1
58+
let check = root.square_repeat_vartime(s.get() - 1);
59+
if !Uint::eq(&check.retrieve(), &p_minus_one).to_bool_vartime() {
60+
return None;
61+
}
5662
(exp, root)
5763
};
5864

@@ -125,6 +131,7 @@ impl<const LIMBS: usize> subtle::ConditionallySelectable for PrimeParams<LIMBS>
125131
const fn find_primitive_root<const LIMBS: usize>(
126132
p: &Odd<Uint<LIMBS>>,
127133
) -> Option<(NonZeroU32, Uint<LIMBS>)> {
134+
// A primitive root exists iff p is 1, 2, 4, q^k or 2q^k, k > 0, q is an odd prime.
128135
// Find a quadratic non-residue (primitive roots are non-residue for powers of a prime)
129136
let mut g = NonZeroU32::new(2u32).expect("ensured non-zero");
130137
loop {
@@ -148,3 +155,55 @@ const fn find_primitive_root<const LIMBS: usize>(
148155
}
149156
}
150157
}
158+
159+
#[cfg(test)]
160+
mod tests {
161+
use super::PrimeParams;
162+
use crate::{Choice, CtEq, CtSelect, Odd, U128, modular::MontyParams};
163+
164+
#[test]
165+
fn check_expected() {
166+
let monty_params =
167+
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc47"));
168+
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
169+
assert_eq!(prime_params.s.get(), 1);
170+
assert_eq!(prime_params.generator.get(), 5);
171+
}
172+
173+
#[test]
174+
fn check_non_prime() {
175+
let monty_params =
176+
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc01"));
177+
assert!(PrimeParams::new_vartime(&monty_params).is_none());
178+
}
179+
180+
#[test]
181+
fn check_equality() {
182+
let monty_params_1 =
183+
MontyParams::new_vartime(Odd::<U128>::from_be_hex("e38af050d74b8567f73c8713cbc7bc47"));
184+
let prime_params_1 =
185+
PrimeParams::new_vartime(&monty_params_1).expect("failed creating params");
186+
187+
let monty_params_2 =
188+
MontyParams::new_vartime(Odd::<U128>::from_be_hex("f2799d643ab7ff983437c3a86cdb1beb"));
189+
let prime_params_2 =
190+
PrimeParams::new_vartime(&monty_params_2).expect("failed creating params");
191+
192+
assert!(CtEq::ct_eq(&prime_params_1, &prime_params_1).to_bool_vartime());
193+
assert!(CtEq::ct_ne(&prime_params_1, &prime_params_2).to_bool_vartime());
194+
assert!(
195+
CtEq::ct_eq(
196+
&CtSelect::ct_select(&prime_params_1, &prime_params_2, Choice::FALSE),
197+
&prime_params_1,
198+
)
199+
.to_bool_vartime()
200+
);
201+
assert!(
202+
CtEq::ct_eq(
203+
&CtSelect::ct_select(&prime_params_1, &prime_params_2, Choice::TRUE),
204+
&prime_params_2,
205+
)
206+
.to_bool_vartime()
207+
);
208+
}
209+
}

src/modular/sqrt.rs

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
use crate::{
22
Choice, CtOption, Uint,
3-
modular::{MontyForm, MontyParams, prime_params::PrimeParams},
3+
modular::{FixedMontyForm, FixedMontyParams, prime_params::PrimeParams},
44
};
55

66
/// Compute a modular square root (if it exists) given [`MontyParams`]
77
/// and [`PrimeParams`] corresponding to `monty_value`, in Montgomery form.
88
#[must_use]
99
pub const fn sqrt_montgomery_form<const LIMBS: usize>(
1010
monty_value: &Uint<LIMBS>,
11-
monty_params: &MontyParams<LIMBS>,
11+
monty_params: &FixedMontyParams<LIMBS>,
1212
prime_params: &PrimeParams<LIMBS>,
1313
) -> CtOption<Uint<LIMBS>> {
14-
let value = MontyForm::from_montgomery(*monty_value, *monty_params);
14+
let value = FixedMontyForm::from_montgomery(*monty_value, monty_params);
1515
let b = value.pow_vartime(&prime_params.sqrt_exp);
1616

1717
// Constant-time versions of modular square root algorithms based on:
@@ -25,23 +25,24 @@ pub const fn sqrt_montgomery_form<const LIMBS: usize>(
2525
}
2626
2 => {
2727
// Algorithm 3: p = 5 mod 8 (Atkins variant)
28-
let ru = MontyForm::from_montgomery(prime_params.monty_root_unity, *monty_params);
28+
let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
2929
let cb = value.mul(&b);
3030
let zeta = cb.mul(&b);
3131
let is_one = Uint::eq(zeta.as_montgomery(), monty_params.one());
3232
monty_select(&cb.mul(&ru), &cb, is_one)
3333
}
3434
3 => {
3535
// Algorithm 4: p = 9 mod 16
36-
let ru = MontyForm::from_montgomery(prime_params.monty_root_unity, *monty_params);
37-
let ru_2 = MontyForm::from_montgomery(prime_params.monty_root_unity_p2, *monty_params);
36+
let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
37+
let ru_2 =
38+
FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
3839
let ru_3 = ru.mul(&ru_2);
3940
let cb = value.mul(&b);
4041
let zeta = cb.mul(&b);
4142

4243
let mut m = monty_select(
4344
&ru,
44-
&MontyForm::one(*ru.params()),
45+
&FixedMontyForm::one(ru.params()),
4546
Uint::eq(zeta.as_montgomery(), monty_params.one()),
4647
);
4748
// m = ru^2 if zeta = -1
@@ -57,8 +58,9 @@ pub const fn sqrt_montgomery_form<const LIMBS: usize>(
5758
}
5859
4 => {
5960
// Algorithm 5: p = 17 mod 32
60-
let ru = MontyForm::from_montgomery(prime_params.monty_root_unity, *monty_params);
61-
let ru_2 = MontyForm::from_montgomery(prime_params.monty_root_unity_p2, *monty_params);
61+
let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
62+
let ru_2 =
63+
FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
6264
let ru_4 = ru_2.square();
6365
let ru_6 = ru_2.mul(&ru_4);
6466
let cb = value.mul(&b);
@@ -71,7 +73,7 @@ pub const fn sqrt_montgomery_form<const LIMBS: usize>(
7173

7274
// m = B if -zeta in (B, C), else 1
7375
let mut m = monty_select(
74-
&MontyForm::one(*ru.params()),
76+
&FixedMontyForm::one(ru.params()),
7577
&ru_2,
7678
neg_zeta_b.or(monty_eq(&neg_zeta, &ru_4)),
7779
);
@@ -99,7 +101,8 @@ pub const fn sqrt_montgomery_form<const LIMBS: usize>(
99101
// Tonelli-Shanks
100102
let mut x = value.mul(&b);
101103
let mut d = x.mul(&b);
102-
let mut z = MontyForm::from_montgomery(prime_params.monty_root_unity, *monty_params);
104+
let mut z =
105+
FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
103106
let mut v = prime_params.s.get();
104107
let mut max_v = v;
105108

@@ -134,18 +137,21 @@ pub const fn sqrt_montgomery_form<const LIMBS: usize>(
134137
CtOption::new(x.to_montgomery(), monty_eq(&x.square(), &value))
135138
}
136139

137-
const fn monty_eq<const LIMBS: usize>(a: &MontyForm<LIMBS>, b: &MontyForm<LIMBS>) -> Choice {
140+
const fn monty_eq<const LIMBS: usize>(
141+
a: &FixedMontyForm<LIMBS>,
142+
b: &FixedMontyForm<LIMBS>,
143+
) -> Choice {
138144
Uint::eq(a.as_montgomery(), b.as_montgomery())
139145
}
140146

141147
const fn monty_select<const LIMBS: usize>(
142-
a: &MontyForm<LIMBS>,
143-
b: &MontyForm<LIMBS>,
148+
a: &FixedMontyForm<LIMBS>,
149+
b: &FixedMontyForm<LIMBS>,
144150
c: Choice,
145-
) -> MontyForm<LIMBS> {
146-
MontyForm::from_montgomery(
151+
) -> FixedMontyForm<LIMBS> {
152+
FixedMontyForm::from_montgomery(
147153
Uint::select(a.as_montgomery(), b.as_montgomery(), c),
148-
*a.params(),
154+
a.params(),
149155
)
150156
}
151157

@@ -154,29 +160,29 @@ mod tests {
154160
use super::sqrt_montgomery_form;
155161
use crate::{
156162
Odd, U256, U576, Uint,
157-
modular::{MontyForm, MontyParams, PrimeParams},
163+
modular::{FixedMontyForm, FixedMontyParams, PrimeParams},
158164
};
159165

160166
fn root_of_unity<const LIMBS: usize>(
161-
monty_params: &MontyParams<LIMBS>,
167+
monty_params: &FixedMontyParams<LIMBS>,
162168
prime_params: &PrimeParams<LIMBS>,
163169
) -> Uint<LIMBS> {
164-
MontyForm::from_montgomery(prime_params.monty_root_unity, *monty_params).retrieve()
170+
FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params).retrieve()
165171
}
166172

167173
fn test_monty_sqrt<const LIMBS: usize>(
168-
monty_params: MontyParams<LIMBS>,
174+
monty_params: FixedMontyParams<LIMBS>,
169175
prime_params: PrimeParams<LIMBS>,
170176
) {
171177
let modulus = monty_params.modulus.get();
172178
let rounds = if cfg!(miri) { 1..=2 } else { 0..=256 };
173179
for i in rounds {
174180
let s = i * i;
175-
let s_monty = MontyForm::new(&Uint::from_u32(s), &monty_params);
181+
let s_monty = FixedMontyForm::new(&Uint::from_u32(s), &monty_params);
176182
let rt_monty =
177183
sqrt_montgomery_form(s_monty.as_montgomery(), &monty_params, &prime_params)
178184
.expect("no sqrt found");
179-
let rt = MontyForm::from_montgomery(rt_monty, monty_params).retrieve();
185+
let rt = FixedMontyForm::from_montgomery(rt_monty, &monty_params).retrieve();
180186
let i = Uint::from_u32(i);
181187
assert!(
182188
Uint::eq(&rt, &i)
@@ -187,7 +193,7 @@ mod tests {
187193

188194
// generator must be non-residue
189195
let generator = Uint::from_u32(prime_params.generator.get());
190-
let gen_monty = MontyForm::new(&generator, &monty_params);
196+
let gen_monty = FixedMontyForm::new(&generator, &monty_params);
191197
assert!(
192198
sqrt_montgomery_form(gen_monty.as_montgomery(), &monty_params, &prime_params)
193199
.is_none()
@@ -199,7 +205,7 @@ mod tests {
199205
fn mod_sqrt_s_1() {
200206
// p = 3 mod 4, s = 1
201207
// P-256 field modulus
202-
let monty_params = MontyParams::new_vartime(Odd::<U256>::from_be_hex(
208+
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
203209
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
204210
));
205211
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
@@ -217,7 +223,7 @@ mod tests {
217223
fn mod_sqrt_s_2() {
218224
// p = 5 mod 8, s = 2
219225
// ed25519 base field
220-
let monty_params = MontyParams::new_vartime(Odd::<U256>::from_be_hex(
226+
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
221227
"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
222228
));
223229
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
@@ -235,7 +241,7 @@ mod tests {
235241
fn mod_sqrt_s_3() {
236242
// p = 9 mod 16, s = 3
237243
// brainpoolP384 scalar field
238-
let monty_params = MontyParams::new_vartime(Odd::<U576>::from_be_hex(
244+
let monty_params = FixedMontyParams::new_vartime(Odd::<U576>::from_be_hex(
239245
"00000000000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
240246
));
241247
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
@@ -255,7 +261,7 @@ mod tests {
255261
fn mod_sqrt_s_4() {
256262
// p = 17 mod 32, s = 4
257263
// P-256 scalar field
258-
let monty_params = MontyParams::new_vartime(Odd::<U256>::from_be_hex(
264+
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
259265
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
260266
));
261267
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");
@@ -273,7 +279,7 @@ mod tests {
273279
fn mod_sqrt_s_6() {
274280
// s = 6
275281
// K-256 scalar field
276-
let monty_params = MontyParams::new_vartime(Odd::<U256>::from_be_hex(
282+
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
277283
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
278284
));
279285
let prime_params = PrimeParams::new_vartime(&monty_params).expect("failed creating params");

0 commit comments

Comments
 (0)