66// option. This file may not be copied, modified, or distributed
77// except according to those terms.
88
9- use core:: fmt:: Debug ;
9+ use core:: { convert :: Infallible , fmt:: Debug } ;
1010
1111use rand_core:: {
12- CryptoRng , RngCore , SeedableRng ,
12+ SeedableRng , TryCryptoRng , TryRngCore ,
1313 block:: { BlockRng , CryptoGenerator , Generator } ,
1414} ;
1515
@@ -149,7 +149,7 @@ pub type BlockPos = U32x2;
149149const BUFFER_SIZE : usize = 64 ;
150150
151151// NB. this must remain consistent with some currently hard-coded numbers in this module
152- const BUF_BLOCKS : u8 = BUFFER_SIZE as u8 >> 4 ;
152+ const BUF_BLOCKS : u8 = BUFFER_SIZE as u8 / BLOCK_WORDS ;
153153
154154impl < R : Rounds , V : Variant > ChaChaCore < R , V > {
155155 /// Generates 4 blocks in parallel with avx2 & neon, but merely fills
@@ -291,22 +291,25 @@ macro_rules! impl_chacha_rng {
291291 }
292292 }
293293 }
294- impl RngCore for $ChaChaXRng {
294+ impl TryRngCore for $ChaChaXRng {
295+ type Error = Infallible ;
296+
295297 #[ inline]
296- fn next_u32 ( & mut self ) -> u32 {
297- self . core. next_word( )
298+ fn try_next_u32 ( & mut self ) -> Result < u32 , Self :: Error > {
299+ Ok ( self . core. next_word( ) )
298300 }
299301 #[ inline]
300- fn next_u64 ( & mut self ) -> u64 {
301- self . core. next_u64_from_u32( )
302+ fn try_next_u64 ( & mut self ) -> Result < u64 , Self :: Error > {
303+ Ok ( self . core. next_u64_from_u32( ) )
302304 }
303305 #[ inline]
304- fn fill_bytes( & mut self , dest: & mut [ u8 ] ) {
305- self . core. fill_bytes( dest)
306+ fn try_fill_bytes( & mut self , dest: & mut [ u8 ] ) -> Result <( ) , Self :: Error > {
307+ self . core. fill_bytes( dest) ;
308+ Ok ( ( ) )
306309 }
307310 }
308311 impl CryptoGenerator for $ChaChaXCore { }
309- impl CryptoRng for $ChaChaXRng { }
312+ impl TryCryptoRng for $ChaChaXRng { }
310313
311314 #[ cfg( feature = "zeroize" ) ]
312315 impl ZeroizeOnDrop for $ChaChaXCore { }
@@ -335,29 +338,37 @@ macro_rules! impl_chacha_rng {
335338 pub fn get_word_pos( & self ) -> u128 {
336339 let mut block_counter = ( u64 :: from( self . core. core. 0 . state[ 13 ] ) << 32 )
337340 | u64 :: from( self . core. core. 0 . state[ 12 ] ) ;
338- block_counter = block_counter. wrapping_sub( BUF_BLOCKS as u64 ) ;
341+ if self . core. word_offset( ) != 0 {
342+ block_counter = block_counter. wrapping_sub( BUF_BLOCKS as u64 ) ;
343+ }
339344 let word_pos =
340- block_counter as u128 * BLOCK_WORDS as u128 + self . core. index ( ) as u128 ;
345+ block_counter as u128 * BLOCK_WORDS as u128 + self . core. word_offset ( ) as u128 ;
341346 // eliminate bits above the 68th bit
342347 word_pos & ( ( 1 << 68 ) - 1 )
343348 }
344349
345- /// Set the offset from the start of the stream, in 32-bit words.
350+ /// Set the offset from the start of the stream, in 32-bit words. **This
351+ /// value will be erased when calling `set_stream()`, so call
352+ /// `set_stream()` before calling `set_word_pos()`** if you intend on
353+ /// using both of them together.
346354 ///
347355 /// As with `get_word_pos`, we use a 68-bit number. Since the generator
348356 /// simply cycles at the end of its period (1 ZiB), we ignore the upper
349357 /// 60 bits.
350358 #[ inline]
351359 pub fn set_word_pos( & mut self , word_offset: u128 ) {
352- let index = ( word_offset & 0b1111 ) as usize ;
353- let counter = word_offset >> 4 ;
360+ let index = ( word_offset % BLOCK_WORDS as u128 ) as usize ;
361+ let counter = word_offset / BLOCK_WORDS as u128 ;
354362 //self.set_block_pos(counter as u64);
355363 self . core. core. 0 . state[ 12 ] = counter as u32 ;
356364 self . core. core. 0 . state[ 13 ] = ( counter >> 32 ) as u32 ;
357- self . core. generate_and_set ( index) ;
365+ self . core. reset_and_skip ( index) ;
358366 }
359367
360- /// Set the block pos and reset the RNG's index.
368+ /// Sets the block pos and resets the RNG's index. **This value will be
369+ /// erased when calling `set_stream()`, so call `set_stream()` before
370+ /// calling `set_block_pos()`** if you intend on using both of them
371+ /// together.
361372 ///
362373 /// The word pos will be equal to `block_pos * 16 words per block`.
363374 ///
@@ -370,7 +381,7 @@ macro_rules! impl_chacha_rng {
370381 #[ inline]
371382 #[ allow( unused) ]
372383 pub fn set_block_pos<B : Into <BlockPos >>( & mut self , block_pos: B ) {
373- self . core. reset ( ) ;
384+ self . core. reset_and_skip ( 0 ) ;
374385 let block_pos = block_pos. into( ) . 0 ;
375386 self . core. core. 0 . state[ 12 ] = block_pos[ 0 ] ;
376387 self . core. core. 0 . state[ 13 ] = block_pos[ 1 ]
@@ -380,11 +391,20 @@ macro_rules! impl_chacha_rng {
380391 #[ inline]
381392 #[ allow( unused) ]
382393 pub fn get_block_pos( & self ) -> u64 {
383- self . core. core. 0 . state[ 12 ] as u64 | ( ( self . core. core. 0 . state[ 13 ] as u64 ) << 32 )
394+ let counter =
395+ self . core. core. 0 . state[ 12 ] as u64 | ( ( self . core. core. 0 . state[ 13 ] as u64 ) << 32 ) ;
396+ if self . core. word_offset( ) != 0 {
397+ counter - BUF_BLOCKS as u64 + self . core. word_offset( ) as u64 / 16
398+ } else {
399+ counter
400+ }
384401 }
385402
386- /// Set the stream number. The lower 64 bits are used and the rest are
387- /// discarded. This method takes any of the following:
403+ /// Sets the stream number, resetting the `index` and `block_pos` to 0,
404+ /// effectively setting the `word_pos` to 0 as well. Consider storing
405+ /// the `word_pos` prior to calling this method.
406+ ///
407+ /// This method takes any of the following:
388408 /// * `u64`
389409 /// * `[u32; 2]`
390410 /// * `[u8; 8]`
@@ -405,20 +425,23 @@ macro_rules! impl_chacha_rng {
405425 /// let mut rng = ChaCha20Rng::from_seed(seed);
406426 ///
407427 /// // set state[12] to 0, state[13] to 1, state[14] to 2, state[15] to 3
408- /// rng.set_block_pos([0u32, 1u32]);
409428 /// rng.set_stream([2u32, 3u32]);
429+ /// rng.set_block_pos([0u32, 1u32]);
410430 ///
411431 /// // confirm that state is set correctly
412432 /// assert_eq!(rng.get_block_pos(), 1 << 32);
413433 /// assert_eq!(rng.get_stream(), (3 << 32) + 2);
434+ ///
435+ /// // restoring `word_pos`/`index` after calling `set_stream`:
436+ /// let word_pos = rng.get_word_pos();
437+ /// rng.set_stream(4);
438+ /// rng.set_word_pos(word_pos);
414439 /// ```
415440 #[ inline]
416441 pub fn set_stream<S : Into <StreamId >>( & mut self , stream: S ) {
417442 let stream: StreamId = stream. into( ) ;
418443 self . core. core. 0 . state[ 14 ..] . copy_from_slice( & stream. 0 ) ;
419- if self . core. index( ) != BUFFER_SIZE {
420- self . core. generate_and_set( self . core. index( ) ) ;
421- }
444+ self . set_block_pos( 0 ) ;
422445 }
423446
424447 /// Get the stream number.
@@ -532,6 +555,7 @@ impl_chacha_rng!(ChaCha20Rng, ChaCha20Core, R20, abst20);
532555pub ( crate ) mod tests {
533556
534557 use hex_literal:: hex;
558+ use rand_core:: RngCore ;
535559
536560 use super :: * ;
537561
@@ -864,6 +888,11 @@ pub(crate) mod tests {
864888 }
865889 rng2. set_stream ( 51 ) ; // switch part way through block
866890 for _ in 7 ..16 {
891+ assert_ne ! ( rng1. next_u64( ) , rng2. next_u64( ) ) ;
892+ }
893+ rng1. set_stream ( 51 ) ;
894+ rng2. set_stream ( 51 ) ;
895+ for _ in 0 ..16 {
867896 assert_eq ! ( rng1. next_u64( ) , rng2. next_u64( ) ) ;
868897 }
869898 }
@@ -892,7 +921,7 @@ pub(crate) mod tests {
892921 fn test_chacha_word_pos_zero ( ) {
893922 let mut rng = ChaChaRng :: from_seed ( Default :: default ( ) ) ;
894923 assert_eq ! ( rng. core. core. 0 . state[ 12 ] , 0 ) ;
895- assert_eq ! ( rng. core. index ( ) , 64 ) ;
924+ assert_eq ! ( rng. core. word_offset ( ) , 0 ) ;
896925 assert_eq ! ( rng. get_word_pos( ) , 0 ) ;
897926 rng. set_word_pos ( 0 ) ;
898927 assert_eq ! ( rng. get_word_pos( ) , 0 ) ;
@@ -1006,7 +1035,7 @@ pub(crate) mod tests {
10061035
10071036 let seed = Default :: default ( ) ;
10081037 let mut rng1 = ChaChaRng :: from_seed ( seed) ;
1009- let rng2 = & mut ChaChaRng :: from_seed ( seed) as & mut dyn CryptoRng ;
1038+ let mut rng2 = & mut ChaChaRng :: from_seed ( seed) as & mut dyn CryptoRng ;
10101039 for _ in 0 ..1000 {
10111040 assert_eq ! ( rng1. next_u64( ) , rng2. next_u64( ) ) ;
10121041 }
@@ -1015,15 +1044,58 @@ pub(crate) mod tests {
10151044 #[ test]
10161045 fn stream_id_endianness ( ) {
10171046 let mut rng = ChaCha20Rng :: from_seed ( [ 0u8 ; 32 ] ) ;
1047+ assert_eq ! ( rng. get_word_pos( ) , 0 ) ;
10181048 rng. set_stream ( [ 3 , 3333 ] ) ;
1049+ assert_eq ! ( rng. get_word_pos( ) , 0 ) ;
10191050 let expected = 1152671828 ;
10201051 assert_eq ! ( rng. next_u32( ) , expected) ;
1052+ let mut word_pos = rng. get_word_pos ( ) ;
1053+
1054+ assert_eq ! ( word_pos, 1 ) ;
1055+
1056+ rng. set_stream ( 1234567 ) ;
1057+ assert_eq ! ( rng. get_word_pos( ) , 0 ) ;
1058+ let mut block = [ 0u32 ; 16 ] ;
1059+ for word in 0 ..block. len ( ) {
1060+ block[ word] = rng. next_u32 ( ) ;
1061+ }
1062+ assert_eq ! ( rng. get_word_pos( ) , 16 ) ;
1063+ assert_eq ! ( rng. core. word_offset( ) , 16 ) ;
1064+ assert_eq ! ( rng. get_block_pos( ) , 1 ) ;
10211065 rng. set_stream ( 1234567 ) ;
1066+ let mut block_2 = [ 0u32 ; 16 ] ;
1067+ for word in 0 ..block_2. len ( ) {
1068+ block_2[ word] = rng. next_u32 ( ) ;
1069+ }
1070+ assert_eq ! ( rng. get_word_pos( ) , 16 ) ;
1071+ assert_eq ! ( rng. core. word_offset( ) , 16 ) ;
1072+ assert_eq ! ( rng. get_block_pos( ) , 1 ) ;
1073+ assert_eq ! ( block, block_2) ;
1074+ rng. set_stream ( 1234567 ) ;
1075+ assert_eq ! ( rng. get_block_pos( ) , 0 ) ;
1076+ assert_eq ! ( rng. get_word_pos( ) , 0 ) ;
1077+ let _ = rng. next_u32 ( ) ;
1078+
1079+ word_pos = rng. get_word_pos ( ) ;
1080+ assert_eq ! ( word_pos, 1 ) ;
1081+ let test = rng. next_u32 ( ) ;
10221082 let expected = 3110319182 ;
1023- assert_eq ! ( rng. next_u32( ) , expected) ;
1083+ rng. set_word_pos ( 65 ) ; // old set_stream added 64 to the word_pos
1084+ assert ! ( rng. next_u32( ) == expected) ;
1085+ rng. set_word_pos ( word_pos) ;
1086+ assert_eq ! ( rng. next_u32( ) , test) ;
1087+
1088+ word_pos = rng. get_word_pos ( ) ;
1089+ assert_eq ! ( word_pos, 2 ) ;
10241090 rng. set_stream ( [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ;
1091+ rng. next_u32 ( ) ;
1092+ rng. next_u32 ( ) ;
1093+ let test = rng. next_u32 ( ) ;
1094+ rng. set_word_pos ( 130 ) ; // old set_stream added another 64 to the word_pos
10251095 let expected = 3790367479 ;
10261096 assert_eq ! ( rng. next_u32( ) , expected) ;
1097+ rng. set_word_pos ( word_pos) ;
1098+ assert_eq ! ( rng. next_u32( ) , test) ;
10271099 }
10281100
10291101 /// If this test fails, the backend may not be
@@ -1043,7 +1115,7 @@ pub(crate) mod tests {
10431115 let mut result = [ 0u8 ; 64 * 5 ] ;
10441116 rng. fill_bytes ( & mut result) ;
10451117 assert_eq ! ( first_blocks_end_word_pos, rng. get_word_pos( ) ) ;
1046- assert_eq ! ( first_blocks_end_block_counter, rng. get_block_pos( ) - 3 ) ;
1118+ assert_eq ! ( first_blocks_end_block_counter, rng. get_block_pos( ) ) ;
10471119
10481120 if first_blocks[ 0 ..64 * 4 ] . ne ( & result[ 64 ..] ) {
10491121 for ( i, ( a, b) ) in first_blocks. iter ( ) . zip ( result. iter ( ) . skip ( 64 ) ) . enumerate ( ) {
0 commit comments