diff --git a/Cargo.toml b/Cargo.toml index 2d00641..6982994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,9 +111,14 @@ harness = false name = "product_zipper" harness = false +[[bench]] +name = "catamorphism" +harness = false + [[bench]] name = "sla" harness = false +required-features = ["viz"] [workspace] members = ["pathmap-derive"] diff --git a/benches/binary_keys.rs b/benches/binary_keys.rs index 2d8ded7..82ffe24 100644 --- a/benches/binary_keys.rs +++ b/benches/binary_keys.rs @@ -61,7 +61,7 @@ fn binary_get(bencher: Bencher, n: u64) { }); } -#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000])] +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000, 100000])] fn binary_val_count_bench(bencher: Bencher, n: u64) { let keys = make_keys(n as usize, 1); @@ -77,6 +77,22 @@ fn binary_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000, 100000])] +fn binary_goat_val_count_bench(bencher: Bencher, n: u64) { + + let keys = make_keys(n as usize, 1); + + let mut map: PathMap = PathMap::new(); + for i in 0..n { map.set_val_at(&keys[i as usize], i); } + + //Benchmark the time taken to count the number of values in the map + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count() + }); + assert_eq!(sink, n as usize); +} + #[divan::bench(args = [50, 100, 200, 400, 800, 1600])] fn binary_drop_head(bencher: Bencher, n: u64) { diff --git a/benches/catamorphism.rs b/benches/catamorphism.rs new file mode 100644 index 0000000..11d2fb5 --- /dev/null +++ b/benches/catamorphism.rs @@ -0,0 +1,105 @@ +use divan::{Divan, Bencher, black_box}; +use pathmap::morphisms::{Catamorphism, Summarization}; +use pathmap::utils::ByteMask; +use pathmap::utils::ints::gen_int_range; +use pathmap::PathMap; + +fn main() { + // Run registered benchmarks. + let divan = Divan::from_args() + .sample_count(4000); + + divan.main(); +} + +fn build_map(count: u64) -> PathMap<()> { + // Dense range of u64 keys encoded as paths; sized to keep benches fast and stable. + gen_int_range::<(), 8, u64>(0, count, 1, ()) +} + +const MAP_COUNT: u64 = 20_000_000; + +#[divan::bench()] +fn recursive_cata_jumping_val_count(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = 0usize; + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.recursive_cata::<_, _, _, _, _, false>( + |v, w, _| (v.is_some() as usize) + w.unwrap_or(0), + |_mask, w: usize, total| { *total += w }, + |_mask, total: usize| { total }, + ); + }); + assert_eq!(sink, MAP_COUNT as usize); +} + +#[divan::bench()] +fn cached_jumping_cata_val_count(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = 0usize; + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.into_cata_jumping_cached(|_mask: &ByteMask, children: &mut [usize], val, _sub_path| { + let mut sum: usize = children.iter().sum(); + if val.is_some() { + sum += 1; + } + sum + }); + }); + assert_eq!(sink, MAP_COUNT as usize); +} + +#[divan::bench()] +fn recursive_cata_jumping_total_len(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = (0usize, 0usize); + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let (mut count, mut total_len) = downstream.unwrap_or((0, 0)); + total_len += count * prefix.len(); + if val.is_some() { + count += 1; + total_len += prefix.len(); + } + (count, total_len) + }, + |mask: &ByteMask, w: (usize, usize), acc: &mut (usize, usize, usize)| { + let byte = mask.indexed_bit::(acc.0).unwrap(); + let _ = byte; // byte value unused; only length matters + acc.0 += 1; + acc.1 += w.0; + acc.2 += w.1; + }, + |_mask: &ByteMask, acc: (usize, usize, usize)| { (acc.1, acc.2) }, + ); + }); + assert_eq!(sink.0, MAP_COUNT as usize); +} + +#[divan::bench()] +fn cached_jumping_cata_total_len(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = (0usize, 0usize); + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.into_cata_jumping_cached(|mask: &ByteMask, children: &mut [(usize, usize)], val, sub_path| { + let mut count = 0usize; + let mut total_len = 0usize; + let prefix_len = sub_path.len(); + if val.is_some() { + count += 1; + total_len += prefix_len; + } + for (_byte, child) in mask.iter().zip(children.iter_mut()) { + count += child.0; + total_len += child.1 + child.0 * prefix_len; + } + (count, total_len) + }); + }); + assert_eq!(sink.0, MAP_COUNT as usize); +} diff --git a/benches/cities.rs b/benches/cities.rs index 231cb6e..cc5cbc9 100644 --- a/benches/cities.rs +++ b/benches/cities.rs @@ -168,6 +168,25 @@ fn cities_val_count(bencher: Bencher) { assert_eq!(sink, unique_count); } +#[divan::bench()] +fn cities_goat_val_count(bencher: Bencher) { + + let pairs = read_data(); + let mut map = PathMap::new(); + let mut unique_count = 0; + for (k, v) in pairs.iter() { + if map.set_val_at(k, *v).is_none() { + unique_count += 1; + } + } + + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count(); + }); + assert_eq!(sink, unique_count); +} + #[cfg(feature="arena_compact")] #[divan::bench()] fn cities_val_count_act(bencher: Bencher) { diff --git a/benches/shakespeare.rs b/benches/shakespeare.rs index 2040ba5..5528f73 100644 --- a/benches/shakespeare.rs +++ b/benches/shakespeare.rs @@ -113,6 +113,25 @@ fn shakespeare_words_val_count(bencher: Bencher) { assert_eq!(sink, unique_count); } +#[divan::bench()] +fn shakespeare_words_goat_val_count(bencher: Bencher) { + + let strings = read_data(true); + let mut map = PathMap::new(); + let mut unique_count = 0; + for (v, k) in strings.iter().enumerate() { + if map.set_val_at(k, v).is_none() { + unique_count += 1; + } + } + + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count(); + }); + assert_eq!(sink, unique_count); +} + #[divan::bench()] fn shakespeare_sentences_insert(bencher: Bencher) { @@ -168,6 +187,25 @@ fn shakespeare_sentences_val_count(bencher: Bencher) { assert_eq!(sink, unique_count); } +#[divan::bench()] +fn shakespeare_sentences_goat_val_count(bencher: Bencher) { + + let strings = read_data(false); + let mut map = PathMap::new(); + let mut unique_count = 0; + for (v, k) in strings.iter().enumerate() { + if map.set_val_at(k, v).is_none() { + unique_count += 1; + } + } + + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count(); + }); + assert_eq!(sink, unique_count); +} + #[cfg(feature="arena_compact")] #[divan::bench()] fn shakespeare_sentences_val_count_act(bencher: Bencher) { diff --git a/benches/sla.rs b/benches/sla.rs index 5d84dbe..92d89f0 100644 --- a/benches/sla.rs +++ b/benches/sla.rs @@ -388,7 +388,7 @@ fn tipover_attention_weave() { // let res = rtq.vF_mut().merkleize(); // println!("{:?}", res.hash); let t0 = Instant::now(); - println!("{:?} {:?}", rtq.vF().read_zipper().into_cata_cached(morphisms::alg::hash), t0.elapsed().as_micros()); + // println!("{:?} {:?}", rtq.vF().read_zipper().into_cata_cached(morphisms::alg::hash), t0.elapsed().as_micros()); return; // rtk.vF_mut().merkleize(); diff --git a/benches/sparse_keys.rs b/benches/sparse_keys.rs index 8489314..42c4d42 100644 --- a/benches/sparse_keys.rs +++ b/benches/sparse_keys.rs @@ -92,6 +92,26 @@ fn sparse_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000])] +fn sparse_goat_val_count_bench(bencher: Bencher, n: u64) { + + let mut r = StdRng::seed_from_u64(1); + let keys: Vec> = (0..n).into_iter().map(|_| { + let len = (r.random::() % 18) + 3; //length between 3 and 20 chars + (0..len).into_iter().map(|_| r.random::()).collect() + }).collect(); + + let mut map: PathMap = PathMap::new(); + for i in 0..n { map.set_val_at(&keys[i as usize], i); } + + //Benchmark the time taken to count the number of values in the map + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count() + }); + assert_eq!(sink, n as usize); +} + #[divan::bench(args = [50, 100, 200, 400, 800, 1600])] fn binary_drop_head(bencher: Bencher, n: u64) { diff --git a/benches/superdense_keys.rs b/benches/superdense_keys.rs index 597954d..34a09ce 100644 --- a/benches/superdense_keys.rs +++ b/benches/superdense_keys.rs @@ -253,6 +253,21 @@ fn superdense_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } +#[divan::bench(sample_size = 1, args = [100, 200, 400, 800, 1600, 3200, 20_000])] +fn superdense_goat_val_count_bench(bencher: Bencher, n: u64) { + + let mut map: PathMap = PathMap::new(); + for i in 0..n { map.set_val_at(prefix_key(&i), i); } + + //Benchmark the time taken to count the number of values in the map + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count() + }); + assert_eq!(sink, n as usize); +} + + #[cfg(feature="arena_compact")] #[divan::bench(sample_size = 1, args = [100, 200, 400, 800, 1600, 3200, 20_000])] fn superdense_val_count_bench_act(bencher: Bencher, n: u64) { diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index dd846a2..5c4144c 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -1,7 +1,6 @@ use core::fmt::{Debug, Formatter}; use core::ptr; -use std::collections::HashMap; use std::hint::unreachable_unchecked; use crate::alloc::Allocator; @@ -12,6 +11,8 @@ use crate::utils::BitMask; use crate::trie_node::*; use crate::line_list_node::LineListNode; +use crate::gxhash::HashMap; + //NOTE: This: `core::array::from_fn(|i| i as u8);` ought to work, but https://github.com/rust-lang/rust/issues/109341 const ALL_BYTES: [u8; 256] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]; @@ -29,7 +30,7 @@ pub struct ByteNode { #[cfg(feature = "nightly")] values: Vec, #[cfg(not(feature = "nightly"))] - values: Vec, + pub(crate) values: Vec, alloc: A, } @@ -310,6 +311,10 @@ impl> ByteNode /// Iterates the entries in `self`, calling `func` for each entry /// The arguments to `func` are: `func(self, key_byte, n)`, where `n` is the number of times /// `func` has been called prior. This corresponds to index of the `CoFree` in the `values` vec + /// + /// PERF GOAT: this method is useful if you know you need the corresponding path byte. If, however, the + /// byte might be unneeded, it's better to let self.values govern the loop, because using the mask + /// in the loop means the bitwise ops to manipulate the mask are always required. #[inline] fn for_each_item(&self, mut func: F) { let mut n = 0; @@ -326,6 +331,63 @@ impl> ByteNode } } } + + #[inline(always)] + pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF, cache: &mut HashMap) -> W + where + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + let mut mask_idx = 0; + let mut lm = unsafe{ *self.mask.0.get_unchecked(0) }; + let mut ws = Some(Acc::default()); + for cf in self.values.iter() { + //Compute the key byte. Hopefully this will all be stripped away by the compiler if the path isn't used + //UPDATE: alas, my hopes were dashed. No amount of reorganizing this code, eliminating all traps, + // unrolling the loop, etc., could convince LLVM to elide it. So we have to hit it with the const hammer. + let key_byte; + let path = if COMPUTE_PATH { + while lm == 0 { + mask_idx += 1; + lm = unsafe{ *self.mask.0.get_unchecked(mask_idx) }; + } + let byte_index = lm.trailing_zeros(); + lm ^= 1u64 << byte_index; + key_byte = 64*(mask_idx as u8) + (byte_index as u8); + core::slice::from_ref(&key_byte) + } else { + &[] + }; + + //Do the recursive calling + //PERF NOTE: The reason we have four code paths around the call to `branch_f` instead of just doing + // `let w = cf.rec().map(|rec| recursive_cata(...))` and then calling branch_f with the appropriate + // pair of `Option<&V>` and `Option` is that the compiler won't optimize the implemntation of + // `branch_f` around whether the options are none or not. That means we often pay for two dependent + // branches instead of one, and the difference was 25% to the val_count benchmark. Breaking out the calls + // like this gives the optimizer a site to specialize for each permutation + match (cf.rec(), cf.val()) { + (Some(rec), Some(val)) => { + let w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f, cache); + branch_f(&self.mask, collapse_f(Some(val), Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + (Some(rec), None) => { + let w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f, cache); + branch_f(&self.mask, collapse_f(None, Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + (None, Some(val)) => { + branch_f(&self.mask, collapse_f(Some(val), None, path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + (None, None) => { + branch_f(&self.mask, collapse_f(None, None, path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + } + } + finalize_f(&self.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) + } } impl> ByteNode where Self: TrieNodeDowncast { @@ -967,7 +1029,7 @@ impl> TrieNode } } } - fn node_val_count(&self, cache: &mut HashMap) -> usize { + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { //Discussion: These two implementations do the same thing but with a slightly different ordering of // the operations. In `all_dense_nodes`, the "Branchy" impl wins. But in a mixed-node setting, the // IMPL B is the winner. My suspicion is that the ListNode's heavily branching structure leads to @@ -991,10 +1053,18 @@ impl> TrieNode t + cf.has_val() as usize + cf.rec().map(|r| val_count_below_node(r, cache)).unwrap_or(0) }); } - fn node_goat_val_count(&self) -> usize { +/* fn node_goat_val_count(&self) -> usize { return self.values.iter().rfold(0, |t, cf| { - t + cf.has_val() as usize + t + cf.has_val() as usize + cf.rec().map(|r| r.as_tagged().node_goat_val_count()).unwrap_or(0) }); + }*/ + #[inline] + fn node_goat_val_count(&self) -> usize { + let mut result = 0; + for cf in self.values.iter() { + result += cf.has_val() as usize + } + result } fn node_child_iter_start(&self) -> (u64, Option<&TrieNodeODRc>) { for (pos, cf) in self.values.iter().enumerate() { diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 040f714..1e7844e 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -1,6 +1,6 @@ use core::hint::unreachable_unchecked; use core::mem::{ManuallyDrop, MaybeUninit}; -use std::collections::HashMap; +use crate::gxhash::HashMap; use fast_slice_utils::{find_prefix_overlap, starts_with}; use local_or_heap::LocalOrHeap; @@ -403,7 +403,7 @@ impl LineListNode { } } #[inline] - unsafe fn child_in_slot(&self) -> &TrieNodeODRc { + pub(crate) unsafe fn child_in_slot(&self) -> &TrieNodeODRc { match SLOT { 0 => unsafe{ &*self.val_or_child0.child }, 1 => unsafe{ &*self.val_or_child1.child }, @@ -419,7 +419,7 @@ impl LineListNode { } } #[inline] - unsafe fn val_in_slot(&self) -> &V { + pub(crate) unsafe fn val_in_slot(&self) -> &V { match SLOT { 0 => unsafe{ &**self.val_or_child0.val }, 1 => unsafe{ &**self.val_or_child1.val }, @@ -1968,7 +1968,7 @@ impl TrieNode for LineListNode } } #[inline] - fn node_val_count(&self, cache: &mut HashMap) -> usize { + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { let mut result = 0; if self.is_used_value_0() { result += 1; @@ -1986,6 +1986,25 @@ impl TrieNode for LineListNode } result } +/* #[inline] + fn node_goat_val_count(&self) -> usize { + let mut result = 0; + if self.is_used_value_0() { + result += 1; + } + if self.is_used_value_1() { + result += 1; + } + if self.is_used_child_0() { + let child_node = unsafe{ self.child_in_slot::<0>() }; + result += child_node.as_tagged().node_goat_val_count(); + } + if self.is_used_child_1() { + let child_node = unsafe{ self.child_in_slot::<1>() }; + result += child_node.as_tagged().node_goat_val_count(); + } + result + }*/ #[inline] fn node_goat_val_count(&self) -> usize { //Here are 3 alternative implementations. They're basically the same in perf, with a slight edge to the @@ -2722,6 +2741,237 @@ impl LineListNode { } } } + + #[inline(always)] + pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF, cache: &mut HashMap) -> W + where + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + //Pair node can have the following permutations: (Slot0, Slot1) + // + // - Case 1 (Empty, Empty) + // Run only `finalize_f` on default `Acc` + // - Case 2 (Child, Empty) + // Recursively call on child, then run only `collapse_f` on the result, specifying the path + // - Case 3 (Child, Val), 1-byte key, same key byte + // Recursively call on child, then run only `collapse_f` on the result, specifying the value and the 1-byte path + // - Case 4 (Child, Val), different key bytes + // Recursively call on child, run `branch_f(collapse_f())` on the result. Then run `branch_f(collapse_f())` again + // with the value's path. And finally, run `finalize_f()` + // - Case 5 (Child, Child) + // Recursively call on child0, run `branch_f(collapse_f())` on the result. Do the same for child1. Finally, run + // `finalize_f()` + // - Case 6 (Val, Empty) + // Run only `collapse_f` on the val + // - Case 7 (Val, Val), common first byte (meaning slot0 is a 1-byte path) + // run `collapse_f` on the slot1 val, specifying the path, then run `collapse_f` again on the slot0 val, specifying + // the common prefix byte + // - Case 8 (Val, Val), different first bytes + // Run `branch_f(collapse_f())` on each val, then `finalize_f` at the end + // - Case 9 (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) + // See "Case 3 (Child, Val), 1-byte key, same key byte" + // - Case 10 (Val, Child), different key bytes + // See "Case 4 (Child, Val), different key bytes" + // + //GOAT, It would be nice to refactor the pair_node in order to express each of these permutations as a unique value for the 4 header bits, so we could take the appropriate code path without looking at any path bytes + + match self.header >> 12 { + //Case 1 (Empty, Empty) + 0 => finalize_f(&ByteMask::new(), Acc::default()), + //Case 2 (Child, Empty) = (1 << 3) + (1 << 1) | (1 << 3) + (1 << 1) + 1 + 10 | 11 => { + let child_node = unsafe{ self.child_in_slot::<0>() }; + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); + let path = if COMPUTE_PATH { + unsafe{ self.key_unchecked::<0>() } + } else { + &[] + }; + collapse_f(None, Some(child_w), path) + }, + //(Child, Val) = (1 << 3) + (1 << 2) + (1 << 1) + 14 => { + let child_node = unsafe{ self.child_in_slot::<0>() }; + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + if key0_byte == key1_byte { + //Case 3 + debug_assert_eq!(key0.len(), 1); + debug_assert_eq!(key1.len(), 1); + let val = unsafe { self.val_in_slot::<1>() }; + let path = if COMPUTE_PATH { + key0 + } else { + &[] + }; + collapse_f(Some(val), Some(child_w), path) + } else { + //Case 4 + let mut acc = Acc::default(); + let (path, mask) = if COMPUTE_PATH { + (&key0[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], ByteMask::new()) + }; + branch_f(&mask, collapse_f(None, Some(child_w), path), &mut acc); + + let val = unsafe { self.val_in_slot::<1>() }; + let path = if COMPUTE_PATH { + &key1[1..] + } else { + &[] + }; + branch_f(&mask, collapse_f(Some(val), None, path), &mut acc); + + finalize_f(&mask, acc) + } + }, + //Case 5 (Child, Child) = (1 << 3) + (1 << 2) + (1 << 1) + 1 + 15 => { + let mut acc = Acc::default(); + let child_node = unsafe{ self.child_in_slot::<0>() }; + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); + let (path0, path1, mask) = if COMPUTE_PATH { + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + (&key0[1..], &key1[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], &[] as &[u8], ByteMask::new()) + }; + branch_f(&mask, collapse_f(None, Some(child_w), path0), &mut acc); + + let child_node = unsafe{ self.child_in_slot::<1>() }; + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); + branch_f(&mask, collapse_f(None, Some(child_w), path1), &mut acc); + + finalize_f(&mask, acc) + }, + //Case 6 (Val, Empty) = (1 << 3) | (1 << 3) + 1 + 8 | 9 => { + let val = unsafe { self.val_in_slot::<0>() }; + let path = if COMPUTE_PATH { + unsafe{ self.key_unchecked::<0>() } + } else { + &[] + }; + collapse_f(Some(val), None, path) + }, + //(Val, Val) = (1 << 3) + (1 << 2) + 12 => { + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + if key0_byte == key1_byte { + //Case 7 (Val, Val), common first byte (meaning slot0 is a 1-byte path) + debug_assert_eq!(key0.len(), 1); + debug_assert!(key1.len() > 1); + let val = unsafe { self.val_in_slot::<1>() }; + let path = if COMPUTE_PATH { + &key1[1..] + } else { + &[] + }; + let w1 = collapse_f(Some(val), None, path); + let val = unsafe { self.val_in_slot::<0>() }; + let path = if COMPUTE_PATH { + &key1[0..1] + } else { + &[] + }; + collapse_f(Some(val), Some(w1), path) + } else { + //Case 8 (Val, Val), different first bytes + let mut acc = Acc::default(); + let val = unsafe{ self.val_in_slot::<0>() }; + let (path0, path1, mask) = if COMPUTE_PATH { + (&key0[1..], &key1[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], &[] as &[u8], ByteMask::new()) + }; + branch_f(&mask, collapse_f(Some(val), None, path0), &mut acc); + + let val = unsafe{ self.val_in_slot::<1>() }; + branch_f(&mask, collapse_f(Some(val), None, path1), &mut acc); + + finalize_f(&mask, acc) + } + }, + //(Val, Child) = (1 << 3) + (1 << 2) + 1 + 13 => { + let child_node = unsafe{ self.child_in_slot::<1>() }; + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + if key0_byte == key1_byte { + //Case 9 (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) + debug_assert_eq!(key0.len(), 1); + debug_assert_eq!(key1.len(), 1); + let val = unsafe { self.val_in_slot::<0>() }; + let path = if COMPUTE_PATH { + key0 + } else { + &[] + }; + collapse_f(Some(val), Some(child_w), path) + } else { + //Case 10 (Val, Child), different key bytes + let mut acc = Acc::default(); + + let (path, mask) = if COMPUTE_PATH { + (&key1[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], ByteMask::new()) + }; + branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), path), &mut acc); + + let val = unsafe { self.val_in_slot::<0>() }; + let path = if COMPUTE_PATH { + &key0[1..] + } else { + &[] + }; + branch_f(&mask, collapse_f(Some(val), None, path), &mut acc); + + finalize_f(&mask, acc) + } + }, + _ => { unsafe { unreachable_unchecked() } } + } + + // let mut ws = Some(Acc::default()); + + // if self.is_used_value_0() { + // let downstream = collapse_f(Some(unsafe { self.val_in_slot::<0>() }), None, &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + // } + // if self.is_used_value_1() { + // let downstream = collapse_f(Some(unsafe { self.val_in_slot::<1>() }), None, &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + // } + // if self.is_used_child_0() { + // let child_node = unsafe{ self.child_in_slot::<0>() }; + // let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + // let downstream = collapse_f(None, Some(w), &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + + // } + // if self.is_used_child_1() { + // let child_node = unsafe{ self.child_in_slot::<1>() }; + // let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + // let downstream = collapse_f(None, Some(w), &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + // } + + // finalize_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) + } } impl TrieNodeDowncast for LineListNode { @@ -3337,4 +3587,4 @@ mod tests { // and ZipperMoving // 2. implement a val_count convenience on top of 1. -//GOAT, Paths in caching Cata: https://github.com/Adam-Vandervorst/PathMap/pull/8#discussion_r2004828957 \ No newline at end of file +//GOAT, Paths in caching Cata: https://github.com/Adam-Vandervorst/PathMap/pull/8#discussion_r2004828957 diff --git a/src/morphisms.rs b/src/morphisms.rs index cc5b1f0..9fe0c88 100644 --- a/src/morphisms.rs +++ b/src/morphisms.rs @@ -72,8 +72,10 @@ use crate::utils::*; use crate::alloc::Allocator; use crate::PathMap; use crate::trie_node::TrieNodeODRc; +use crate::trie_node::recursive_cata_cached; use crate::zipper; use crate::zipper::*; +use crate::zipper::zipper_priv::ZipperPriv; use crate::gxhash::{HashMap, HashMapExt}; @@ -213,8 +215,8 @@ pub trait Catamorphism { /// See [into_cata_cached](Catamorphism::into_cata_cached) for explanation of other arguments and behavior fn into_cata_jumping_cached(self, alg_f: AlgF) -> W where - W: Clone, - AlgF: Fn(&ByteMask, &mut [W], Option<&V>, &[u8]) -> W, + W: Clone, + AlgF: Fn(&ByteMask, &mut [W], Option<&V>, &[u8]) -> W, Self: Sized { self.into_cata_jumping_cached_fallible(|mask, children, val, sub_path| -> Result { @@ -231,6 +233,55 @@ pub trait Catamorphism { AlgF: Fn(&ByteMask, &mut [W], Option<&V>, &[u8]) -> Result; } +/// Provides faster catamorphism methods for types backed by an in-memory trie, such as [`PathMap`] +/// and some zipper implementations +pub trait Summarization { + + /// GOAT recursive cached cata. If this dev branch is successful this should replace the caching cata flavors in the public API + /// This is the JUMPING cata + /// + /// Closures: + /// + /// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path into a single `W` + /// `fn(val: Option<&V>, downstream: Option, prefix: &[u8]) -> W` + /// + /// `BranchF`: Accumulates the `W` representing a downstream branch into an `Acc` accumulator type + /// `fn(branch_mask: &ByteMask, downstream: W, accumulator: &mut Acc)` + /// + /// `FinalizeF`: Converts an `Acc` accumulator into a `W` representing the logical node + /// `fn(branch_mask: &ByteMask, accumulator: Acc) -> W` + /// + /// GOAT: The `COMPUTE_PATH` parameter shouldn't be necessary in a perfect world, but unfortunately the compiler + /// isn't very good at getting rid of the dead code, so passing `COMPUTE_PATH=false` gives a considerable speedup + /// at the expense of providing paths and reliable child_masks to the closures. + fn recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + Self: Sized; + + /// A stepping (non-jumping) catamorphism for the trie. + /// + /// Use this when you need the cata to evaluate once per path byte, even across non-branching sub-paths. + /// Unlike the jumping version, `branch_f` and `finalize_f` will be called for every path byte. + /// + /// See [`Catamorphism::recursive_cata`] for closure semantics; this stepping variant omits the prefix argument from `collapse_f`. + fn recursive_cata_stepping(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + Self: Sized; + +} + //TODO GOAT!!: It would be nice to get rid of this Default bound on all morphism Ws. In this case, the plan // for doing that would be to create a new type called a TakableSlice. It would be able to deref // into a regular mutable slice of `T` so it would work just like an ordinary slice. Additionally @@ -420,6 +471,99 @@ impl Catamorph } } +impl<'a, Z, V> Summarization for Z where Z: Zipper + ZipperReadOnlyConditionalValues<'a, V> + ZipperConcrete + ZipperAbsolutePath + ZipperPathBuffer + ZipperPriv { + fn recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + let focus = self.get_focus(); + let w = match focus.borrow() { + Some(node) => { + let mut cache = HashMap::new(); + recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) + }, + None => finalize_f(&ByteMask::EMPTY, Acc::default()), + }; + collapse_f(self.val(), Some(w), &[]) + } + + fn recursive_cata_stepping(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + self.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let mut w = collapse_f(val, downstream); + for byte in prefix.iter().rev() { + let mask = ByteMask::from(*byte); + let mut acc = Acc::default(); + branch_f(&mask, w, &mut acc); + w = finalize_f(&mask, acc); + } + w + }, + branch_f, + finalize_f, + ) + } +} + +impl Summarization for PathMap { + fn recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + let w = match self.root() { + Some(node) => { + let mut cache = HashMap::new(); + recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) + }, + None => finalize_f(&ByteMask::EMPTY, Acc::default()), + }; + collapse_f(self.root_val(), Some(w), &[]) + } + + fn recursive_cata_stepping(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + self.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let mut w = collapse_f(val, downstream); + for byte in prefix.iter().rev() { + let mask = ByteMask::from(*byte); + let mut acc = Acc::default(); + branch_f(&mask, w, &mut acc); + w = finalize_f(&mask, acc); + } + w + }, + branch_f, + finalize_f, + ) + } +} + #[inline] fn cata_side_effect_body<'a, Z, V: 'a, W, Err, AlgF, const JUMPING: bool>(mut z: Z, mut alg_f: AlgF) -> Result where @@ -1956,6 +2100,121 @@ mod tests { eprintln!("calls_cached: {calls_cached}\ncalls_side: {calls_side}"); } + /// Adapted from morphisms::cata_test1 for recursive_cata (jumping). + #[test] + fn recursive_cata_jumping_sum_digits() { + let tests = [ + (vec![], 0), + (vec!["1"], 1), + (vec!["1", "2"], 3), + (vec!["1", "2", "3", "4", "5", "6"], 21), + (vec!["a1", "a2"], 3), + (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), + (vec!["12345"], 5), + (vec!["1", "12", "123", "1234", "12345"], 15), + (vec!["123", "123456", "123789"], 18), + (vec!["12", "123", "123456", "123789"], 20), + (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), + ]; + + #[derive(Default)] + struct SumAcc { + idx: usize, + sum: u32, + } + + for (keys, expected_sum) in tests { + let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); + let sum = map.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let mut sum = downstream.map(|w: (bool, u32)| w.1).unwrap_or(0); + if val.is_some() { + if let Some(byte) = prefix.last() { + sum += (*byte as char).to_digit(10).unwrap(); + } + } + (val.is_some() && prefix.is_empty(), sum) + }, + |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { + let byte = mask.indexed_bit::(acc.idx).unwrap(); + acc.idx += 1; + if w.0 { + acc.sum += (byte as char).to_digit(10).unwrap(); + } + acc.sum += w.1; + }, + |_mask: &ByteMask, acc: SumAcc| { (false, acc.sum) }, + ).1; + assert_eq!(sum, expected_sum); + } + } + + /// Adapted from morphisms::cata_test1 for recursive_cata_stepping (non-jumping). + #[test] + fn recursive_cata_stepping_sum_digits() { + let tests = [ + (vec![], 0), + (vec!["1"], 1), + (vec!["1", "2"], 3), + (vec!["1", "2", "3", "4", "5", "6"], 21), + (vec!["a1", "a2"], 3), + (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), + (vec!["12345"], 5), + (vec!["1", "12", "123", "1234", "12345"], 15), + (vec!["123", "123456", "123789"], 18), + (vec!["12", "123", "123456", "123789"], 20), + (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), + ]; + + #[derive(Default)] + struct SumAcc { + idx: usize, + sum: u32, + } + + for (keys, expected_sum) in tests { + let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); + let sum = map.recursive_cata_stepping::( + |val, downstream| { + let sum = downstream.map(|w| w.1).unwrap_or(0); + (val.is_some(), sum) + }, + |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { + let byte = mask.iter().nth(acc.idx).unwrap(); + acc.idx += 1; + if w.0 { + acc.sum += (byte as char).to_digit(10).unwrap(); + } + acc.sum += w.1; + }, + |_mask: &ByteMask, acc: SumAcc| { + (false, acc.sum) + }, + ).1; + assert_eq!(sum, expected_sum); + } + } + + /// Finds the path_depth at which the recursive cata hits a stack overflow + /// + /// Empirically seems to be somewhere between 8 and 10 KBytes. But more branching, and thus fewer + /// bytes-per-node, will mean it will fail on shorter paths. + #[test] + fn recursive_cata_stack_overflow_smoke() { + const PATH_LEN: usize = 8_000; + + let mut map = PathMap::<()>::new(); + let path = vec![b'a'; PATH_LEN]; + map.set_val_at(&path, ()); + + let count = map.recursive_cata::<_, _, _, _, _, false>( + |v, w, _| (v.is_some() as usize) + w.unwrap_or(0), + |_mask, w: usize, total| { *total += w }, + |_mask, total: usize| { total }, + ); + assert_eq!(count, 1); + } + /// Generate some basic tries using the [TrieBuilder::push_byte] API #[test] fn ana_test1() { diff --git a/src/tiny_node.rs b/src/tiny_node.rs index 08d61fc..1c93f1a 100644 --- a/src/tiny_node.rs +++ b/src/tiny_node.rs @@ -9,7 +9,7 @@ use core::mem::MaybeUninit; use core::fmt::{Debug, Formatter}; -use std::collections::HashMap; +use crate::gxhash::HashMap; use fast_slice_utils::{find_prefix_overlap, starts_with}; use crate::utils::ByteMask; @@ -121,6 +121,17 @@ impl<'a, V: Clone + Send + Sync, A: Allocator> TinyRefNode<'a, V, A> { fn key(&self) -> &[u8] { unsafe{ core::slice::from_raw_parts(self.key_bytes.as_ptr().cast(), self.key_len()) } } + + pub(crate) fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF, cache: &mut HashMap) -> W + where + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + self.into_full().unwrap().node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) + } } impl<'a, V: Clone + Send + Sync, A: Allocator> TrieNode for TinyRefNode<'a, V, A> { @@ -217,7 +228,7 @@ impl<'a, V: Clone + Send + Sync, A: Allocator> TrieNode for TinyRefNode<'a fn new_iter_token(&self) -> u128 { unreachable!() } fn iter_token_for_path(&self, _key: &[u8]) -> u128 { unreachable!() } fn next_items(&self, _token: u128) -> (u128, &'a[u8], Option<&TrieNodeODRc>, Option<&V>) { unreachable!() } - fn node_val_count(&self, cache: &mut HashMap) -> usize { + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { let temp_node = self.into_full().unwrap(); temp_node.node_val_count(cache) } diff --git a/src/trie_map.rs b/src/trie_map.rs index 3c5b0f3..7631149 100644 --- a/src/trie_map.rs +++ b/src/trie_map.rs @@ -1,7 +1,7 @@ use core::cell::UnsafeCell; use std::ptr::slice_from_raw_parts; use crate::alloc::{Allocator, GlobalAlloc, global_alloc}; -use crate::morphisms::{new_map_from_ana_in, Catamorphism, TrieBuilder}; +use crate::morphisms::{new_map_from_ana_in, Catamorphism, Summarization, TrieBuilder}; use crate::trie_node::*; use crate::zipper::*; use crate::merkleization::{MerkleizeResult, merkleize_impl}; @@ -510,10 +510,11 @@ impl PathMap { pub fn goat_val_count(&self) -> usize { let root_val = unsafe{ &*self.root_val.get() }.is_some() as usize; match self.root() { - Some(root) => { - traverse_physical(root, - |node, ctx: usize| { ctx + node.node_goat_val_count() }, - |ctx, child_ctx| { ctx + child_ctx }, + Some(_root) => { + self.recursive_cata::<_, _, _, _, _, false>( + |v, w, _| { (v.is_some() as usize) + w.unwrap_or(0) }, // on values amongst a path + |_mask, w: usize, total| { *total += w }, // on merging children into a node + |_mask, total: usize| { total } // finalizing a node ) + root_val }, None => root_val diff --git a/src/trie_node.rs b/src/trie_node.rs index 111bd5e..fb05bd2 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2,7 +2,6 @@ use core::hint::unreachable_unchecked; use core::mem::ManuallyDrop; use core::ptr::NonNull; -use std::collections::HashMap; use dyn_clone::*; use local_or_heap::LocalOrHeap; use arrayvec::ArrayVec; @@ -14,6 +13,8 @@ use crate::ring::*; use crate::tiny_node::TinyRefNode; use crate::line_list_node::LineListNode; +use crate::gxhash::HashMap; + #[cfg(feature = "bridge_nodes")] use crate::bridge_node::BridgeNode; @@ -220,7 +221,7 @@ pub(crate) trait TrieNode: TrieNodeDowncas /// Returns the total number of leaves contained within the whole subtree defined by the node /// GOAT, this should be deprecated - fn node_val_count(&self, cache: &mut HashMap) -> usize; + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize; /// Returns the number of values contained within the node itself, irrespective of the positions within /// the node; does not include onward links @@ -1222,7 +1223,7 @@ mod tagged_node_ref { } #[inline] - pub fn node_val_count(&self, cache: &mut HashMap) -> usize { + pub fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { match self { Self::DenseByteNode(node) => node.node_val_count(cache), Self::LineListNode(node) => node.node_val_count(cache), @@ -2353,7 +2354,7 @@ pub(crate) fn val_count_below_root(node: T node.node_val_count(&mut cache) } -pub(crate) fn val_count_below_node(node: &TrieNodeODRc, cache: &mut HashMap) -> usize { +pub(crate) fn val_count_below_node(node: &TrieNodeODRc, cache: &mut std::collections::HashMap) -> usize { if node.is_empty() { return 0 } @@ -2372,64 +2373,62 @@ pub(crate) fn val_count_below_node(node: & } } -/// Recursively traverses a trie descending from `node`, visiting every physical non-empty node once -pub(crate) fn traverse_physical(node: &TrieNodeODRc, node_f: NodeF, fold_f: FoldF) -> Ctx - where - V: Clone + Send + Sync, - A: Allocator, - Ctx: Clone + Default, - NodeF: Fn(TaggedNodeRef, Ctx) -> Ctx + Copy, - FoldF: Fn(Ctx, Ctx) -> Ctx + Copy -{ - let mut cache = std::collections::HashMap::new(); - traverse_physical_internal(node, node_f, fold_f, &mut cache) -} - -fn traverse_physical_internal(node: &TrieNodeODRc, node_f: NodeF, fold_f: FoldF, cache: &mut HashMap) -> Ctx - where +/// Internal implementation of recursive_cata +pub(crate) fn recursive_cata_cached( + node: &TrieNodeODRc, + collapse_f: CollapseF, + branch_f: BranchF, + finalize_f: FinalizeF, + cache: &mut HashMap, +) -> W +where V: Clone + Send + Sync, A: Allocator, - Ctx: Clone + Default, - NodeF: Fn(TaggedNodeRef, Ctx) -> Ctx + Copy, - FoldF: Fn(Ctx, Ctx) -> Ctx + Copy + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { - if node.is_empty() { - return Ctx::default() - } - - if node.refcount() > 1 { + if !node.is_empty() && node.refcount() > 1 { let hash = node.shared_node_id(); match cache.get(&hash) { Some(cached) => cached.clone(), None => { - let ctx = traverse_physical_children_internal(node.as_tagged(), node_f, fold_f, cache); - cache.insert(hash, ctx.clone()); - ctx + let w = recursive_cata_dispatch::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, cache); + cache.insert(hash, w.clone()); + w }, } } else { - traverse_physical_children_internal(node.as_tagged(), node_f, fold_f, cache) + recursive_cata_dispatch::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, cache) } } -fn traverse_physical_children_internal(node: TaggedNodeRef, node_f: NodeF, fold_f: FoldF, cache: &mut HashMap) -> Ctx - where +#[inline(always)] +fn recursive_cata_dispatch( + node: &TrieNodeODRc, + collapse_f: CollapseF, + branch_f: BranchF, + finalize_f: FinalizeF, + cache: &mut HashMap, +) -> W +where V: Clone + Send + Sync, A: Allocator, - Ctx: Clone + Default, - NodeF: Fn(TaggedNodeRef, Ctx) -> Ctx + Copy, - FoldF: Fn(Ctx, Ctx) -> Ctx + Copy + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { - let mut ctx = Ctx::default(); - - let (mut tok, mut child) = node.node_child_iter_start(); - while let Some(child_node) = child { - let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); - ctx = fold_f(ctx, child_ctx); - (tok, child) = node.node_child_iter_next(tok); + match node.as_tagged() { + TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } + TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } + TaggedNodeRef::CellByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } + TaggedNodeRef::TinyRefNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } + TaggedNodeRef::EmptyNode => { finalize_f(&ByteMask::EMPTY, Acc::default()) } } - - node_f(node, ctx) } /// Internal function to walk a mut TrieNodeODRc ref along a path @@ -2483,6 +2482,7 @@ pub(crate) fn make_cell_node(node: &mut Tr // module come from the visibility of the trait it is derived on. In this case, `TrieNode` //Credit to QuineDot for his ideas on this pattern here: https://users.rust-lang.org/t/inferred-lifetime-for-dyn-trait/112116/7 pub(crate) use opaque_dyn_rc_trie_node::TrieNodeODRc; + #[cfg(not(feature = "slim_ptrs"))] mod opaque_dyn_rc_trie_node { use std::sync::Arc; @@ -2973,6 +2973,7 @@ mod opaque_dyn_rc_trie_node { pub(crate) fn new_empty() -> Self { Self { ptr: SlimNodePtr::new_empty(), alloc: MaybeUninit::uninit() } } + #[inline(always)] pub(crate) fn is_empty(&self) -> bool { self.tag() == EMPTY_NODE_TAG } @@ -3241,4 +3242,4 @@ mod tests { node_ref.make_unique(); drop(cloned); } -} \ No newline at end of file +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a5c277c..4f74b45 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -255,6 +255,26 @@ impl From for ByteMask { } } +impl From<(u8, u8)> for ByteMask { + #[inline] + fn from(byte_pair: (u8, u8)) -> Self { + let mut new_mask = Self::new(); + new_mask.set_bit(byte_pair.0); + new_mask.set_bit(byte_pair.1); + new_mask + } +} + +impl From<[u8; 2]> for ByteMask { + #[inline] + fn from(byte_pair: [u8; 2]) -> Self { + let mut new_mask = Self::new(); + new_mask.set_bit(byte_pair[0]); + new_mask.set_bit(byte_pair[1]); + new_mask + } +} + impl From<[u64; 4]> for ByteMask { #[inline] fn from(mask: [u64; 4]) -> Self {