|
5 | 5 |
|
6 | 6 | use std::{ |
7 | 7 | cmp::{min, PartialEq}, |
| 8 | + convert::TryFrom, |
8 | 9 | fmt::{self, Debug}, |
9 | 10 | hash::{Hash, Hasher}, |
10 | 11 | iter::{DoubleEndedIterator, FromIterator, FusedIterator}, |
@@ -585,6 +586,64 @@ impl<T: Debug + PrimInt> Vob<T> { |
585 | 586 | } |
586 | 587 | } |
587 | 588 |
|
| 589 | + /// Counts the number of set bits. |
| 590 | + /// This method assumes the range is processed with process_range() |
| 591 | + fn count_set_bits(&self, range: Range<usize>) -> usize { |
| 592 | + // Early return for empty ranges |
| 593 | + if range.is_empty() { |
| 594 | + return 0; |
| 595 | + } |
| 596 | + let start_off = block_offset::<T>(range.start); |
| 597 | + debug_assert!( |
| 598 | + start_off < self.len(), |
| 599 | + "start_off {} >= self.len {}", |
| 600 | + start_off, |
| 601 | + self.len() |
| 602 | + ); |
| 603 | + |
| 604 | + // this -1 arithmetic is safe since we already tested for range.start & range.end equality |
| 605 | + let end_off = blocks_required::<T>(range.end) - 1; |
| 606 | + |
| 607 | + if start_off == end_off { |
| 608 | + // Range entirely within one word |
| 609 | + let b = self.vec[start_off]; |
| 610 | + let start_bit = range.start % bits_per_block::<T>(); |
| 611 | + let end_bit = range.end % bits_per_block::<T>(); |
| 612 | + |
| 613 | + // Remove bits before start_bit and bits after end_bit |
| 614 | + let count = if end_bit == 0 { |
| 615 | + // end_bit = 0 means we want everything from start_bit to end of word |
| 616 | + // After the right shift, we have what we want |
| 617 | + b >> start_bit |
| 618 | + } else { |
| 619 | + // We want bits from start_bit to end_bit |
| 620 | + // After right shift, we need to remove the high bits |
| 621 | + (b >> start_bit) << (start_bit + bits_per_block::<T>() - end_bit) |
| 622 | + } |
| 623 | + .count_ones(); |
| 624 | + return usize::try_from(count).unwrap(); |
| 625 | + } |
| 626 | + |
| 627 | + // First word: shift out bits before start_bit |
| 628 | + let start_bit = range.start % bits_per_block::<T>(); |
| 629 | + let mut count = usize::try_from((self.vec[start_off] >> start_bit).count_ones()).unwrap(); |
| 630 | + |
| 631 | + // Middle words |
| 632 | + for word_idx in (start_off + 1)..end_off { |
| 633 | + count += usize::try_from(self.vec[word_idx].count_ones()).unwrap(); |
| 634 | + } |
| 635 | + |
| 636 | + // Last word: shift out bits after end_bit |
| 637 | + let end_bit = range.end % bits_per_block::<T>(); |
| 638 | + let count_ones = if end_bit == 0 { |
| 639 | + // end_bit = 0 means we want to count the entire end_off word |
| 640 | + self.vec[end_off].count_ones() |
| 641 | + } else { |
| 642 | + (self.vec[end_off] << (bits_per_block::<T>() - end_bit)).count_ones() |
| 643 | + }; |
| 644 | + count + usize::try_from(count_ones).unwrap() |
| 645 | + } |
| 646 | + |
588 | 647 | /// Returns an iterator which efficiently produces the index of each unset bit in the specified |
589 | 648 | /// range. Assuming appropriate support from your CPU, this is much more efficient than |
590 | 649 | /// checking each bit individually. |
@@ -1082,6 +1141,10 @@ impl<T: Debug + PrimInt> Iterator for Iter<'_, T> { |
1082 | 1141 | fn size_hint(&self) -> (usize, Option<usize>) { |
1083 | 1142 | self.range.size_hint() |
1084 | 1143 | } |
| 1144 | + |
| 1145 | + fn count(self) -> usize { |
| 1146 | + self.range.count() |
| 1147 | + } |
1085 | 1148 | } |
1086 | 1149 |
|
1087 | 1150 | impl<T: Debug + PrimInt> DoubleEndedIterator for Iter<'_, T> { |
@@ -1148,6 +1211,10 @@ impl<T: Debug + PrimInt> Iterator for IterSetBits<'_, T> { |
1148 | 1211 | fn size_hint(&self) -> (usize, Option<usize>) { |
1149 | 1212 | self.range.size_hint() |
1150 | 1213 | } |
| 1214 | + |
| 1215 | + fn count(self) -> usize { |
| 1216 | + self.vob.count_set_bits(self.range) |
| 1217 | + } |
1151 | 1218 | } |
1152 | 1219 |
|
1153 | 1220 | impl<T: Debug + PrimInt> DoubleEndedIterator for IterSetBits<'_, T> { |
@@ -1228,6 +1295,12 @@ impl<T: Debug + PrimInt> Iterator for IterUnsetBits<'_, T> { |
1228 | 1295 | fn size_hint(&self) -> (usize, Option<usize>) { |
1229 | 1296 | self.range.size_hint() |
1230 | 1297 | } |
| 1298 | + |
| 1299 | + fn count(self) -> usize { |
| 1300 | + // This arithmetic is safe because (self.range.end - self.range.start) is the total number of bits, |
| 1301 | + // and self.vob.count_set_bits() always returns a value less than or equal to that. |
| 1302 | + (self.range.end - self.range.start) - self.vob.count_set_bits(self.range) |
| 1303 | + } |
1231 | 1304 | } |
1232 | 1305 |
|
1233 | 1306 | impl<T: Debug + PrimInt> DoubleEndedIterator for IterUnsetBits<'_, T> { |
@@ -1300,6 +1373,10 @@ impl<T: Debug + PrimInt> Iterator for StorageIter<'_, T> { |
1300 | 1373 | fn size_hint(&self) -> (usize, Option<usize>) { |
1301 | 1374 | self.iter.size_hint() |
1302 | 1375 | } |
| 1376 | + |
| 1377 | + fn count(self) -> usize { |
| 1378 | + self.iter.count() |
| 1379 | + } |
1303 | 1380 | } |
1304 | 1381 |
|
1305 | 1382 | #[inline(always)] |
@@ -1974,6 +2051,27 @@ mod tests { |
1974 | 2051 | for _ in 0..len { |
1975 | 2052 | vob.push(rng.random()); |
1976 | 2053 | } |
| 2054 | + // these tests can later be dialed down, as they noticeable slow down every random vob test. |
| 2055 | + assert_eq!( |
| 2056 | + vob.iter_set_bits(..).count(), |
| 2057 | + vob.iter_set_bits(..).filter(|_| true).count() |
| 2058 | + ); |
| 2059 | + assert_eq!( |
| 2060 | + vob.iter_unset_bits(..).count(), |
| 2061 | + vob.iter_unset_bits(..).filter(|_| true).count() |
| 2062 | + ); |
| 2063 | + if len > 2 { |
| 2064 | + // trigger the edge cases of count_set_bits() |
| 2065 | + let range = 1..len - 1; |
| 2066 | + assert_eq!( |
| 2067 | + vob.iter_set_bits(range.clone()).count(), |
| 2068 | + vob.iter_set_bits(range.clone()).filter(|_| true).count() |
| 2069 | + ); |
| 2070 | + assert_eq!( |
| 2071 | + vob.iter_unset_bits(range.clone()).count(), |
| 2072 | + vob.iter_unset_bits(range.clone()).filter(|_| true).count() |
| 2073 | + ); |
| 2074 | + } |
1977 | 2075 | vob |
1978 | 2076 | } |
1979 | 2077 |
|
@@ -2047,4 +2145,37 @@ mod tests { |
2047 | 2145 | v.push(true); |
2048 | 2146 | assert_eq!(v.vec.len(), 1); |
2049 | 2147 | } |
| 2148 | + |
| 2149 | + #[test] |
| 2150 | + fn test_count() { |
| 2151 | + let mut rng = rand::rng(); |
| 2152 | + |
| 2153 | + for test_len in 1..128 { |
| 2154 | + let vob = random_vob(test_len); |
| 2155 | + assert_eq!( |
| 2156 | + vob.iter_storage().count(), |
| 2157 | + vob.iter_storage().filter(|_| true).count() |
| 2158 | + ); |
| 2159 | + assert_eq!(vob.iter().count(), vob.iter().filter(|_| true).count()); |
| 2160 | + for i in 1..test_len - 1 { |
| 2161 | + let from = rng.random_range(0..i); |
| 2162 | + let to = rng.random_range(from..i); |
| 2163 | + assert_eq!( |
| 2164 | + vob.iter_set_bits(from..to).count(), |
| 2165 | + vob.iter_set_bits(from..to).filter(|_| true).count() |
| 2166 | + ); |
| 2167 | + assert_eq!( |
| 2168 | + vob.iter_unset_bits(from..to).count(), |
| 2169 | + vob.iter_unset_bits(from..to).filter(|_| true).count() |
| 2170 | + ); |
| 2171 | + } |
| 2172 | + } |
| 2173 | + } |
| 2174 | + |
| 2175 | + #[test] |
| 2176 | + fn test_collect_capacity() { |
| 2177 | + // a test to make sure that iter_set_bits().collect() does not always allocate .len() elements |
| 2178 | + let vec: Vec<usize> = Vob::from_elem(false, 100).iter_set_bits(..).collect(); |
| 2179 | + assert_eq!(vec.capacity(), 0); |
| 2180 | + } |
2050 | 2181 | } |
0 commit comments