|
5 | 5 |
|
6 | 6 | use std::{ |
7 | 7 | cmp::{min, PartialEq}, |
| 8 | + convert::TryFrom, |
8 | 9 | fmt::{self, Debug}, |
9 | 10 | hash::{Hash, Hasher}, |
10 | 11 | iter::{DoubleEndedIterator, FromIterator, FusedIterator}, |
@@ -585,6 +586,70 @@ impl<T: Debug + PrimInt> Vob<T> { |
585 | 586 | } |
586 | 587 | } |
587 | 588 |
|
| 589 | + /// Counts the number of set bits. |
| 590 | + /// This method assumes the range is processed with process_range() |
| 591 | + fn count_set_bits(&self, range: Range<usize>) -> usize { |
| 592 | + // Early return for empty ranges |
| 593 | + if range.is_empty() { |
| 594 | + return 0; |
| 595 | + } |
| 596 | + let start_off = block_offset::<T>(range.start); |
| 597 | + debug_assert!( |
| 598 | + start_off < self.vec.len(), |
| 599 | + "start_off {} >= self.vec.len() {}", |
| 600 | + start_off, |
| 601 | + self.vec.len() |
| 602 | + ); |
| 603 | + |
| 604 | + // this -1 arithmetic is safe since we already tested for range.start & range.end equality |
| 605 | + let end_off = blocks_required::<T>(range.end) - 1; |
| 606 | + debug_assert!( |
| 607 | + end_off < self.vec.len(), |
| 608 | + "end_off {} >= self.vec.len() {}", |
| 609 | + end_off, |
| 610 | + self.vec.len() |
| 611 | + ); |
| 612 | + |
| 613 | + if start_off == end_off { |
| 614 | + // Range entirely within one word |
| 615 | + let b = self.vec[start_off]; |
| 616 | + let start_bit = range.start % bits_per_block::<T>(); |
| 617 | + let end_bit = range.end % bits_per_block::<T>(); |
| 618 | + |
| 619 | + // Remove bits before start_bit and bits after end_bit |
| 620 | + let count = if end_bit == 0 { |
| 621 | + // end_bit = 0 means we want everything from start_bit to end of word |
| 622 | + // After the right shift, we have what we want |
| 623 | + b >> start_bit |
| 624 | + } else { |
| 625 | + // We want bits from start_bit to end_bit |
| 626 | + // After right shift, we need to remove the high bits |
| 627 | + (b >> start_bit) << (start_bit + bits_per_block::<T>() - end_bit) |
| 628 | + } |
| 629 | + .count_ones(); |
| 630 | + return usize::try_from(count).unwrap(); |
| 631 | + } |
| 632 | + |
| 633 | + // First word: shift out bits before start_bit |
| 634 | + let start_bit = range.start % bits_per_block::<T>(); |
| 635 | + let mut count = usize::try_from((self.vec[start_off] >> start_bit).count_ones()).unwrap(); |
| 636 | + |
| 637 | + // Middle words |
| 638 | + for word_idx in (start_off + 1)..end_off { |
| 639 | + count += usize::try_from(self.vec[word_idx].count_ones()).unwrap(); |
| 640 | + } |
| 641 | + |
| 642 | + // Last word: shift out bits after end_bit |
| 643 | + let end_bit = range.end % bits_per_block::<T>(); |
| 644 | + let count_ones = if end_bit == 0 { |
| 645 | + // end_bit = 0 means we want to count the entire end_off word |
| 646 | + self.vec[end_off].count_ones() |
| 647 | + } else { |
| 648 | + (self.vec[end_off] << (bits_per_block::<T>() - end_bit)).count_ones() |
| 649 | + }; |
| 650 | + count + usize::try_from(count_ones).unwrap() |
| 651 | + } |
| 652 | + |
588 | 653 | /// Returns an iterator which efficiently produces the index of each unset bit in the specified |
589 | 654 | /// range. Assuming appropriate support from your CPU, this is much more efficient than |
590 | 655 | /// checking each bit individually. |
@@ -1082,6 +1147,10 @@ impl<T: Debug + PrimInt> Iterator for Iter<'_, T> { |
1082 | 1147 | fn size_hint(&self) -> (usize, Option<usize>) { |
1083 | 1148 | self.range.size_hint() |
1084 | 1149 | } |
| 1150 | + |
| 1151 | + fn count(self) -> usize { |
| 1152 | + self.range.count() |
| 1153 | + } |
1085 | 1154 | } |
1086 | 1155 |
|
1087 | 1156 | impl<T: Debug + PrimInt> DoubleEndedIterator for Iter<'_, T> { |
@@ -1148,6 +1217,10 @@ impl<T: Debug + PrimInt> Iterator for IterSetBits<'_, T> { |
1148 | 1217 | fn size_hint(&self) -> (usize, Option<usize>) { |
1149 | 1218 | self.range.size_hint() |
1150 | 1219 | } |
| 1220 | + |
| 1221 | + fn count(self) -> usize { |
| 1222 | + self.vob.count_set_bits(self.range) |
| 1223 | + } |
1151 | 1224 | } |
1152 | 1225 |
|
1153 | 1226 | impl<T: Debug + PrimInt> DoubleEndedIterator for IterSetBits<'_, T> { |
@@ -1228,6 +1301,12 @@ impl<T: Debug + PrimInt> Iterator for IterUnsetBits<'_, T> { |
1228 | 1301 | fn size_hint(&self) -> (usize, Option<usize>) { |
1229 | 1302 | self.range.size_hint() |
1230 | 1303 | } |
| 1304 | + |
| 1305 | + fn count(self) -> usize { |
| 1306 | + // This arithmetic is safe because (self.range.end - self.range.start) is the total number of bits, |
| 1307 | + // and self.vob.count_set_bits() always returns a value less than or equal to that. |
| 1308 | + (self.range.end - self.range.start) - self.vob.count_set_bits(self.range) |
| 1309 | + } |
1231 | 1310 | } |
1232 | 1311 |
|
1233 | 1312 | impl<T: Debug + PrimInt> DoubleEndedIterator for IterUnsetBits<'_, T> { |
@@ -1300,6 +1379,10 @@ impl<T: Debug + PrimInt> Iterator for StorageIter<'_, T> { |
1300 | 1379 | fn size_hint(&self) -> (usize, Option<usize>) { |
1301 | 1380 | self.iter.size_hint() |
1302 | 1381 | } |
| 1382 | + |
| 1383 | + fn count(self) -> usize { |
| 1384 | + self.iter.count() |
| 1385 | + } |
1303 | 1386 | } |
1304 | 1387 |
|
1305 | 1388 | #[inline(always)] |
@@ -1974,6 +2057,27 @@ mod tests { |
1974 | 2057 | for _ in 0..len { |
1975 | 2058 | vob.push(rng.random()); |
1976 | 2059 | } |
| 2060 | + // these tests can later be dialed down, as they noticeable slow down every random vob test. |
| 2061 | + assert_eq!( |
| 2062 | + vob.iter_set_bits(..).count(), |
| 2063 | + vob.iter_set_bits(..).filter(|_| true).count() |
| 2064 | + ); |
| 2065 | + assert_eq!( |
| 2066 | + vob.iter_unset_bits(..).count(), |
| 2067 | + vob.iter_unset_bits(..).filter(|_| true).count() |
| 2068 | + ); |
| 2069 | + if len > 2 { |
| 2070 | + // trigger the edge cases of count_set_bits() |
| 2071 | + let range = 1..len - 1; |
| 2072 | + assert_eq!( |
| 2073 | + vob.iter_set_bits(range.clone()).count(), |
| 2074 | + vob.iter_set_bits(range.clone()).filter(|_| true).count() |
| 2075 | + ); |
| 2076 | + assert_eq!( |
| 2077 | + vob.iter_unset_bits(range.clone()).count(), |
| 2078 | + vob.iter_unset_bits(range.clone()).filter(|_| true).count() |
| 2079 | + ); |
| 2080 | + } |
1977 | 2081 | vob |
1978 | 2082 | } |
1979 | 2083 |
|
@@ -2047,4 +2151,37 @@ mod tests { |
2047 | 2151 | v.push(true); |
2048 | 2152 | assert_eq!(v.vec.len(), 1); |
2049 | 2153 | } |
| 2154 | + |
| 2155 | + #[test] |
| 2156 | + fn test_count() { |
| 2157 | + let mut rng = rand::rng(); |
| 2158 | + |
| 2159 | + for test_len in 1..128 { |
| 2160 | + let vob = random_vob(test_len); |
| 2161 | + assert_eq!( |
| 2162 | + vob.iter_storage().count(), |
| 2163 | + vob.iter_storage().filter(|_| true).count() |
| 2164 | + ); |
| 2165 | + assert_eq!(vob.iter().count(), vob.iter().filter(|_| true).count()); |
| 2166 | + for i in 1..test_len - 1 { |
| 2167 | + let from = rng.random_range(0..i); |
| 2168 | + let to = rng.random_range(from..i); |
| 2169 | + assert_eq!( |
| 2170 | + vob.iter_set_bits(from..to).count(), |
| 2171 | + vob.iter_set_bits(from..to).filter(|_| true).count() |
| 2172 | + ); |
| 2173 | + assert_eq!( |
| 2174 | + vob.iter_unset_bits(from..to).count(), |
| 2175 | + vob.iter_unset_bits(from..to).filter(|_| true).count() |
| 2176 | + ); |
| 2177 | + } |
| 2178 | + } |
| 2179 | + } |
| 2180 | + |
| 2181 | + #[test] |
| 2182 | + fn test_collect_capacity() { |
| 2183 | + // a test to make sure that iter_set_bits().collect() does not always allocate .len() elements |
| 2184 | + let vec: Vec<usize> = Vob::from_elem(false, 100).iter_set_bits(..).collect(); |
| 2185 | + assert_eq!(vec.capacity(), 0); |
| 2186 | + } |
2050 | 2187 | } |
0 commit comments