diff --git a/Cargo.lock b/Cargo.lock index aa21ea4c..97fd18fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,7 +61,7 @@ dependencies = [ "hex-literal", "proptest", "rand_chacha", - "rand_core 0.10.0-rc-3", + "rand_core 0.10.0-rc-4", "zeroize", ] @@ -287,9 +287,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.10.0-rc-3" +version = "0.10.0-rc-4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f66ee92bc15280519ef199a274fe0cafff4245d31bc39aaa31c011ad56cb1f05" +checksum = "43bb1e3655c24705492d72208c9bacdefe07c30c14b8f7664c556a3e1953b72c" [[package]] name = "rand_xorshift" diff --git a/chacha20/Cargo.toml b/chacha20/Cargo.toml index b72ac425..5a54d6b5 100644 --- a/chacha20/Cargo.toml +++ b/chacha20/Cargo.toml @@ -21,7 +21,7 @@ rand_core-compatible RNGs based on those ciphers. [dependencies] cfg-if = "1" cipher = { version = "0.5.0-rc.3", optional = true, features = ["stream-wrapper"] } -rand_core = { version = "0.10.0-rc-3", optional = true, default-features = false } +rand_core = { version = "0.10.0-rc-4", optional = true, default-features = false } # `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate zeroize = { version = "1.8.1", optional = true, default-features = false } diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index eade9265..1b549a43 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -6,10 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use core::fmt::Debug; +use core::{convert::Infallible, fmt::Debug}; use rand_core::{ - CryptoRng, RngCore, SeedableRng, + SeedableRng, TryCryptoRng, TryRngCore, block::{BlockRng, CryptoGenerator, Generator}, }; @@ -149,7 +149,7 @@ pub type BlockPos = U32x2; const BUFFER_SIZE: usize = 64; // NB. this must remain consistent with some currently hard-coded numbers in this module -const BUF_BLOCKS: u8 = BUFFER_SIZE as u8 >> 4; +const BUF_BLOCKS: u8 = BUFFER_SIZE as u8 / BLOCK_WORDS; impl ChaChaCore { /// Generates 4 blocks in parallel with avx2 & neon, but merely fills @@ -291,22 +291,25 @@ macro_rules! impl_chacha_rng { } } } - impl RngCore for $ChaChaXRng { + impl TryRngCore for $ChaChaXRng { + type Error = Infallible; + #[inline] - fn next_u32(&mut self) -> u32 { - self.core.next_word() + fn try_next_u32(&mut self) -> Result { + Ok(self.core.next_word()) } #[inline] - fn next_u64(&mut self) -> u64 { - self.core.next_u64_from_u32() + fn try_next_u64(&mut self) -> Result { + Ok(self.core.next_u64_from_u32()) } #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.core.fill_bytes(dest) + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + self.core.fill_bytes(dest); + Ok(()) } } impl CryptoGenerator for $ChaChaXCore {} - impl CryptoRng for $ChaChaXRng {} + impl TryCryptoRng for $ChaChaXRng {} #[cfg(feature = "zeroize")] impl ZeroizeOnDrop for $ChaChaXCore {} @@ -335,29 +338,37 @@ macro_rules! impl_chacha_rng { pub fn get_word_pos(&self) -> u128 { let mut block_counter = (u64::from(self.core.core.0.state[13]) << 32) | u64::from(self.core.core.0.state[12]); - block_counter = block_counter.wrapping_sub(BUF_BLOCKS as u64); + if self.core.word_offset() != 0 { + block_counter = block_counter.wrapping_sub(BUF_BLOCKS as u64); + } let word_pos = - block_counter as u128 * BLOCK_WORDS as u128 + self.core.index() as u128; + block_counter as u128 * BLOCK_WORDS as u128 + self.core.word_offset() as u128; // eliminate bits above the 68th bit word_pos & ((1 << 68) - 1) } - /// Set the offset from the start of the stream, in 32-bit words. + /// Set the offset from the start of the stream, in 32-bit words. **This + /// value will be erased when calling `set_stream()`, so call + /// `set_stream()` before calling `set_word_pos()`** if you intend on + /// using both of them together. /// /// As with `get_word_pos`, we use a 68-bit number. Since the generator /// simply cycles at the end of its period (1 ZiB), we ignore the upper /// 60 bits. #[inline] pub fn set_word_pos(&mut self, word_offset: u128) { - let index = (word_offset & 0b1111) as usize; - let counter = word_offset >> 4; + let index = (word_offset % BLOCK_WORDS as u128) as usize; + let counter = word_offset / BLOCK_WORDS as u128; //self.set_block_pos(counter as u64); self.core.core.0.state[12] = counter as u32; self.core.core.0.state[13] = (counter >> 32) as u32; - self.core.generate_and_set(index); + self.core.reset_and_skip(index); } - /// Set the block pos and reset the RNG's index. + /// Sets the block pos and resets the RNG's index. **This value will be + /// erased when calling `set_stream()`, so call `set_stream()` before + /// calling `set_block_pos()`** if you intend on using both of them + /// together. /// /// The word pos will be equal to `block_pos * 16 words per block`. /// @@ -370,7 +381,7 @@ macro_rules! impl_chacha_rng { #[inline] #[allow(unused)] pub fn set_block_pos>(&mut self, block_pos: B) { - self.core.reset(); + self.core.reset_and_skip(0); let block_pos = block_pos.into().0; self.core.core.0.state[12] = block_pos[0]; self.core.core.0.state[13] = block_pos[1] @@ -380,11 +391,20 @@ macro_rules! impl_chacha_rng { #[inline] #[allow(unused)] pub fn get_block_pos(&self) -> u64 { - self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32) + let counter = + self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32); + if self.core.word_offset() != 0 { + counter - BUF_BLOCKS as u64 + self.core.word_offset() as u64 / 16 + } else { + counter + } } - /// Set the stream number. The lower 64 bits are used and the rest are - /// discarded. This method takes any of the following: + /// Sets the stream number, resetting the `index` and `block_pos` to 0, + /// effectively setting the `word_pos` to 0 as well. Consider storing + /// the `word_pos` prior to calling this method. + /// + /// This method takes any of the following: /// * `u64` /// * `[u32; 2]` /// * `[u8; 8]` @@ -405,20 +425,23 @@ macro_rules! impl_chacha_rng { /// let mut rng = ChaCha20Rng::from_seed(seed); /// /// // set state[12] to 0, state[13] to 1, state[14] to 2, state[15] to 3 - /// rng.set_block_pos([0u32, 1u32]); /// rng.set_stream([2u32, 3u32]); + /// rng.set_block_pos([0u32, 1u32]); /// /// // confirm that state is set correctly /// assert_eq!(rng.get_block_pos(), 1 << 32); /// assert_eq!(rng.get_stream(), (3 << 32) + 2); + /// + /// // restoring `word_pos`/`index` after calling `set_stream`: + /// let word_pos = rng.get_word_pos(); + /// rng.set_stream(4); + /// rng.set_word_pos(word_pos); /// ``` #[inline] pub fn set_stream>(&mut self, stream: S) { let stream: StreamId = stream.into(); self.core.core.0.state[14..].copy_from_slice(&stream.0); - if self.core.index() != BUFFER_SIZE { - self.core.generate_and_set(self.core.index()); - } + self.set_block_pos(0); } /// Get the stream number. @@ -532,6 +555,7 @@ impl_chacha_rng!(ChaCha20Rng, ChaCha20Core, R20, abst20); pub(crate) mod tests { use hex_literal::hex; + use rand_core::RngCore; use super::*; @@ -864,6 +888,11 @@ pub(crate) mod tests { } rng2.set_stream(51); // switch part way through block for _ in 7..16 { + assert_ne!(rng1.next_u64(), rng2.next_u64()); + } + rng1.set_stream(51); + rng2.set_stream(51); + for _ in 0..16 { assert_eq!(rng1.next_u64(), rng2.next_u64()); } } @@ -892,7 +921,7 @@ pub(crate) mod tests { fn test_chacha_word_pos_zero() { let mut rng = ChaChaRng::from_seed(Default::default()); assert_eq!(rng.core.core.0.state[12], 0); - assert_eq!(rng.core.index(), 64); + assert_eq!(rng.core.word_offset(), 0); assert_eq!(rng.get_word_pos(), 0); rng.set_word_pos(0); assert_eq!(rng.get_word_pos(), 0); @@ -1006,7 +1035,7 @@ pub(crate) mod tests { let seed = Default::default(); let mut rng1 = ChaChaRng::from_seed(seed); - let rng2 = &mut ChaChaRng::from_seed(seed) as &mut dyn CryptoRng; + let mut rng2 = &mut ChaChaRng::from_seed(seed) as &mut dyn CryptoRng; for _ in 0..1000 { assert_eq!(rng1.next_u64(), rng2.next_u64()); } @@ -1015,15 +1044,58 @@ pub(crate) mod tests { #[test] fn stream_id_endianness() { let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + assert_eq!(rng.get_word_pos(), 0); rng.set_stream([3, 3333]); + assert_eq!(rng.get_word_pos(), 0); let expected = 1152671828; assert_eq!(rng.next_u32(), expected); + let mut word_pos = rng.get_word_pos(); + + assert_eq!(word_pos, 1); + + rng.set_stream(1234567); + assert_eq!(rng.get_word_pos(), 0); + let mut block = [0u32; 16]; + for word in 0..block.len() { + block[word] = rng.next_u32(); + } + assert_eq!(rng.get_word_pos(), 16); + assert_eq!(rng.core.word_offset(), 16); + assert_eq!(rng.get_block_pos(), 1); rng.set_stream(1234567); + let mut block_2 = [0u32; 16]; + for word in 0..block_2.len() { + block_2[word] = rng.next_u32(); + } + assert_eq!(rng.get_word_pos(), 16); + assert_eq!(rng.core.word_offset(), 16); + assert_eq!(rng.get_block_pos(), 1); + assert_eq!(block, block_2); + rng.set_stream(1234567); + assert_eq!(rng.get_block_pos(), 0); + assert_eq!(rng.get_word_pos(), 0); + let _ = rng.next_u32(); + + word_pos = rng.get_word_pos(); + assert_eq!(word_pos, 1); + let test = rng.next_u32(); let expected = 3110319182; - assert_eq!(rng.next_u32(), expected); + rng.set_word_pos(65); // old set_stream added 64 to the word_pos + assert!(rng.next_u32() == expected); + rng.set_word_pos(word_pos); + assert_eq!(rng.next_u32(), test); + + word_pos = rng.get_word_pos(); + assert_eq!(word_pos, 2); rng.set_stream([1, 2, 3, 4, 5, 6, 7, 8]); + rng.next_u32(); + rng.next_u32(); + let test = rng.next_u32(); + rng.set_word_pos(130); // old set_stream added another 64 to the word_pos let expected = 3790367479; assert_eq!(rng.next_u32(), expected); + rng.set_word_pos(word_pos); + assert_eq!(rng.next_u32(), test); } /// If this test fails, the backend may not be @@ -1043,7 +1115,7 @@ pub(crate) mod tests { let mut result = [0u8; 64 * 5]; rng.fill_bytes(&mut result); assert_eq!(first_blocks_end_word_pos, rng.get_word_pos()); - assert_eq!(first_blocks_end_block_counter, rng.get_block_pos() - 3); + assert_eq!(first_blocks_end_block_counter, rng.get_block_pos()); if first_blocks[0..64 * 4].ne(&result[64..]) { for (i, (a, b)) in first_blocks.iter().zip(result.iter().skip(64)).enumerate() {