-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathuniform.rs
More file actions
271 lines (236 loc) · 9.1 KB
/
uniform.rs
File metadata and controls
271 lines (236 loc) · 9.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
// Copyright © 2025 Niklas Siemer
//
// This file is part of qFALL-math.
//
// qFALL-math is free software: you can redistribute it and/or modify it under
// the terms of the Mozilla Public License Version 2.0 as published by the
// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
//! This module includes core functionality to sample according to the
//! uniform random distribution.
use crate::{error::MathError, integer::Z};
use flint_sys::fmpz::{fmpz_addmul_ui, fmpz_set_ui};
use rand::{Rng, rngs::ThreadRng};
/// Enables uniformly random sampling a [`Z`] in `[0, interval_size)`.
///
/// Attributes:
/// - `interval_size`: defines the interval [0, interval_size), which we sample from
/// - `two_pow_32`: is a helper to shift bits by 32-bits left by multiplication
/// - `nr_iterations`: defines how many full samples of u32 are required
/// - `upper_modulo`: is a power of two to remove superfluously sampled bits to increase
/// the probability of accepting a sample to at least 1/2
/// - `rng`: defines the [`ThreadRng`] that's used to sample uniform [u32] integers
///
/// # Examples
/// ```
/// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
/// let interval_size = Z::from(20);
///
/// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
///
/// let sample = uis.sample();
///
/// assert!(Z::ZERO <= sample);
/// assert!(sample < interval_size);
/// ```
pub struct UniformIntegerSampler {
interval_size: Z,
two_pow_32: u64,
nr_iterations: u32,
upper_modulo: u32,
rng: ThreadRng,
}
impl UniformIntegerSampler {
/// Initializes the [`UniformIntegerSampler`] with
/// - `interval_size` as `interval_size`,
/// - `two_pow_32` as a [u64] containing 2^32
/// - `nr_iterations` as `(interval_size - 1).bits() / 32` floored
/// - `upper_modulo` as 2^{(interval_size - 1).bits() mod 32}
/// - `rng` as a fresh [`ThreadRng`]
///
/// Parameters:
/// - `interval_size`: specifies the interval `[0, interval_size)`
/// from which the samples are drawn
///
/// Returns a [`UniformIntegerSampler`] or a [`MathError`],
/// if the interval size is chosen smaller than or equal to `1`.
///
/// # Examples
/// ```
/// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
/// let interval_size = Z::from(20);
///
/// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
/// ```
///
/// # Errors and Failures
/// - Returns a [`MathError`] of type [`InvalidInterval`](MathError::InvalidInterval)
/// if the interval is chosen smaller than `1`.
pub fn init(interval_size: &Z) -> Result<Self, MathError> {
if interval_size < &Z::ONE {
return Err(MathError::InvalidInterval(format!(
"An invalid interval size {interval_size} was provided."
)));
}
// Compute 2^32 to be able to shift bits to the left
// by 32 bits using multiplication
let two_pow_32 = u32::MAX as u64 + 1;
let bit_size = (interval_size - Z::ONE).bits() as u32;
// div rounds towards 0, i.e. div_floor in this case, i.e. this is
// perfect for sampling the top one first and then iterating
// nr_iterations-many times
let nr_iterations = bit_size / 32;
// Set upper_modulo to 2^{bit_size mod 32}
// defines how many bits will be discarded / have been sampled too much
let upper_modulo = 2_u32.pow(bit_size % 32);
let rng = rand::rng();
Ok(Self {
interval_size: interval_size.clone(),
two_pow_32,
nr_iterations,
upper_modulo,
rng,
})
}
/// Computes a uniformly chosen [`Z`] sample in `[0, interval_size)`
/// using rejection sampling that accepts samples with probability at least 1/2.
///
/// # Examples
/// ```
/// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
/// let interval_size = Z::from(20);
///
/// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
///
/// let sample = uis.sample();
///
/// assert!(Z::ZERO <= sample);
/// assert!(sample < interval_size);
/// ```
pub fn sample(&mut self) -> Z {
if self.interval_size.is_one() {
return Z::ZERO;
}
let mut sample = self.sample_bits_uniform();
while sample >= self.interval_size {
sample = self.sample_bits_uniform();
}
sample
}
/// Computes `self.nr_iterations * 32 + upper_modulo` many uniformly chosen bits.
///
/// Returns a [`Z`] containing `self.nr_iterations * 32 + upper_modulo`-many uniformly
/// chosen bits.
///
/// # Examples
/// ```
/// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
/// let interval = Z::from(u16::MAX) + 1;
///
/// let mut uis = UniformIntegerSampler::init(&interval).unwrap();
///
/// let sample = uis.sample_bits_uniform();
///
/// assert!(Z::ZERO <= sample);
/// assert!(sample < interval);
/// ```
pub fn sample_bits_uniform(&mut self) -> Z {
// remove superfluously sampled bits to increase chance of acception to at lest 1/2
let mut value = Z::from(self.rng.next_u32() % self.upper_modulo);
for _ in 0..self.nr_iterations {
let sample = self.rng.next_u32();
let mut res = Z::default();
unsafe {
fmpz_set_ui(&mut res.value, sample as u64);
// Sets res = res + value * 2^32 reusing the memory allocated of res
// could be optimized by shifting bits left by 32 bits once lshift is part of flint-sys
fmpz_addmul_ui(&mut res.value, &value.value, self.two_pow_32);
};
value = res;
}
value
}
}
#[cfg(test)]
mod test_uis {
use super::{UniformIntegerSampler, Z};
use std::collections::HashSet;
/// Checks whether sampling works fine for small interval sizes.
#[test]
fn small_interval() {
let size_2 = Z::from(2);
let size_7 = Z::from(7);
let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
let mut uis_7 = UniformIntegerSampler::init(&size_7).unwrap();
for _ in 0..3 {
let sample_2 = uis_2.sample();
let sample_7 = uis_7.sample();
assert!(Z::ZERO <= sample_2);
assert!(sample_2 < size_2);
assert!(Z::ZERO <= sample_7);
assert!(sample_7 < size_7)
}
}
/// Checks whether sampling works fine for large interval sizes.
#[test]
fn large_interval() {
let size_0 = Z::from(u64::MAX);
let size_1 = Z::from(u64::MAX) * 2 + 1;
let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
for _i in 0..u8::MAX {
let sample_0 = uis_0.sample();
let sample_1 = uis_1.sample();
assert!(Z::ZERO <= sample_0);
assert!(sample_0 < size_0);
assert!(Z::ZERO <= sample_1);
assert!(sample_1 < size_1);
}
}
/// Checks whether it samples from the entire interval.
#[test]
fn entire_interval() {
let interval_sizes = vec![6, 7, 16];
for interval_size in interval_sizes {
let interval = Z::from(interval_size);
let mut uis = UniformIntegerSampler::init(&interval).unwrap();
let mut samples = HashSet::new();
for _ in 0..2_u32.pow(interval_size) {
samples.insert(uis.sample());
}
// if len(samples) == interval_size, then every element in [0, interval_size)
// needs to be represented in samples
assert_eq!(
interval_size,
samples.len() as u32,
"This test may fail with low probability."
);
}
}
/// Checks whether interval sizes smaller than 2 result in an error.
#[test]
fn invalid_interval() {
assert!(UniformIntegerSampler::init(&Z::ZERO).is_err());
assert!(UniformIntegerSampler::init(&Z::MINUS_ONE).is_err());
}
/// Checks whether random bit sampling doesn't fill more bits than required.
#[test]
fn sample_bits_uniform_necessary_nr_bytes() {
let size_0 = Z::from(8);
let size_1 = Z::from(256);
let size_2 = Z::from(u32::MAX) + Z::ONE;
let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
for _ in 0..u8::MAX {
let sample_0 = uis_0.sample_bits_uniform();
let sample_1 = uis_1.sample_bits_uniform();
let sample_2 = uis_2.sample_bits_uniform();
assert!(Z::ZERO <= sample_0);
assert!(sample_0 < size_0);
assert!(Z::ZERO <= sample_1);
assert!(sample_1 < size_1);
assert!(Z::ZERO <= sample_2);
assert!(sample_2 < size_2);
}
}
}