|
5 | 5 |
|
6 | 6 | use std::{ |
7 | 7 | cmp::{min, PartialEq}, |
| 8 | + convert::TryFrom, |
8 | 9 | fmt::{self, Debug}, |
9 | 10 | hash::{Hash, Hasher}, |
10 | 11 | iter::{DoubleEndedIterator, FromIterator, FusedIterator}, |
@@ -585,6 +586,58 @@ impl<T: Debug + PrimInt> Vob<T> { |
585 | 586 | } |
586 | 587 | } |
587 | 588 |
|
| 589 | + /// Counts the number of set bits. |
| 590 | + /// This method assumes the range is processed with process_range() |
| 591 | + fn count_set_bits(&self, range: Range<usize>) -> usize { |
| 592 | + // Early return for empty ranges |
| 593 | + if range.is_empty() { |
| 594 | + return 0; |
| 595 | + } |
| 596 | + let start_off = block_offset::<T>(range.start); |
| 597 | + |
| 598 | + // this -1 arithmetic is safe since we already tested for range.start & range.end equality |
| 599 | + let end_off = blocks_required::<T>(range.end) - 1; |
| 600 | + |
| 601 | + if start_off == end_off { |
| 602 | + // Range entirely within one word |
| 603 | + let b = self.vec[start_off]; |
| 604 | + let start_bit = range.start % bits_per_block::<T>(); |
| 605 | + let end_bit = range.end % bits_per_block::<T>(); |
| 606 | + |
| 607 | + // Remove bits before start_bit and bits after end_bit |
| 608 | + let count = if end_bit == 0 { |
| 609 | + // end_bit = 0 means we want everything from start_bit to end of word |
| 610 | + // After the right shift, we have what we want |
| 611 | + b >> start_bit |
| 612 | + } else { |
| 613 | + // We want bits from start_bit to end_bit |
| 614 | + // After right shift, we need to remove the high bits |
| 615 | + (b >> start_bit) << (start_bit + bits_per_block::<T>() - end_bit) |
| 616 | + } |
| 617 | + .count_ones(); |
| 618 | + return usize::try_from(count).unwrap(); |
| 619 | + } |
| 620 | + |
| 621 | + // First word: shift out bits before start_bit |
| 622 | + let start_bit = range.start % bits_per_block::<T>(); |
| 623 | + let mut count = usize::try_from((self.vec[start_off] >> start_bit).count_ones()).unwrap(); |
| 624 | + |
| 625 | + // Middle words |
| 626 | + for word_idx in (start_off + 1)..end_off { |
| 627 | + count += usize::try_from(self.vec[word_idx].count_ones()).unwrap(); |
| 628 | + } |
| 629 | + |
| 630 | + // Last word: shift out bits after end_bit |
| 631 | + let end_bit = range.end % bits_per_block::<T>(); |
| 632 | + let count_ones = if end_bit == 0 { |
| 633 | + // end_bit = 0 means we want to count the entire end_off word |
| 634 | + self.vec[end_off].count_ones() |
| 635 | + } else { |
| 636 | + (self.vec[end_off] << (bits_per_block::<T>() - end_bit)).count_ones() |
| 637 | + }; |
| 638 | + count + usize::try_from(count_ones).unwrap() |
| 639 | + } |
| 640 | + |
588 | 641 | /// Returns an iterator which efficiently produces the index of each unset bit in the specified |
589 | 642 | /// range. Assuming appropriate support from your CPU, this is much more efficient than |
590 | 643 | /// checking each bit individually. |
@@ -1082,6 +1135,10 @@ impl<T: Debug + PrimInt> Iterator for Iter<'_, T> { |
1082 | 1135 | fn size_hint(&self) -> (usize, Option<usize>) { |
1083 | 1136 | self.range.size_hint() |
1084 | 1137 | } |
| 1138 | + |
| 1139 | + fn count(self) -> usize { |
| 1140 | + self.range.count() |
| 1141 | + } |
1085 | 1142 | } |
1086 | 1143 |
|
1087 | 1144 | impl<T: Debug + PrimInt> DoubleEndedIterator for Iter<'_, T> { |
@@ -1148,6 +1205,10 @@ impl<T: Debug + PrimInt> Iterator for IterSetBits<'_, T> { |
1148 | 1205 | fn size_hint(&self) -> (usize, Option<usize>) { |
1149 | 1206 | self.range.size_hint() |
1150 | 1207 | } |
| 1208 | + |
| 1209 | + fn count(self) -> usize { |
| 1210 | + self.vob.count_set_bits(self.range) |
| 1211 | + } |
1151 | 1212 | } |
1152 | 1213 |
|
1153 | 1214 | impl<T: Debug + PrimInt> DoubleEndedIterator for IterSetBits<'_, T> { |
@@ -1228,6 +1289,12 @@ impl<T: Debug + PrimInt> Iterator for IterUnsetBits<'_, T> { |
1228 | 1289 | fn size_hint(&self) -> (usize, Option<usize>) { |
1229 | 1290 | self.range.size_hint() |
1230 | 1291 | } |
| 1292 | + |
| 1293 | + fn count(self) -> usize { |
| 1294 | + // This arithmetic is safe because (self.range.end - self.range.start) is the total number of bits, |
| 1295 | + // and self.vob.count_set_bits() always returns a value less than or equal to that. |
| 1296 | + (self.range.end - self.range.start) - self.vob.count_set_bits(self.range) |
| 1297 | + } |
1231 | 1298 | } |
1232 | 1299 |
|
1233 | 1300 | impl<T: Debug + PrimInt> DoubleEndedIterator for IterUnsetBits<'_, T> { |
@@ -1300,6 +1367,10 @@ impl<T: Debug + PrimInt> Iterator for StorageIter<'_, T> { |
1300 | 1367 | fn size_hint(&self) -> (usize, Option<usize>) { |
1301 | 1368 | self.iter.size_hint() |
1302 | 1369 | } |
| 1370 | + |
| 1371 | + fn count(self) -> usize { |
| 1372 | + self.iter.count() |
| 1373 | + } |
1303 | 1374 | } |
1304 | 1375 |
|
1305 | 1376 | #[inline(always)] |
@@ -1974,6 +2045,27 @@ mod tests { |
1974 | 2045 | for _ in 0..len { |
1975 | 2046 | vob.push(rng.random()); |
1976 | 2047 | } |
| 2048 | + // these tests can later be dialed down, as they noticeable slow down every random vob test. |
| 2049 | + assert_eq!( |
| 2050 | + vob.iter_set_bits(..).count(), |
| 2051 | + vob.iter_set_bits(..).filter(|_| true).count() |
| 2052 | + ); |
| 2053 | + assert_eq!( |
| 2054 | + vob.iter_unset_bits(..).count(), |
| 2055 | + vob.iter_unset_bits(..).filter(|_| true).count() |
| 2056 | + ); |
| 2057 | + if len > 2 { |
| 2058 | + // trigger the edge cases of count_set_bits() |
| 2059 | + let range = 1..len - 1; |
| 2060 | + assert_eq!( |
| 2061 | + vob.iter_set_bits(range.clone()).count(), |
| 2062 | + vob.iter_set_bits(range.clone()).filter(|_| true).count() |
| 2063 | + ); |
| 2064 | + assert_eq!( |
| 2065 | + vob.iter_unset_bits(range.clone()).count(), |
| 2066 | + vob.iter_unset_bits(range.clone()).filter(|_| true).count() |
| 2067 | + ); |
| 2068 | + } |
1977 | 2069 | vob |
1978 | 2070 | } |
1979 | 2071 |
|
@@ -2047,4 +2139,37 @@ mod tests { |
2047 | 2139 | v.push(true); |
2048 | 2140 | assert_eq!(v.vec.len(), 1); |
2049 | 2141 | } |
| 2142 | + |
| 2143 | + #[test] |
| 2144 | + fn test_count() { |
| 2145 | + let mut rng = rand::rng(); |
| 2146 | + |
| 2147 | + for test_len in 1..128 { |
| 2148 | + let vob = random_vob(test_len); |
| 2149 | + assert_eq!( |
| 2150 | + vob.iter_storage().count(), |
| 2151 | + vob.iter_storage().filter(|_| true).count() |
| 2152 | + ); |
| 2153 | + assert_eq!(vob.iter().count(), vob.iter().filter(|_| true).count()); |
| 2154 | + for i in 1..test_len - 1 { |
| 2155 | + let from = rng.random_range(0..i); |
| 2156 | + let to = rng.random_range(from..i); |
| 2157 | + assert_eq!( |
| 2158 | + vob.iter_set_bits(from..to).count(), |
| 2159 | + vob.iter_set_bits(from..to).filter(|_| true).count() |
| 2160 | + ); |
| 2161 | + assert_eq!( |
| 2162 | + vob.iter_unset_bits(from..to).count(), |
| 2163 | + vob.iter_unset_bits(from..to).filter(|_| true).count() |
| 2164 | + ); |
| 2165 | + } |
| 2166 | + } |
| 2167 | + } |
| 2168 | + |
| 2169 | + #[test] |
| 2170 | + fn test_collect_capacity() { |
| 2171 | + // a test to make sure that iter_set_bits().collect() does not always allocate .len() elements |
| 2172 | + let vec: Vec<usize> = Vob::from_elem(false, 100).iter_set_bits(..).collect(); |
| 2173 | + assert_eq!(vec.capacity(), 0); |
| 2174 | + } |
2050 | 2175 | } |
0 commit comments