From 981b041bcf3edab0a0234d003abaeaca782afbfb Mon Sep 17 00:00:00 2001 From: Adam Vandervorst Date: Fri, 9 Jan 2026 21:19:02 +0100 Subject: [PATCH 01/15] Add node-based catamorphism POC --- Cargo.toml | 1 + benches/binary_keys.rs | 16 ++++ benches/cities.rs | 19 +++++ benches/shakespeare.rs | 38 ++++++++++ benches/sla.rs | 2 +- benches/sparse_keys.rs | 20 +++++ benches/superdense_keys.rs | 15 ++++ src/dense_byte_node.rs | 14 +++- src/line_list_node.rs | 23 +++++- src/trie_map.rs | 22 +++++- src/trie_node.rs | 146 +++++++++++++++++++++++++++++++++++-- 11 files changed, 301 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2d006419..0cf1cecc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -114,6 +114,7 @@ harness = false [[bench]] name = "sla" harness = false +required-features = ["viz"] [workspace] members = ["pathmap-derive"] diff --git a/benches/binary_keys.rs b/benches/binary_keys.rs index 2d8ded79..8f3c1e0a 100644 --- a/benches/binary_keys.rs +++ b/benches/binary_keys.rs @@ -77,6 +77,22 @@ fn binary_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000])] +fn binary_goat_val_count_bench(bencher: Bencher, n: u64) { + + let keys = make_keys(n as usize, 1); + + let mut map: PathMap = PathMap::new(); + for i in 0..n { map.set_val_at(&keys[i as usize], i); } + + //Benchmark the time taken to count the number of values in the map + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count() + }); + assert_eq!(sink, n as usize); +} + #[divan::bench(args = [50, 100, 200, 400, 800, 1600])] fn binary_drop_head(bencher: Bencher, n: u64) { diff --git a/benches/cities.rs b/benches/cities.rs index 231cb6e0..cc5cbc94 100644 --- a/benches/cities.rs +++ b/benches/cities.rs @@ -168,6 +168,25 @@ fn cities_val_count(bencher: Bencher) { assert_eq!(sink, unique_count); } +#[divan::bench()] +fn cities_goat_val_count(bencher: Bencher) { + + let pairs = read_data(); + let mut map = PathMap::new(); + let mut unique_count = 0; + for (k, v) in pairs.iter() { + if map.set_val_at(k, *v).is_none() { + unique_count += 1; + } + } + + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count(); + }); + assert_eq!(sink, unique_count); +} + #[cfg(feature="arena_compact")] #[divan::bench()] fn cities_val_count_act(bencher: Bencher) { diff --git a/benches/shakespeare.rs b/benches/shakespeare.rs index 2040ba54..5528f739 100644 --- a/benches/shakespeare.rs +++ b/benches/shakespeare.rs @@ -113,6 +113,25 @@ fn shakespeare_words_val_count(bencher: Bencher) { assert_eq!(sink, unique_count); } +#[divan::bench()] +fn shakespeare_words_goat_val_count(bencher: Bencher) { + + let strings = read_data(true); + let mut map = PathMap::new(); + let mut unique_count = 0; + for (v, k) in strings.iter().enumerate() { + if map.set_val_at(k, v).is_none() { + unique_count += 1; + } + } + + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count(); + }); + assert_eq!(sink, unique_count); +} + #[divan::bench()] fn shakespeare_sentences_insert(bencher: Bencher) { @@ -168,6 +187,25 @@ fn shakespeare_sentences_val_count(bencher: Bencher) { assert_eq!(sink, unique_count); } +#[divan::bench()] +fn shakespeare_sentences_goat_val_count(bencher: Bencher) { + + let strings = read_data(false); + let mut map = PathMap::new(); + let mut unique_count = 0; + for (v, k) in strings.iter().enumerate() { + if map.set_val_at(k, v).is_none() { + unique_count += 1; + } + } + + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count(); + }); + assert_eq!(sink, unique_count); +} + #[cfg(feature="arena_compact")] #[divan::bench()] fn shakespeare_sentences_val_count_act(bencher: Bencher) { diff --git a/benches/sla.rs b/benches/sla.rs index 5d84dbee..92d89f07 100644 --- a/benches/sla.rs +++ b/benches/sla.rs @@ -388,7 +388,7 @@ fn tipover_attention_weave() { // let res = rtq.vF_mut().merkleize(); // println!("{:?}", res.hash); let t0 = Instant::now(); - println!("{:?} {:?}", rtq.vF().read_zipper().into_cata_cached(morphisms::alg::hash), t0.elapsed().as_micros()); + // println!("{:?} {:?}", rtq.vF().read_zipper().into_cata_cached(morphisms::alg::hash), t0.elapsed().as_micros()); return; // rtk.vF_mut().merkleize(); diff --git a/benches/sparse_keys.rs b/benches/sparse_keys.rs index 84893144..42c4d422 100644 --- a/benches/sparse_keys.rs +++ b/benches/sparse_keys.rs @@ -92,6 +92,26 @@ fn sparse_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000])] +fn sparse_goat_val_count_bench(bencher: Bencher, n: u64) { + + let mut r = StdRng::seed_from_u64(1); + let keys: Vec> = (0..n).into_iter().map(|_| { + let len = (r.random::() % 18) + 3; //length between 3 and 20 chars + (0..len).into_iter().map(|_| r.random::()).collect() + }).collect(); + + let mut map: PathMap = PathMap::new(); + for i in 0..n { map.set_val_at(&keys[i as usize], i); } + + //Benchmark the time taken to count the number of values in the map + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count() + }); + assert_eq!(sink, n as usize); +} + #[divan::bench(args = [50, 100, 200, 400, 800, 1600])] fn binary_drop_head(bencher: Bencher, n: u64) { diff --git a/benches/superdense_keys.rs b/benches/superdense_keys.rs index 597954df..34a09ce4 100644 --- a/benches/superdense_keys.rs +++ b/benches/superdense_keys.rs @@ -253,6 +253,21 @@ fn superdense_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } +#[divan::bench(sample_size = 1, args = [100, 200, 400, 800, 1600, 3200, 20_000])] +fn superdense_goat_val_count_bench(bencher: Bencher, n: u64) { + + let mut map: PathMap = PathMap::new(); + for i in 0..n { map.set_val_at(prefix_key(&i), i); } + + //Benchmark the time taken to count the number of values in the map + let mut sink = 0; + bencher.bench_local(|| { + *black_box(&mut sink) = map.goat_val_count() + }); + assert_eq!(sink, n as usize); +} + + #[cfg(feature="arena_compact")] #[divan::bench(sample_size = 1, args = [100, 200, 400, 800, 1600, 3200, 20_000])] fn superdense_val_count_bench_act(bencher: Bencher, n: u64) { diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index dd846a29..defaccff 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -29,7 +29,7 @@ pub struct ByteNode { #[cfg(feature = "nightly")] values: Vec, #[cfg(not(feature = "nightly"))] - values: Vec, + pub(crate) values: Vec, alloc: A, } @@ -991,10 +991,18 @@ impl> TrieNode t + cf.has_val() as usize + cf.rec().map(|r| val_count_below_node(r, cache)).unwrap_or(0) }); } - fn node_goat_val_count(&self) -> usize { +/* fn node_goat_val_count(&self) -> usize { return self.values.iter().rfold(0, |t, cf| { - t + cf.has_val() as usize + t + cf.has_val() as usize + cf.rec().map(|r| r.as_tagged().node_goat_val_count()).unwrap_or(0) }); + }*/ + #[inline] + fn node_goat_val_count(&self) -> usize { + let mut result = 0; + for cf in self.values.iter() { + result += cf.has_val() as usize + } + result } fn node_child_iter_start(&self) -> (u64, Option<&TrieNodeODRc>) { for (pos, cf) in self.values.iter().enumerate() { diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 040f7142..66120318 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -403,7 +403,7 @@ impl LineListNode { } } #[inline] - unsafe fn child_in_slot(&self) -> &TrieNodeODRc { + pub(crate) unsafe fn child_in_slot(&self) -> &TrieNodeODRc { match SLOT { 0 => unsafe{ &*self.val_or_child0.child }, 1 => unsafe{ &*self.val_or_child1.child }, @@ -419,7 +419,7 @@ impl LineListNode { } } #[inline] - unsafe fn val_in_slot(&self) -> &V { + pub(crate) unsafe fn val_in_slot(&self) -> &V { match SLOT { 0 => unsafe{ &**self.val_or_child0.val }, 1 => unsafe{ &**self.val_or_child1.val }, @@ -1986,6 +1986,25 @@ impl TrieNode for LineListNode } result } +/* #[inline] + fn node_goat_val_count(&self) -> usize { + let mut result = 0; + if self.is_used_value_0() { + result += 1; + } + if self.is_used_value_1() { + result += 1; + } + if self.is_used_child_0() { + let child_node = unsafe{ self.child_in_slot::<0>() }; + result += child_node.as_tagged().node_goat_val_count(); + } + if self.is_used_child_1() { + let child_node = unsafe{ self.child_in_slot::<1>() }; + result += child_node.as_tagged().node_goat_val_count(); + } + result + }*/ #[inline] fn node_goat_val_count(&self) -> usize { //Here are 3 alternative implementations. They're basically the same in perf, with a slight edge to the diff --git a/src/trie_map.rs b/src/trie_map.rs index 3c5b0f37..5d211868 100644 --- a/src/trie_map.rs +++ b/src/trie_map.rs @@ -511,9 +511,25 @@ impl PathMap { let root_val = unsafe{ &*self.root_val.get() }.is_some() as usize; match self.root() { Some(root) => { - traverse_physical(root, - |node, ctx: usize| { ctx + node.node_goat_val_count() }, - |ctx, child_ctx| { ctx + child_ctx }, + // root.as_tagged().node_goat_val_count() + root_val + // traverse_physical(root, + // |node, ctx: usize| { ctx + node.node_goat_val_count() }, + // |ctx, child_ctx| { ctx + child_ctx }, + // ) + root_val + + // traverse_split_cata( + // root, + // |v, _| { 1usize }, + // |_, w, _| { 1 + w }, + // |bm, ws: &mut [usize], _| { ws.iter().sum() } + // ) + root_val + // Adam: this doesn't need to be called "traverse_osplit_cata" or be exposed under this interface; it can just live in morphisms + traverse_osplit_cata( + root, + |v, _| { 1usize }, // on leaf values + |_, w, _| { 1 + w }, // on values amongst a path + |bm, w: usize, _, total| { *total += w }, // on merging children into a node + |bm, total: usize, _| { total } // finalizing a node ) + root_val }, None => root_val diff --git a/src/trie_node.rs b/src/trie_node.rs index 111bd5e4..0687c4fa 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -7,7 +7,7 @@ use dyn_clone::*; use local_or_heap::LocalOrHeap; use arrayvec::ArrayVec; -use crate::utils::ByteMask; +use crate::utils::{BitMask, ByteMask}; use crate::alloc::Allocator; use crate::dense_byte_node::*; use crate::ring::*; @@ -2422,16 +2422,147 @@ fn traverse_physical_children_internal(node: TaggedNode { let mut ctx = Ctx::default(); - let (mut tok, mut child) = node.node_child_iter_start(); - while let Some(child_node) = child { - let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); - ctx = fold_f(ctx, child_ctx); - (tok, child) = node.node_child_iter_next(tok); + match node { + TaggedNodeRef::DenseByteNode(n) => { + for cf in n.values.iter() { + if let Some(rec) = cf.rec() { + let child_ctx = traverse_physical_internal(rec, node_f, fold_f, cache); + ctx = fold_f(ctx, child_ctx); + } + } + } + TaggedNodeRef::LineListNode(n) => { + if n.is_used_child_0() { + let child_node = unsafe{ n.child_in_slot::<0>() }; + let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); + ctx = fold_f(ctx, child_ctx); + } + if n.is_used_child_1() { + let child_node = unsafe{ n.child_in_slot::<1>() }; + let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); + ctx = fold_f(ctx, child_ctx); + } + } + TaggedNodeRef::CellByteNode(_) => { todo!() } + TaggedNodeRef::TinyRefNode(_) => { todo!() } + TaggedNodeRef::EmptyNode => { todo!() } } node_f(node, ctx) } +// This experiment is still OK, but the `&mut [W]` is awkward to instantiate if you don't actually have +/*pub fn traverse_split_cata<'a, A : Allocator, V : TrieValue, W, MapF, CollapseF, AlgF>(node: &TrieNodeODRc, mut map_f: MapF, mut collapse_f: CollapseF, alg_f: AlgF) -> W +where + MapF: Copy + FnMut(&V, &[u8]) -> W + 'a, + CollapseF: Copy + FnMut(&V, W, &[u8]) -> W + 'a, + AlgF: Copy + Fn(&ByteMask, &mut [W], &[u8]) -> W + 'a, +{ + match node.as_tagged() { + TaggedNodeRef::DenseByteNode(n) => { + let mut ws = [const { std::mem::MaybeUninit::::uninit() }; 256]; + // let mut ws: Vec> = Vec::with_capacity(n.mask.count_bits()); + // unsafe { ws.set_len(n.mask.count_bits()) }; + let mut c = 0; + for cf in n.values.iter() { + if let Some(rec) = cf.rec() { + let w = traverse_split_cata(rec, map_f, collapse_f, alg_f); + if let Some(v) = cf.val() { + ws[c].write(collapse_f(v, w, &[])); + } else { + ws[c].write(w); + } + } else if let Some(v) = cf.val() { + ws[c].write(map_f(v, &[])); + } + c += 1; + } + alg_f(&n.mask, unsafe { std::mem::transmute(&mut ws[..c]) }, &[]) + } + TaggedNodeRef::LineListNode(n) => { + // let mut ws = vec![]; + // if n.is_used_value_0() { + // ws.append(map_f(unsafe { n.val_in_slot::<0>() }, &[])); + // } + // if n.is_used_value_1() { + // ws.append(map_f(unsafe { n.val_in_slot::<1>() }, &[])); + // } + // if n.is_used_child_0() { + // let child_node = unsafe{ n.child_in_slot::<0>() }; + // let child_ctx = traverse_split_cata(child_node, map_f, collapse_f, alg_f); + // + // } + // if n.is_used_child_1() { + // let child_node = unsafe{ n.child_in_slot::<1>() }; + // let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); + // ctx = fold_f(ctx, child_ctx); + // } + alg_f(&ByteMask::new(), &mut [], &[]) + } + TaggedNodeRef::CellByteNode(_) => { todo!() } + TaggedNodeRef::TinyRefNode(_) => { todo!() } + TaggedNodeRef::EmptyNode => { todo!() } + } +} +*/ + +// Adam: This seems to be a winner, though it needs some work, the split alg gives us the opportunity to nicely compose the different calls for the different node types without introducing overhead +pub fn traverse_osplit_cata<'a, A : Allocator, V : TrieValue, Alg : Default, W, MapF, CollapseF, InAlgF, OutAlgF>(node: &TrieNodeODRc, mut map_f: MapF, mut collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W +where + MapF: Copy + FnMut(&V, &[u8]) -> W + 'a, + CollapseF: Copy + FnMut(&V, W, &[u8]) -> W + 'a, + InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Alg), + OutAlgF: Copy + Fn(&ByteMask, Alg, &[u8]) -> W + 'a, +{ + match node.as_tagged() { + TaggedNodeRef::DenseByteNode(n) => { + let mut ws = Some(Alg::default()); + for cf in n.values.iter() { + if let Some(rec) = cf.rec() { + let w = traverse_osplit_cata(rec, map_f, collapse_f, in_alg_f, out_alg_f); + if let Some(v) = cf.val() { + in_alg_f(&n.mask, collapse_f(v, w, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } else { + in_alg_f(&n.mask, w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + } else if let Some(v) = cf.val() { + in_alg_f(&n.mask, map_f(v, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + } + out_alg_f(&n.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) + } + TaggedNodeRef::LineListNode(n) => { + // Adam: I skimped out on the collapse logic here, I assume there are some built-in LineListNode functions I can use for prefixes, or another way to organize the branching based on the mask directly + let mut ws = Some(Alg::default()); + + if n.is_used_value_0() { + in_alg_f(&ByteMask::new(), map_f(unsafe { n.val_in_slot::<0>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + if n.is_used_value_1() { + in_alg_f(&ByteMask::new(), map_f(unsafe { n.val_in_slot::<1>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + if n.is_used_child_0() { + let child_node = unsafe{ n.child_in_slot::<0>() }; + let w = traverse_osplit_cata(child_node, map_f, collapse_f, in_alg_f, out_alg_f); + in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + + } + if n.is_used_child_1() { + let child_node = unsafe{ n.child_in_slot::<1>() }; + let w = traverse_osplit_cata(child_node, map_f, collapse_f, in_alg_f, out_alg_f); + in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + + out_alg_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) + } + TaggedNodeRef::CellByteNode(_) => { todo!() } + TaggedNodeRef::TinyRefNode(_) => { todo!() } + TaggedNodeRef::EmptyNode => { + out_alg_f(&ByteMask::new(), Alg::default(), &[]) + } + } +} + /// Internal function to walk a mut TrieNodeODRc ref along a path /// /// If `stop_early` is `true`, this function will return the parent node of the path and will never return @@ -2483,6 +2614,9 @@ pub(crate) fn make_cell_node(node: &mut Tr // module come from the visibility of the trait it is derived on. In this case, `TrieNode` //Credit to QuineDot for his ideas on this pattern here: https://users.rust-lang.org/t/inferred-lifetime-for-dyn-trait/112116/7 pub(crate) use opaque_dyn_rc_trie_node::TrieNodeODRc; +use crate::morphisms::SplitCata; +use crate::TrieValue; + #[cfg(not(feature = "slim_ptrs"))] mod opaque_dyn_rc_trie_node { use std::sync::Arc; From 949d2d6b27b2ee6ab90b29d229e52a81e496004b Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Tue, 20 Jan 2026 17:19:11 -0700 Subject: [PATCH 02/15] Thrashing towards fully functional recursive caching cata. --- benches/binary_keys.rs | 4 +- src/dense_byte_node.rs | 49 +++++++++++++++++++++ src/line_list_node.rs | 68 +++++++++++++++++++++++++++++ src/trie_map.rs | 2 +- src/trie_node.rs | 99 +++++++++++++++++++++--------------------- 5 files changed, 170 insertions(+), 52 deletions(-) diff --git a/benches/binary_keys.rs b/benches/binary_keys.rs index 8f3c1e0a..82ffe247 100644 --- a/benches/binary_keys.rs +++ b/benches/binary_keys.rs @@ -61,7 +61,7 @@ fn binary_get(bencher: Bencher, n: u64) { }); } -#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000])] +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000, 100000])] fn binary_val_count_bench(bencher: Bencher, n: u64) { let keys = make_keys(n as usize, 1); @@ -77,7 +77,7 @@ fn binary_val_count_bench(bencher: Bencher, n: u64) { assert_eq!(sink, n as usize); } -#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000])] +#[divan::bench(args = [125, 250, 500, 1000, 2000, 4000, 100000])] fn binary_goat_val_count_bench(bencher: Bencher, n: u64) { let keys = make_keys(n as usize, 1); diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index defaccff..78b2e176 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -310,6 +310,10 @@ impl> ByteNode /// Iterates the entries in `self`, calling `func` for each entry /// The arguments to `func` are: `func(self, key_byte, n)`, where `n` is the number of times /// `func` has been called prior. This corresponds to index of the `CoFree` in the `values` vec + /// + /// PERF GOAT: this method is useful if you know you need the corresponding path byte. If, however, the + /// byte might be unneeded, it's better to let self.values govern the loop, because using the mask + /// in the loop means the bitwise ops to manipulate the mask are always required. #[inline] fn for_each_item(&self, mut func: F) { let mut n = 0; @@ -326,6 +330,51 @@ impl> ByteNode } } } + + #[inline(always)] + pub fn node_recursive_cata(&self, map_f: MapF, collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W + where + Acc: Default, + MapF: Copy + Fn(&V, &[u8]) -> W, + CollapseF: Copy + Fn(&V, W, &[u8]) -> W, + InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Acc), + OutAlgF: Copy + Fn(&ByteMask, Acc, &[u8]) -> W, + { + let mut mask_idx = 0; + let mut lm = unsafe{ *self.mask.0.get_unchecked(0) }; + let mut ws = Some(Acc::default()); + for cf in self.values.iter() { + //Compute the key byte. Hopefully this will all be stripped away by the compiler if the path isn't used + //UPDATE: alas, my hopes were dashed. No amount of reorganizing this code, eliminating all traps, + // unrolling the loop, etc., could convince LLVM to elide it. So we have to hit it with the const hammer. + let key_byte; + let path = if COMPUTE_PATH { + while lm == 0 { + mask_idx += 1; + lm = unsafe{ *self.mask.0.get_unchecked(mask_idx) }; + } + let byte_index = lm.trailing_zeros(); + lm ^= 1u64 << byte_index; + key_byte = 64*(mask_idx as u8) + (byte_index as u8); + core::slice::from_ref(&key_byte) + } else { + &[] + }; + + //Do the recursive calling + if let Some(rec) = cf.rec() { + let w = recursive_cata::<_, _, _, _, _, _, _, _, COMPUTE_PATH>(rec, map_f, collapse_f, in_alg_f, out_alg_f); + if let Some(v) = cf.val() { + in_alg_f(&self.mask, collapse_f(v, w, path), path, unsafe { ws.as_mut().unwrap_unchecked() }); + } else { + in_alg_f(&self.mask, w, path, unsafe { ws.as_mut().unwrap_unchecked() }); + } + } else if let Some(v) = cf.val() { + in_alg_f(&self.mask, map_f(v, path), path, unsafe { ws.as_mut().unwrap_unchecked() }); + } + } + out_alg_f(&self.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) + } } impl> ByteNode where Self: TrieNodeDowncast { diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 66120318..3a0fd466 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -2741,6 +2741,74 @@ impl LineListNode { } } } + + #[inline(always)] + pub fn node_recursive_cata(&self, map_f: MapF, collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W + where + Acc: Default, + MapF: Copy + Fn(&V, &[u8]) -> W, + CollapseF: Copy + Fn(&V, W, &[u8]) -> W, + InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Acc), + OutAlgF: Copy + Fn(&ByteMask, Acc, &[u8]) -> W, + { + let mut ws = Some(Acc::default()); + +//GOAT, should we remove the path from out_alg?? I can't see when it's ever used... +// A: Either the path doesn't belong on the out_alg or on the in_alg. + +//GOAT, check out whether in_alg should get the path on the ByteNode + +//Cases: +// * There is only one val. Run map_f only +// * There is only one child. Recurse, then Run the in_alg -> out_alg combo +// * Both slots are filled +// GOAT, there is a problem where we have to care about whether it's a Val or child! +// - Is the first byte the same? +// N: +// - Call map_f on slot0 with the whole path +// - Call map_f on slot1 with the whole path +// - Call in_alg (branch_f) on each of the Ws +// - Call out_alg (fold_f) on the context with a composed mask from the first bytes +// Y: +// - Call map_f on slot0 with everything after first byte +// - Call map_f on slot1 with everything after first byte +// - Call in_alg (branch_f) on each of the Ws +// - Call out_alg (fold_f) on the context with a composed mask from the second bytes + + + +//New plan for API. +// 3 closures. +// - closure to deal with straight paths and values. Fn(Option<&V>, Option, &[u8]) -> W +// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) +// - closure to fold accumulator back into a W for the logical node. Fn(&ByteMask, Acc) -> W +// +//Then, the non-jumping API would take: +// 2 closures. +// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) +// - closure to deal with each path byte / logical node. Fn(&ByteMask, Option<&V>, Acc) -> W + + + if self.is_used_value_0() { + in_alg_f(&ByteMask::new(), map_f(unsafe { self.val_in_slot::<0>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + if self.is_used_value_1() { + in_alg_f(&ByteMask::new(), map_f(unsafe { self.val_in_slot::<1>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + if self.is_used_child_0() { + let child_node = unsafe{ self.child_in_slot::<0>() }; + let w = recursive_cata::<_, _, _, _, _, _, _, _, COMPUTE_PATH>(child_node, map_f, collapse_f, in_alg_f, out_alg_f); + in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + + } + if self.is_used_child_1() { + let child_node = unsafe{ self.child_in_slot::<1>() }; + let w = recursive_cata::<_, _, _, _, _, _, _, _, COMPUTE_PATH>(child_node, map_f, collapse_f, in_alg_f, out_alg_f); + in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + } + + out_alg_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) + } } impl TrieNodeDowncast for LineListNode { diff --git a/src/trie_map.rs b/src/trie_map.rs index 5d211868..9a3d93ab 100644 --- a/src/trie_map.rs +++ b/src/trie_map.rs @@ -524,7 +524,7 @@ impl PathMap { // |bm, ws: &mut [usize], _| { ws.iter().sum() } // ) + root_val // Adam: this doesn't need to be called "traverse_osplit_cata" or be exposed under this interface; it can just live in morphisms - traverse_osplit_cata( + recursive_cata::<_, _, _, _, _, _, _, _, false>( root, |v, _| { 1usize }, // on leaf values |_, w, _| { 1 + w }, // on values amongst a path diff --git a/src/trie_node.rs b/src/trie_node.rs index 0687c4fa..1068a998 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2507,59 +2507,60 @@ where */ // Adam: This seems to be a winner, though it needs some work, the split alg gives us the opportunity to nicely compose the different calls for the different node types without introducing overhead -pub fn traverse_osplit_cata<'a, A : Allocator, V : TrieValue, Alg : Default, W, MapF, CollapseF, InAlgF, OutAlgF>(node: &TrieNodeODRc, mut map_f: MapF, mut collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W +/// Traverse a trie with a split catamorphism and caller-provided aggregation. +/// +/// Closure argument meanings: +/// - `map_f(&V, &[u8]) -> W`: map a leaf value to a result. Args are the value and +/// a path slice (currently always `&[]` in this implementation). +/// - `collapse_f(&V, W, &[u8]) -> W`: combine a value stored along a path with a +/// child result. Args are the value, the child result, and a path slice +/// (currently always `&[]`). +/// - `in_alg_f(&ByteMask, W, &[u8], &mut Alg)`: fold a single slot's result into +/// the node accumulator. Args are the node's mask, the slot result, a path +/// slice (currently always `&[]`), and the mutable accumulator. +/// - `out_alg_f(&ByteMask, Alg, &[u8]) -> W`: finalize a node accumulator into the +/// node's result. Args are the node's mask, the accumulator, and a path slice +/// (currently always `&[]`). + +//GOAT issues +// - The reason it's faster than the other abstraction because it branches on node type once, rather than twice per node. +// +// There is more stuff on the stack, meaning we're more likely to blow the stack +// +// * Needed to add caching +// * The logic in pair node was wrong, because there is no guarantee both sides are at the same level; fixing that added another branch +// +// Observations: +// If we want to count path-ends, MapF wouldn't work, but could pass Option<&V>, which would be fine +// It might be possible to unify MapF and CollapseF, but +// +//GOAT, TODO: Make a test to hit the stack overflow failure case +// +//GOAT: +// * Look at the callbacks in line node +// * Send partial paths in pair node too +// * Come up with new names for in_alg and out_alg... +// * Put caching back +// * See if I can get closer to the other cata API without sacrificing performance +// +pub fn recursive_cata(node: &TrieNodeODRc, map_f: MapF, collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W where - MapF: Copy + FnMut(&V, &[u8]) -> W + 'a, - CollapseF: Copy + FnMut(&V, W, &[u8]) -> W + 'a, - InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Alg), - OutAlgF: Copy + Fn(&ByteMask, Alg, &[u8]) -> W + 'a, + V: Clone + Send + Sync, + A: Allocator, + Acc: Default, + MapF: Copy + Fn(&V, &[u8]) -> W, + CollapseF: Copy + Fn(&V, W, &[u8]) -> W, + // InAlgF: called for each + InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Acc), + // OutAlgF: collapses all children at the same level + OutAlgF: Copy + Fn(&ByteMask, Acc, &[u8]) -> W, { match node.as_tagged() { - TaggedNodeRef::DenseByteNode(n) => { - let mut ws = Some(Alg::default()); - for cf in n.values.iter() { - if let Some(rec) = cf.rec() { - let w = traverse_osplit_cata(rec, map_f, collapse_f, in_alg_f, out_alg_f); - if let Some(v) = cf.val() { - in_alg_f(&n.mask, collapse_f(v, w, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); - } else { - in_alg_f(&n.mask, w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); - } - } else if let Some(v) = cf.val() { - in_alg_f(&n.mask, map_f(v, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); - } - } - out_alg_f(&n.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) - } - TaggedNodeRef::LineListNode(n) => { - // Adam: I skimped out on the collapse logic here, I assume there are some built-in LineListNode functions I can use for prefixes, or another way to organize the branching based on the mask directly - let mut ws = Some(Alg::default()); - - if n.is_used_value_0() { - in_alg_f(&ByteMask::new(), map_f(unsafe { n.val_in_slot::<0>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); - } - if n.is_used_value_1() { - in_alg_f(&ByteMask::new(), map_f(unsafe { n.val_in_slot::<1>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); - } - if n.is_used_child_0() { - let child_node = unsafe{ n.child_in_slot::<0>() }; - let w = traverse_osplit_cata(child_node, map_f, collapse_f, in_alg_f, out_alg_f); - in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); - - } - if n.is_used_child_1() { - let child_node = unsafe{ n.child_in_slot::<1>() }; - let w = traverse_osplit_cata(child_node, map_f, collapse_f, in_alg_f, out_alg_f); - in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); - } - - out_alg_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) - } + TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, _, COMPUTE_PATH>(map_f, collapse_f, in_alg_f, out_alg_f) } + TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, _, COMPUTE_PATH>(map_f, collapse_f, in_alg_f, out_alg_f) } TaggedNodeRef::CellByteNode(_) => { todo!() } TaggedNodeRef::TinyRefNode(_) => { todo!() } - TaggedNodeRef::EmptyNode => { - out_alg_f(&ByteMask::new(), Alg::default(), &[]) - } + TaggedNodeRef::EmptyNode => { out_alg_f(&ByteMask::new(), Acc::default(), &[]) } } } @@ -3375,4 +3376,4 @@ mod tests { node_ref.make_unique(); drop(cloned); } -} \ No newline at end of file +} From 54b8daa160134b6ef698ddd6ac9d6fe903565ec0 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Tue, 20 Jan 2026 18:49:47 -0700 Subject: [PATCH 03/15] Implementing new API structure for recursive cata. No loss in perf so far. --- src/dense_byte_node.rs | 27 ++++++++++---------- src/line_list_node.rs | 39 +++++++++++------------------ src/trie_map.rs | 9 +++---- src/trie_node.rs | 56 ++++++++++++++++++++++++------------------ 4 files changed, 64 insertions(+), 67 deletions(-) diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index 78b2e176..362f2d21 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -332,13 +332,12 @@ impl> ByteNode } #[inline(always)] - pub fn node_recursive_cata(&self, map_f: MapF, collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W + pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where Acc: Default, - MapF: Copy + Fn(&V, &[u8]) -> W, - CollapseF: Copy + Fn(&V, W, &[u8]) -> W, - InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Acc), - OutAlgF: Copy + Fn(&ByteMask, Acc, &[u8]) -> W, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { let mut mask_idx = 0; let mut lm = unsafe{ *self.mask.0.get_unchecked(0) }; @@ -362,18 +361,18 @@ impl> ByteNode }; //Do the recursive calling + //PERF NOTE: The reason we have two code paths around the call to `branch_f` instead of just doing + // `let w = cf.rec().map(|rec| recursive_cata(...))` is that the compiler won't optimize the implemntation + // of `branch_f` around whether `w` is none or not, if we go with one call to `branch_f`. That means we + // often pay for two dependent branches instead of one, and the difference was 25% to the val_count benchmark. if let Some(rec) = cf.rec() { - let w = recursive_cata::<_, _, _, _, _, _, _, _, COMPUTE_PATH>(rec, map_f, collapse_f, in_alg_f, out_alg_f); - if let Some(v) = cf.val() { - in_alg_f(&self.mask, collapse_f(v, w, path), path, unsafe { ws.as_mut().unwrap_unchecked() }); - } else { - in_alg_f(&self.mask, w, path, unsafe { ws.as_mut().unwrap_unchecked() }); - } - } else if let Some(v) = cf.val() { - in_alg_f(&self.mask, map_f(v, path), path, unsafe { ws.as_mut().unwrap_unchecked() }); + let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f); + branch_f(&self.mask, collapse_f(cf.val(), Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); + } else { + branch_f(&self.mask, collapse_f(cf.val(), None, path), unsafe { ws.as_mut().unwrap_unchecked() }); } } - out_alg_f(&self.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) + finalize_f(&self.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) } } diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 3a0fd466..155218e9 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -2743,13 +2743,12 @@ impl LineListNode { } #[inline(always)] - pub fn node_recursive_cata(&self, map_f: MapF, collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W + pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where Acc: Default, - MapF: Copy + Fn(&V, &[u8]) -> W, - CollapseF: Copy + Fn(&V, W, &[u8]) -> W, - InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Acc), - OutAlgF: Copy + Fn(&ByteMask, Acc, &[u8]) -> W, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { let mut ws = Some(Acc::default()); @@ -2777,37 +2776,29 @@ impl LineListNode { -//New plan for API. -// 3 closures. -// - closure to deal with straight paths and values. Fn(Option<&V>, Option, &[u8]) -> W -// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) -// - closure to fold accumulator back into a W for the logical node. Fn(&ByteMask, Acc) -> W -// -//Then, the non-jumping API would take: -// 2 closures. -// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) -// - closure to deal with each path byte / logical node. Fn(&ByteMask, Option<&V>, Acc) -> W - - if self.is_used_value_0() { - in_alg_f(&ByteMask::new(), map_f(unsafe { self.val_in_slot::<0>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + let downstream = collapse_f(Some(unsafe { self.val_in_slot::<0>() }), None, &[]); + branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); } if self.is_used_value_1() { - in_alg_f(&ByteMask::new(), map_f(unsafe { self.val_in_slot::<1>() }, &[]), &[], unsafe { ws.as_mut().unwrap_unchecked() }); + let downstream = collapse_f(Some(unsafe { self.val_in_slot::<1>() }), None, &[]); + branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); } if self.is_used_child_0() { let child_node = unsafe{ self.child_in_slot::<0>() }; - let w = recursive_cata::<_, _, _, _, _, _, _, _, COMPUTE_PATH>(child_node, map_f, collapse_f, in_alg_f, out_alg_f); - in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let downstream = collapse_f(None, Some(w), &[]); + branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); } if self.is_used_child_1() { let child_node = unsafe{ self.child_in_slot::<1>() }; - let w = recursive_cata::<_, _, _, _, _, _, _, _, COMPUTE_PATH>(child_node, map_f, collapse_f, in_alg_f, out_alg_f); - in_alg_f(&ByteMask::new(), w, &[], unsafe { ws.as_mut().unwrap_unchecked() }); + let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let downstream = collapse_f(None, Some(w), &[]); + branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); } - out_alg_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }, &[]) + finalize_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) } } diff --git a/src/trie_map.rs b/src/trie_map.rs index 9a3d93ab..3b09d143 100644 --- a/src/trie_map.rs +++ b/src/trie_map.rs @@ -524,12 +524,11 @@ impl PathMap { // |bm, ws: &mut [usize], _| { ws.iter().sum() } // ) + root_val // Adam: this doesn't need to be called "traverse_osplit_cata" or be exposed under this interface; it can just live in morphisms - recursive_cata::<_, _, _, _, _, _, _, _, false>( + recursive_cata::<_, _, _, _, _, _, _, false>( root, - |v, _| { 1usize }, // on leaf values - |_, w, _| { 1 + w }, // on values amongst a path - |bm, w: usize, _, total| { *total += w }, // on merging children into a node - |bm, total: usize, _| { total } // finalizing a node + |v, w, _| { (v.is_some() as usize) + w.unwrap_or(0) }, // on values amongst a path + |bm, w: usize, total| { *total += w }, // on merging children into a node + |bm, total: usize| { total } // finalizing a node ) + root_val }, None => root_val diff --git a/src/trie_node.rs b/src/trie_node.rs index 1068a998..f083bef1 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2506,21 +2506,32 @@ where } */ -// Adam: This seems to be a winner, though it needs some work, the split alg gives us the opportunity to nicely compose the different calls for the different node types without introducing overhead -/// Traverse a trie with a split catamorphism and caller-provided aggregation. +/// GOAT recursive caching cata. If this dev branch is successful this should replace the caching cata flavors in the public API +/// This is the JUMPING cata /// -/// Closure argument meanings: -/// - `map_f(&V, &[u8]) -> W`: map a leaf value to a result. Args are the value and -/// a path slice (currently always `&[]` in this implementation). -/// - `collapse_f(&V, W, &[u8]) -> W`: combine a value stored along a path with a -/// child result. Args are the value, the child result, and a path slice -/// (currently always `&[]`). -/// - `in_alg_f(&ByteMask, W, &[u8], &mut Alg)`: fold a single slot's result into -/// the node accumulator. Args are the node's mask, the slot result, a path -/// slice (currently always `&[]`), and the mutable accumulator. -/// - `out_alg_f(&ByteMask, Alg, &[u8]) -> W`: finalize a node accumulator into the -/// node's result. Args are the node's mask, the accumulator, and a path slice -/// (currently always `&[]`). +/// Closures: +/// +/// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path into a single `W` +/// `fn(val: Option<&V>, downstream: Option, prefix: &[u8]) -> W` +/// +/// `BranchF`: Accumulates the `W` representing a downstream branch into an `Acc` accumulator type +/// `fn(branch_mask: &ByteMask, downstream: W, accumulator: &mut Acc)` +/// +/// `FinalizeF`: Converts an `Acc` accumulator into a `W` representing the logical node +/// `fn(branch_mask: &ByteMask, accumulator: Acc) -> W` + +// +//New plan for API. +// 3 closures. +// - closure to deal with straight paths and values. Fn(Option<&V>, Option, &[u8]) -> W +// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) +// - closure to fold accumulator back into a W for the logical node. Fn(&ByteMask, Acc) -> W +// +//Then, the non-jumping API would take: +// 2 closures. +// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) +// - closure to deal with each path byte / logical node. Fn(&ByteMask, Option<&V>, Acc) -> W + //GOAT issues // - The reason it's faster than the other abstraction because it branches on node type once, rather than twice per node. @@ -2543,24 +2554,21 @@ where // * Put caching back // * See if I can get closer to the other cata API without sacrificing performance // -pub fn recursive_cata(node: &TrieNodeODRc, map_f: MapF, collapse_f: CollapseF, in_alg_f: InAlgF, out_alg_f: OutAlgF) -> W +pub fn recursive_cata(node: &TrieNodeODRc, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where V: Clone + Send + Sync, A: Allocator, Acc: Default, - MapF: Copy + Fn(&V, &[u8]) -> W, - CollapseF: Copy + Fn(&V, W, &[u8]) -> W, - // InAlgF: called for each - InAlgF: Copy + Fn(&ByteMask, W, &[u8], &mut Acc), - // OutAlgF: collapses all children at the same level - OutAlgF: Copy + Fn(&ByteMask, Acc, &[u8]) -> W, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { match node.as_tagged() { - TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, _, COMPUTE_PATH>(map_f, collapse_f, in_alg_f, out_alg_f) } - TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, _, COMPUTE_PATH>(map_f, collapse_f, in_alg_f, out_alg_f) } + TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f) } + TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f) } TaggedNodeRef::CellByteNode(_) => { todo!() } TaggedNodeRef::TinyRefNode(_) => { todo!() } - TaggedNodeRef::EmptyNode => { out_alg_f(&ByteMask::new(), Acc::default(), &[]) } + TaggedNodeRef::EmptyNode => { finalize_f(&ByteMask::EMPTY, Acc::default()) } } } From cd6ea5cdd728af6d8f990b243ce92d5bd4d0fd32 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Tue, 20 Jan 2026 19:25:46 -0700 Subject: [PATCH 04/15] Slight tweak to node_recursive_cata for byte node, to give the optimizer more to work with. ~5% improvement to saturated val_count benchmark --- src/dense_byte_node.rs | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index 362f2d21..9fb1d821 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -361,15 +361,27 @@ impl> ByteNode }; //Do the recursive calling - //PERF NOTE: The reason we have two code paths around the call to `branch_f` instead of just doing - // `let w = cf.rec().map(|rec| recursive_cata(...))` is that the compiler won't optimize the implemntation - // of `branch_f` around whether `w` is none or not, if we go with one call to `branch_f`. That means we - // often pay for two dependent branches instead of one, and the difference was 25% to the val_count benchmark. - if let Some(rec) = cf.rec() { - let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f); - branch_f(&self.mask, collapse_f(cf.val(), Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); - } else { - branch_f(&self.mask, collapse_f(cf.val(), None, path), unsafe { ws.as_mut().unwrap_unchecked() }); + //PERF NOTE: The reason we have four code paths around the call to `branch_f` instead of just doing + // `let w = cf.rec().map(|rec| recursive_cata(...))` and then calling branch_f with the appropriate + // pair of `Option<&V>` and `Option` is that the compiler won't optimize the implemntation of + // `branch_f` around whether the options are none or not. That means we often pay for two dependent + // branches instead of one, and the difference was 25% to the val_count benchmark. Breaking out the calls + // like this gives the optimizer a site to specialize for each permutation + match (cf.rec(), cf.val()) { + (Some(rec), Some(val)) => { + let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f); + branch_f(&self.mask, collapse_f(Some(val), Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + (Some(rec), None) => { + let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f); + branch_f(&self.mask, collapse_f(None, Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + (None, Some(val)) => { + branch_f(&self.mask, collapse_f(Some(val), None, path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, + (None, None) => { + branch_f(&self.mask, collapse_f(None, None, path), unsafe { ws.as_mut().unwrap_unchecked() }); + }, } } finalize_f(&self.mask, unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) From 3fb75fe74f413462bfb315fabc66e91b94f77bb2 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Tue, 20 Jan 2026 20:00:34 -0700 Subject: [PATCH 05/15] Getting rid of dead code --- src/trie_map.rs | 17 +----- src/trie_node.rs | 140 +---------------------------------------------- 2 files changed, 3 insertions(+), 154 deletions(-) diff --git a/src/trie_map.rs b/src/trie_map.rs index 3b09d143..aa835b06 100644 --- a/src/trie_map.rs +++ b/src/trie_map.rs @@ -511,24 +511,11 @@ impl PathMap { let root_val = unsafe{ &*self.root_val.get() }.is_some() as usize; match self.root() { Some(root) => { - // root.as_tagged().node_goat_val_count() + root_val - // traverse_physical(root, - // |node, ctx: usize| { ctx + node.node_goat_val_count() }, - // |ctx, child_ctx| { ctx + child_ctx }, - // ) + root_val - - // traverse_split_cata( - // root, - // |v, _| { 1usize }, - // |_, w, _| { 1 + w }, - // |bm, ws: &mut [usize], _| { ws.iter().sum() } - // ) + root_val - // Adam: this doesn't need to be called "traverse_osplit_cata" or be exposed under this interface; it can just live in morphisms recursive_cata::<_, _, _, _, _, _, _, false>( root, |v, w, _| { (v.is_some() as usize) + w.unwrap_or(0) }, // on values amongst a path - |bm, w: usize, total| { *total += w }, // on merging children into a node - |bm, total: usize| { total } // finalizing a node + |_mask, w: usize, total| { *total += w }, // on merging children into a node + |_mask, total: usize| { total } // finalizing a node ) + root_val }, None => root_val diff --git a/src/trie_node.rs b/src/trie_node.rs index f083bef1..dc5aa2f6 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -7,7 +7,7 @@ use dyn_clone::*; use local_or_heap::LocalOrHeap; use arrayvec::ArrayVec; -use crate::utils::{BitMask, ByteMask}; +use crate::utils::ByteMask; use crate::alloc::Allocator; use crate::dense_byte_node::*; use crate::ring::*; @@ -2372,140 +2372,6 @@ pub(crate) fn val_count_below_node(node: & } } -/// Recursively traverses a trie descending from `node`, visiting every physical non-empty node once -pub(crate) fn traverse_physical(node: &TrieNodeODRc, node_f: NodeF, fold_f: FoldF) -> Ctx - where - V: Clone + Send + Sync, - A: Allocator, - Ctx: Clone + Default, - NodeF: Fn(TaggedNodeRef, Ctx) -> Ctx + Copy, - FoldF: Fn(Ctx, Ctx) -> Ctx + Copy -{ - let mut cache = std::collections::HashMap::new(); - traverse_physical_internal(node, node_f, fold_f, &mut cache) -} - -fn traverse_physical_internal(node: &TrieNodeODRc, node_f: NodeF, fold_f: FoldF, cache: &mut HashMap) -> Ctx - where - V: Clone + Send + Sync, - A: Allocator, - Ctx: Clone + Default, - NodeF: Fn(TaggedNodeRef, Ctx) -> Ctx + Copy, - FoldF: Fn(Ctx, Ctx) -> Ctx + Copy -{ - if node.is_empty() { - return Ctx::default() - } - - if node.refcount() > 1 { - let hash = node.shared_node_id(); - match cache.get(&hash) { - Some(cached) => cached.clone(), - None => { - let ctx = traverse_physical_children_internal(node.as_tagged(), node_f, fold_f, cache); - cache.insert(hash, ctx.clone()); - ctx - }, - } - } else { - traverse_physical_children_internal(node.as_tagged(), node_f, fold_f, cache) - } -} - -fn traverse_physical_children_internal(node: TaggedNodeRef, node_f: NodeF, fold_f: FoldF, cache: &mut HashMap) -> Ctx - where - V: Clone + Send + Sync, - A: Allocator, - Ctx: Clone + Default, - NodeF: Fn(TaggedNodeRef, Ctx) -> Ctx + Copy, - FoldF: Fn(Ctx, Ctx) -> Ctx + Copy -{ - let mut ctx = Ctx::default(); - - match node { - TaggedNodeRef::DenseByteNode(n) => { - for cf in n.values.iter() { - if let Some(rec) = cf.rec() { - let child_ctx = traverse_physical_internal(rec, node_f, fold_f, cache); - ctx = fold_f(ctx, child_ctx); - } - } - } - TaggedNodeRef::LineListNode(n) => { - if n.is_used_child_0() { - let child_node = unsafe{ n.child_in_slot::<0>() }; - let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); - ctx = fold_f(ctx, child_ctx); - } - if n.is_used_child_1() { - let child_node = unsafe{ n.child_in_slot::<1>() }; - let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); - ctx = fold_f(ctx, child_ctx); - } - } - TaggedNodeRef::CellByteNode(_) => { todo!() } - TaggedNodeRef::TinyRefNode(_) => { todo!() } - TaggedNodeRef::EmptyNode => { todo!() } - } - - node_f(node, ctx) -} - -// This experiment is still OK, but the `&mut [W]` is awkward to instantiate if you don't actually have -/*pub fn traverse_split_cata<'a, A : Allocator, V : TrieValue, W, MapF, CollapseF, AlgF>(node: &TrieNodeODRc, mut map_f: MapF, mut collapse_f: CollapseF, alg_f: AlgF) -> W -where - MapF: Copy + FnMut(&V, &[u8]) -> W + 'a, - CollapseF: Copy + FnMut(&V, W, &[u8]) -> W + 'a, - AlgF: Copy + Fn(&ByteMask, &mut [W], &[u8]) -> W + 'a, -{ - match node.as_tagged() { - TaggedNodeRef::DenseByteNode(n) => { - let mut ws = [const { std::mem::MaybeUninit::::uninit() }; 256]; - // let mut ws: Vec> = Vec::with_capacity(n.mask.count_bits()); - // unsafe { ws.set_len(n.mask.count_bits()) }; - let mut c = 0; - for cf in n.values.iter() { - if let Some(rec) = cf.rec() { - let w = traverse_split_cata(rec, map_f, collapse_f, alg_f); - if let Some(v) = cf.val() { - ws[c].write(collapse_f(v, w, &[])); - } else { - ws[c].write(w); - } - } else if let Some(v) = cf.val() { - ws[c].write(map_f(v, &[])); - } - c += 1; - } - alg_f(&n.mask, unsafe { std::mem::transmute(&mut ws[..c]) }, &[]) - } - TaggedNodeRef::LineListNode(n) => { - // let mut ws = vec![]; - // if n.is_used_value_0() { - // ws.append(map_f(unsafe { n.val_in_slot::<0>() }, &[])); - // } - // if n.is_used_value_1() { - // ws.append(map_f(unsafe { n.val_in_slot::<1>() }, &[])); - // } - // if n.is_used_child_0() { - // let child_node = unsafe{ n.child_in_slot::<0>() }; - // let child_ctx = traverse_split_cata(child_node, map_f, collapse_f, alg_f); - // - // } - // if n.is_used_child_1() { - // let child_node = unsafe{ n.child_in_slot::<1>() }; - // let child_ctx = traverse_physical_internal(child_node, node_f, fold_f, cache); - // ctx = fold_f(ctx, child_ctx); - // } - alg_f(&ByteMask::new(), &mut [], &[]) - } - TaggedNodeRef::CellByteNode(_) => { todo!() } - TaggedNodeRef::TinyRefNode(_) => { todo!() } - TaggedNodeRef::EmptyNode => { todo!() } - } -} -*/ - /// GOAT recursive caching cata. If this dev branch is successful this should replace the caching cata flavors in the public API /// This is the JUMPING cata /// @@ -2550,9 +2416,7 @@ where //GOAT: // * Look at the callbacks in line node // * Send partial paths in pair node too -// * Come up with new names for in_alg and out_alg... // * Put caching back -// * See if I can get closer to the other cata API without sacrificing performance // pub fn recursive_cata(node: &TrieNodeODRc, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where @@ -2623,8 +2487,6 @@ pub(crate) fn make_cell_node(node: &mut Tr // module come from the visibility of the trait it is derived on. In this case, `TrieNode` //Credit to QuineDot for his ideas on this pattern here: https://users.rust-lang.org/t/inferred-lifetime-for-dyn-trait/112116/7 pub(crate) use opaque_dyn_rc_trie_node::TrieNodeODRc; -use crate::morphisms::SplitCata; -use crate::TrieValue; #[cfg(not(feature = "slim_ptrs"))] mod opaque_dyn_rc_trie_node { From 69bce4139aeb309c589136bb478da554f27fd7cf Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Tue, 20 Jan 2026 23:30:20 -0700 Subject: [PATCH 06/15] Working through all the cases in pair node (on paper) --- src/line_list_node.rs | 49 ++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 155218e9..08e1ca9f 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -2752,28 +2752,33 @@ impl LineListNode { { let mut ws = Some(Acc::default()); -//GOAT, should we remove the path from out_alg?? I can't see when it's ever used... -// A: Either the path doesn't belong on the out_alg or on the in_alg. - -//GOAT, check out whether in_alg should get the path on the ByteNode - -//Cases: -// * There is only one val. Run map_f only -// * There is only one child. Recurse, then Run the in_alg -> out_alg combo -// * Both slots are filled -// GOAT, there is a problem where we have to care about whether it's a Val or child! -// - Is the first byte the same? -// N: -// - Call map_f on slot0 with the whole path -// - Call map_f on slot1 with the whole path -// - Call in_alg (branch_f) on each of the Ws -// - Call out_alg (fold_f) on the context with a composed mask from the first bytes -// Y: -// - Call map_f on slot0 with everything after first byte -// - Call map_f on slot1 with everything after first byte -// - Call in_alg (branch_f) on each of the Ws -// - Call out_alg (fold_f) on the context with a composed mask from the second bytes - +//Pair node can have the following permutations: (Slot0, Slot1) +// +// - (Empty, Empty) +// Run only `finalize_f` on default `Acc` +// - (Child, Empty) +// Recursively call on child, then run only `collapse_f` on the result, specifying the path +// - (Child, Val), 1-byte key, same key byte +// Recursively call on child, then run only `collapse_f` on the result, specifying the value and the 1-byte path +// - (Child, Val), different key bytes +// Recursively call on child, run `branch_f(collapse_f())` on the result. Then run `branch_f(collapse_f())` again +// with the value's path. And finally, run `finalize_f()` +// - (Child, Child) +// Recursively call on child0, run `branch_f(collapse_f())` on the result. Do the same for child1. Finally, run +// `finalize_f()` +// - (Val, Empty) +// Run only `collapse_f` on the val +// - (Val, Val), common first byte (meaning slot0 is a 1-byte path) +// run `collapse_f` on the slot1 val, specifying the path, then run `collapse_f` again on the slot0 val, specifying +// the common prefix byte +// - (Val, Val), different first bytes +// Run `branch_f(collapse_f())` on each val, then `finalize_f` at the end +// - (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) +// See "(Child, Val), 1-byte key, same key byte" +// - (Val, Child), different key bytes +// See (Child, Val), different key bytes +// +//GOAT, It would be nice to refactor the pair_node in order to express each of these permutations as a unique value for the 4 header bits, so we could take the appropriate code path without looking at any path bytes if self.is_used_value_0() { From 3a8e826fe9e1ccb3fbd42ce274ae70234df492ce Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Wed, 21 Jan 2026 18:12:00 -0700 Subject: [PATCH 07/15] Handling every case in the PairNode for recursive_cata (still no paths) --- src/line_list_node.rs | 194 +++++++++++++++++++++++++++++++----------- 1 file changed, 146 insertions(+), 48 deletions(-) diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 08e1ca9f..1dce7438 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -2750,60 +2750,158 @@ impl LineListNode { BranchF: Copy + Fn(&ByteMask, W, &mut Acc), FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { - let mut ws = Some(Acc::default()); + //Pair node can have the following permutations: (Slot0, Slot1) + // + // - Case 1 (Empty, Empty) + // Run only `finalize_f` on default `Acc` + // - Case 2 (Child, Empty) + // Recursively call on child, then run only `collapse_f` on the result, specifying the path + // - Case 3 (Child, Val), 1-byte key, same key byte + // Recursively call on child, then run only `collapse_f` on the result, specifying the value and the 1-byte path + // - Case 4 (Child, Val), different key bytes + // Recursively call on child, run `branch_f(collapse_f())` on the result. Then run `branch_f(collapse_f())` again + // with the value's path. And finally, run `finalize_f()` + // - Case 5 (Child, Child) + // Recursively call on child0, run `branch_f(collapse_f())` on the result. Do the same for child1. Finally, run + // `finalize_f()` + // - Case 6 (Val, Empty) + // Run only `collapse_f` on the val + // - Case 7 (Val, Val), common first byte (meaning slot0 is a 1-byte path) + // run `collapse_f` on the slot1 val, specifying the path, then run `collapse_f` again on the slot0 val, specifying + // the common prefix byte + // - Case 8 (Val, Val), different first bytes + // Run `branch_f(collapse_f())` on each val, then `finalize_f` at the end + // - Case 9 (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) + // See "Case 3 (Child, Val), 1-byte key, same key byte" + // - Case 10 (Val, Child), different key bytes + // See "Case 4 (Child, Val), different key bytes" + // + //GOAT, It would be nice to refactor the pair_node in order to express each of these permutations as a unique value for the 4 header bits, so we could take the appropriate code path without looking at any path bytes + + match self.header >> 12 { + //Case 1 (Empty, Empty) + 0 => finalize_f(&ByteMask::new(), Acc::default()), + //Case 2 (Child, Empty) = (1 << 3) + (1 << 1) | (1 << 3) + (1 << 1) + 1 + 10 | 11 => { + let child_node = unsafe{ self.child_in_slot::<0>() }; + let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + collapse_f(None, Some(child_w), &[]) + }, + //(Child, Val) = (1 << 3) + (1 << 2) + (1 << 1) + 14 => { + let child_node = unsafe{ self.child_in_slot::<0>() }; + let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let (key0, key1) = self.get_both_keys(); + //GOAT, we check the length here to short-circuit checking the key bytes, which is likely a lot slower. But maybe not... Try it both ways + if key1.len() == 1 && unsafe{ key0.get_unchecked(0) == key1.get_unchecked(0) } { + //Case 3 + debug_assert_eq!(key0.len(), 1); + debug_assert_eq!(key1.len(), 1); + let val = unsafe { self.val_in_slot::<1>() }; + collapse_f(Some(val), Some(child_w), &[]) + } else { + //Case 4 + let mut acc = Acc::default(); + branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); -//Pair node can have the following permutations: (Slot0, Slot1) -// -// - (Empty, Empty) -// Run only `finalize_f` on default `Acc` -// - (Child, Empty) -// Recursively call on child, then run only `collapse_f` on the result, specifying the path -// - (Child, Val), 1-byte key, same key byte -// Recursively call on child, then run only `collapse_f` on the result, specifying the value and the 1-byte path -// - (Child, Val), different key bytes -// Recursively call on child, run `branch_f(collapse_f())` on the result. Then run `branch_f(collapse_f())` again -// with the value's path. And finally, run `finalize_f()` -// - (Child, Child) -// Recursively call on child0, run `branch_f(collapse_f())` on the result. Do the same for child1. Finally, run -// `finalize_f()` -// - (Val, Empty) -// Run only `collapse_f` on the val -// - (Val, Val), common first byte (meaning slot0 is a 1-byte path) -// run `collapse_f` on the slot1 val, specifying the path, then run `collapse_f` again on the slot0 val, specifying -// the common prefix byte -// - (Val, Val), different first bytes -// Run `branch_f(collapse_f())` on each val, then `finalize_f` at the end -// - (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) -// See "(Child, Val), 1-byte key, same key byte" -// - (Val, Child), different key bytes -// See (Child, Val), different key bytes -// -//GOAT, It would be nice to refactor the pair_node in order to express each of these permutations as a unique value for the 4 header bits, so we could take the appropriate code path without looking at any path bytes + let val = unsafe { self.val_in_slot::<1>() }; + branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + finalize_f(&ByteMask::new(), acc) + } + }, + //Case 5 (Child, Child) = (1 << 3) + (1 << 2) + (1 << 1) + 1 + 15 => { + let mut acc = Acc::default(); + let child_node = unsafe{ self.child_in_slot::<0>() }; + let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + + let child_node = unsafe{ self.child_in_slot::<1>() }; + let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + + finalize_f(&ByteMask::new(), acc) + }, + //Case 6 (Val, Empty) = (1 << 3) | (1 << 3) + 1 + 8 | 9 => { + let val = unsafe { self.val_in_slot::<0>() }; + collapse_f(Some(val), None, &[]) + }, + //(Val, Val) = (1 << 3) + (1 << 2) + 12 => { + let (key0, key1) = self.get_both_keys(); + if unsafe{ key0.get_unchecked(0) == key1.get_unchecked(0) } { + //Case 7 (Val, Val), common first byte (meaning slot0 is a 1-byte path) + debug_assert_eq!(key0.len(), 1); + debug_assert!(key1.len() > 1); + let val = unsafe { self.val_in_slot::<1>() }; + let w1 = collapse_f(Some(val), None, &[]); + let val = unsafe { self.val_in_slot::<0>() }; + collapse_f(Some(val), Some(w1), &[]) + } else { + //Case 8 (Val, Val), different first bytes + let mut acc = Acc::default(); + let val = unsafe{ self.val_in_slot::<0>() }; + branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); - if self.is_used_value_0() { - let downstream = collapse_f(Some(unsafe { self.val_in_slot::<0>() }), None, &[]); - branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); - } - if self.is_used_value_1() { - let downstream = collapse_f(Some(unsafe { self.val_in_slot::<1>() }), None, &[]); - branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); - } - if self.is_used_child_0() { - let child_node = unsafe{ self.child_in_slot::<0>() }; - let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - let downstream = collapse_f(None, Some(w), &[]); - branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + let val = unsafe{ self.val_in_slot::<1>() }; + branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + finalize_f(&ByteMask::new(), acc) + } + }, + //(Val, Child) = (1 << 3) + (1 << 2) + 1 + 13 => { + let child_node = unsafe{ self.child_in_slot::<1>() }; + let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let (key0, key1) = self.get_both_keys(); + //GOAT, we check the length here to short-circuit checking the key bytes, which is likely a lot slower. But maybe not... Try it both ways + if key1.len() == 1 && unsafe{ key0.get_unchecked(0) == key1.get_unchecked(0) } { + //Case 9 (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) + debug_assert_eq!(key0.len(), 1); + debug_assert_eq!(key1.len(), 1); + let val = unsafe { self.val_in_slot::<0>() }; + collapse_f(Some(val), Some(child_w), &[]) + } else { + //Case 10 (Val, Child), different key bytes + let mut acc = Acc::default(); + branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + + let val = unsafe { self.val_in_slot::<0>() }; + branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + + finalize_f(&ByteMask::new(), acc) + } + }, + _ => { unsafe { unreachable_unchecked() } } } - if self.is_used_child_1() { - let child_node = unsafe{ self.child_in_slot::<1>() }; - let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - let downstream = collapse_f(None, Some(w), &[]); - branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); - } - finalize_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) + // let mut ws = Some(Acc::default()); + + // if self.is_used_value_0() { + // let downstream = collapse_f(Some(unsafe { self.val_in_slot::<0>() }), None, &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + // } + // if self.is_used_value_1() { + // let downstream = collapse_f(Some(unsafe { self.val_in_slot::<1>() }), None, &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + // } + // if self.is_used_child_0() { + // let child_node = unsafe{ self.child_in_slot::<0>() }; + // let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + // let downstream = collapse_f(None, Some(w), &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + + // } + // if self.is_used_child_1() { + // let child_node = unsafe{ self.child_in_slot::<1>() }; + // let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + // let downstream = collapse_f(None, Some(w), &[]); + // branch_f(&ByteMask::new(), downstream, unsafe { ws.as_mut().unwrap_unchecked() }); + // } + + // finalize_f(&ByteMask::new(), unsafe { std::mem::take(&mut ws).unwrap_unchecked() }) } } From 42c35878acfa18ee2dd37d7aecad2e6d37f79705 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Wed, 21 Jan 2026 21:16:07 -0700 Subject: [PATCH 08/15] Adding correct path handling to node_recursive_cata for PairNode --- src/line_list_node.rs | 120 +++++++++++++++++++++++++++++++++--------- src/utils/mod.rs | 20 +++++++ 2 files changed, 114 insertions(+), 26 deletions(-) diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 1dce7438..17e809ab 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -2785,29 +2785,50 @@ impl LineListNode { 10 | 11 => { let child_node = unsafe{ self.child_in_slot::<0>() }; let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - collapse_f(None, Some(child_w), &[]) + let path = if COMPUTE_PATH { + unsafe{ self.key_unchecked::<0>() } + } else { + &[] + }; + collapse_f(None, Some(child_w), path) }, //(Child, Val) = (1 << 3) + (1 << 2) + (1 << 1) 14 => { let child_node = unsafe{ self.child_in_slot::<0>() }; let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - let (key0, key1) = self.get_both_keys(); - //GOAT, we check the length here to short-circuit checking the key bytes, which is likely a lot slower. But maybe not... Try it both ways - if key1.len() == 1 && unsafe{ key0.get_unchecked(0) == key1.get_unchecked(0) } { + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + if key0_byte == key1_byte { //Case 3 debug_assert_eq!(key0.len(), 1); debug_assert_eq!(key1.len(), 1); let val = unsafe { self.val_in_slot::<1>() }; - collapse_f(Some(val), Some(child_w), &[]) + let path = if COMPUTE_PATH { + key0 + } else { + &[] + }; + collapse_f(Some(val), Some(child_w), path) } else { //Case 4 let mut acc = Acc::default(); - branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + let (path, mask) = if COMPUTE_PATH { + (&key0[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], ByteMask::new()) + }; + branch_f(&mask, collapse_f(None, Some(child_w), path), &mut acc); let val = unsafe { self.val_in_slot::<1>() }; - branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + let path = if COMPUTE_PATH { + &key1[1..] + } else { + &[] + }; + branch_f(&mask, collapse_f(Some(val), None, path), &mut acc); - finalize_f(&ByteMask::new(), acc) + finalize_f(&mask, acc) } }, //Case 5 (Child, Child) = (1 << 3) + (1 << 2) + (1 << 1) + 1 @@ -2815,63 +2836,110 @@ impl LineListNode { let mut acc = Acc::default(); let child_node = unsafe{ self.child_in_slot::<0>() }; let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + let (path0, path1, mask) = if COMPUTE_PATH { + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + (&key0[1..], &key1[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], &[] as &[u8], ByteMask::new()) + }; + branch_f(&mask, collapse_f(None, Some(child_w), path0), &mut acc); let child_node = unsafe{ self.child_in_slot::<1>() }; let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + branch_f(&mask, collapse_f(None, Some(child_w), path1), &mut acc); - finalize_f(&ByteMask::new(), acc) + finalize_f(&mask, acc) }, //Case 6 (Val, Empty) = (1 << 3) | (1 << 3) + 1 8 | 9 => { let val = unsafe { self.val_in_slot::<0>() }; - collapse_f(Some(val), None, &[]) + let path = if COMPUTE_PATH { + unsafe{ self.key_unchecked::<0>() } + } else { + &[] + }; + collapse_f(Some(val), None, path) }, //(Val, Val) = (1 << 3) + (1 << 2) 12 => { - let (key0, key1) = self.get_both_keys(); - if unsafe{ key0.get_unchecked(0) == key1.get_unchecked(0) } { + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + if key0_byte == key1_byte { //Case 7 (Val, Val), common first byte (meaning slot0 is a 1-byte path) debug_assert_eq!(key0.len(), 1); debug_assert!(key1.len() > 1); let val = unsafe { self.val_in_slot::<1>() }; - let w1 = collapse_f(Some(val), None, &[]); + let path = if COMPUTE_PATH { + &key1[1..] + } else { + &[] + }; + let w1 = collapse_f(Some(val), None, path); let val = unsafe { self.val_in_slot::<0>() }; - collapse_f(Some(val), Some(w1), &[]) + let path = if COMPUTE_PATH { + &key1[0..1] + } else { + &[] + }; + collapse_f(Some(val), Some(w1), path) } else { //Case 8 (Val, Val), different first bytes let mut acc = Acc::default(); let val = unsafe{ self.val_in_slot::<0>() }; - branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + let (path0, path1, mask) = if COMPUTE_PATH { + (&key0[1..], &key1[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], &[] as &[u8], ByteMask::new()) + }; + branch_f(&mask, collapse_f(Some(val), None, path0), &mut acc); let val = unsafe{ self.val_in_slot::<1>() }; - branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + branch_f(&mask, collapse_f(Some(val), None, path1), &mut acc); - finalize_f(&ByteMask::new(), acc) + finalize_f(&mask, acc) } }, //(Val, Child) = (1 << 3) + (1 << 2) + 1 13 => { let child_node = unsafe{ self.child_in_slot::<1>() }; let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); - let (key0, key1) = self.get_both_keys(); - //GOAT, we check the length here to short-circuit checking the key bytes, which is likely a lot slower. But maybe not... Try it both ways - if key1.len() == 1 && unsafe{ key0.get_unchecked(0) == key1.get_unchecked(0) } { + let key0 = unsafe{ self.key_unchecked::<0>() }; + let key1 = unsafe{ self.key_unchecked::<1>() }; + let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; + if key0_byte == key1_byte { //Case 9 (Val, Child), 1-byte key, same key byte (We could eliminate this case by requiring a canonical ordering for identical one-byte keys, but currently we don't) debug_assert_eq!(key0.len(), 1); debug_assert_eq!(key1.len(), 1); let val = unsafe { self.val_in_slot::<0>() }; - collapse_f(Some(val), Some(child_w), &[]) + let path = if COMPUTE_PATH { + key0 + } else { + &[] + }; + collapse_f(Some(val), Some(child_w), path) } else { //Case 10 (Val, Child), different key bytes let mut acc = Acc::default(); - branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), &[]), &mut acc); + + let (path, mask) = if COMPUTE_PATH { + (&key1[1..], ByteMask::from((key0_byte, key1_byte))) + } else { + (&[] as &[u8], ByteMask::new()) + }; + branch_f(&ByteMask::new(), collapse_f(None, Some(child_w), path), &mut acc); let val = unsafe { self.val_in_slot::<0>() }; - branch_f(&ByteMask::new(), collapse_f(Some(val), None, &[]), &mut acc); + let path = if COMPUTE_PATH { + &key0[1..] + } else { + &[] + }; + branch_f(&mask, collapse_f(Some(val), None, path), &mut acc); - finalize_f(&ByteMask::new(), acc) + finalize_f(&mask, acc) } }, _ => { unsafe { unreachable_unchecked() } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a5c277cf..4f74b458 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -255,6 +255,26 @@ impl From for ByteMask { } } +impl From<(u8, u8)> for ByteMask { + #[inline] + fn from(byte_pair: (u8, u8)) -> Self { + let mut new_mask = Self::new(); + new_mask.set_bit(byte_pair.0); + new_mask.set_bit(byte_pair.1); + new_mask + } +} + +impl From<[u8; 2]> for ByteMask { + #[inline] + fn from(byte_pair: [u8; 2]) -> Self { + let mut new_mask = Self::new(); + new_mask.set_bit(byte_pair[0]); + new_mask.set_bit(byte_pair[1]); + new_mask + } +} + impl From<[u64; 4]> for ByteMask { #[inline] fn from(mask: [u64; 4]) -> Self { From 542c283e2d5bc850a7484adf66f36ead45b8b969 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Wed, 21 Jan 2026 21:31:13 -0700 Subject: [PATCH 09/15] Re-adding caching to recursive_cata. Slight perf hit, but it's unavoidable, and <5% --- src/dense_byte_node.rs | 7 ++-- src/line_list_node.rs | 15 +++++---- src/trie_node.rs | 75 +++++++++++++++++++++++++++++++----------- 3 files changed, 68 insertions(+), 29 deletions(-) diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index 9fb1d821..87a1f9e0 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -332,9 +332,10 @@ impl> ByteNode } #[inline(always)] - pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF, cache: &mut HashMap) -> W where Acc: Default, + W: Clone, CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, BranchF: Copy + Fn(&ByteMask, W, &mut Acc), FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, @@ -369,11 +370,11 @@ impl> ByteNode // like this gives the optimizer a site to specialize for each permutation match (cf.rec(), cf.val()) { (Some(rec), Some(val)) => { - let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f); + let w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f, cache); branch_f(&self.mask, collapse_f(Some(val), Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); }, (Some(rec), None) => { - let w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f); + let w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(rec, collapse_f, branch_f, finalize_f, cache); branch_f(&self.mask, collapse_f(None, Some(w), path), unsafe { ws.as_mut().unwrap_unchecked() }); }, (None, Some(val)) => { diff --git a/src/line_list_node.rs b/src/line_list_node.rs index 17e809ab..afa5ac8d 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -2743,9 +2743,10 @@ impl LineListNode { } #[inline(always)] - pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + pub fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF, cache: &mut HashMap) -> W where Acc: Default, + W: Clone, CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, BranchF: Copy + Fn(&ByteMask, W, &mut Acc), FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, @@ -2784,7 +2785,7 @@ impl LineListNode { //Case 2 (Child, Empty) = (1 << 3) + (1 << 1) | (1 << 3) + (1 << 1) + 1 10 | 11 => { let child_node = unsafe{ self.child_in_slot::<0>() }; - let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); let path = if COMPUTE_PATH { unsafe{ self.key_unchecked::<0>() } } else { @@ -2795,7 +2796,7 @@ impl LineListNode { //(Child, Val) = (1 << 3) + (1 << 2) + (1 << 1) 14 => { let child_node = unsafe{ self.child_in_slot::<0>() }; - let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); let key0 = unsafe{ self.key_unchecked::<0>() }; let key1 = unsafe{ self.key_unchecked::<1>() }; let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; @@ -2835,7 +2836,7 @@ impl LineListNode { 15 => { let mut acc = Acc::default(); let child_node = unsafe{ self.child_in_slot::<0>() }; - let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); let (path0, path1, mask) = if COMPUTE_PATH { let key0 = unsafe{ self.key_unchecked::<0>() }; let key1 = unsafe{ self.key_unchecked::<1>() }; @@ -2847,7 +2848,7 @@ impl LineListNode { branch_f(&mask, collapse_f(None, Some(child_w), path0), &mut acc); let child_node = unsafe{ self.child_in_slot::<1>() }; - let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); branch_f(&mask, collapse_f(None, Some(child_w), path1), &mut acc); finalize_f(&mask, acc) @@ -2905,7 +2906,7 @@ impl LineListNode { //(Val, Child) = (1 << 3) + (1 << 2) + 1 13 => { let child_node = unsafe{ self.child_in_slot::<1>() }; - let child_w = recursive_cata::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f); + let child_w = recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(child_node, collapse_f, branch_f, finalize_f, cache); let key0 = unsafe{ self.key_unchecked::<0>() }; let key1 = unsafe{ self.key_unchecked::<1>() }; let (key0_byte, key1_byte) = unsafe{ (*key0.get_unchecked(0), *key1.get_unchecked(0)) }; @@ -3586,4 +3587,4 @@ mod tests { // and ZipperMoving // 2. implement a val_count convenience on top of 1. -//GOAT, Paths in caching Cata: https://github.com/Adam-Vandervorst/PathMap/pull/8#discussion_r2004828957 \ No newline at end of file +//GOAT, Paths in caching Cata: https://github.com/Adam-Vandervorst/PathMap/pull/8#discussion_r2004828957 diff --git a/src/trie_node.rs b/src/trie_node.rs index dc5aa2f6..005cd535 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2377,7 +2377,7 @@ pub(crate) fn val_count_below_node(node: & /// /// Closures: /// -/// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path into a single `W` +/// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path, into a single `W` /// `fn(val: Option<&V>, downstream: Option, prefix: &[u8]) -> W` /// /// `BranchF`: Accumulates the `W` representing a downstream branch into an `Acc` accumulator type @@ -2399,37 +2399,74 @@ pub(crate) fn val_count_below_node(node: & // - closure to deal with each path byte / logical node. Fn(&ByteMask, Option<&V>, Acc) -> W -//GOAT issues -// - The reason it's faster than the other abstraction because it branches on node type once, rather than twice per node. -// -// There is more stuff on the stack, meaning we're more likely to blow the stack -// -// * Needed to add caching -// * The logic in pair node was wrong, because there is no guarantee both sides are at the same level; fixing that added another branch -// -// Observations: -// If we want to count path-ends, MapF wouldn't work, but could pass Option<&V>, which would be fine -// It might be possible to unify MapF and CollapseF, but // //GOAT, TODO: Make a test to hit the stack overflow failure case // -//GOAT: -// * Look at the callbacks in line node -// * Send partial paths in pair node too -// * Put caching back -// pub fn recursive_cata(node: &TrieNodeODRc, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where V: Clone + Send + Sync, A: Allocator, Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, +{ + let mut cache = HashMap::new(); + recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) +} + +pub(crate) fn recursive_cata_cached( + node: &TrieNodeODRc, + collapse_f: CollapseF, + branch_f: BranchF, + finalize_f: FinalizeF, + cache: &mut HashMap, +) -> W +where + V: Clone + Send + Sync, + A: Allocator, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, +{ + if node.refcount() > 1 { + let hash = node.shared_node_id(); + match cache.get(&hash) { + Some(cached) => cached.clone(), + None => { + let w = recursive_cata_dispatch::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, cache); + cache.insert(hash, w.clone()); + w + }, + } + } else { + recursive_cata_dispatch::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, cache) + } +} + +#[inline(always)] +fn recursive_cata_dispatch( + node: &TrieNodeODRc, + collapse_f: CollapseF, + branch_f: BranchF, + finalize_f: FinalizeF, + cache: &mut HashMap, +) -> W +where + V: Clone + Send + Sync, + A: Allocator, + Acc: Default, + W: Clone, CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, BranchF: Copy + Fn(&ByteMask, W, &mut Acc), FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { match node.as_tagged() { - TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f) } - TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f) } + TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } + TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } TaggedNodeRef::CellByteNode(_) => { todo!() } TaggedNodeRef::TinyRefNode(_) => { todo!() } TaggedNodeRef::EmptyNode => { finalize_f(&ByteMask::EMPTY, Acc::default()) } From 32973bfb3f304463b07dfa855239fe30ab1dd2b9 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Wed, 21 Jan 2026 21:40:39 -0700 Subject: [PATCH 10/15] Putting empty-node check back into recursive cata, to avoid reading bad memory when retrieving the refcount of the empty node --- src/trie_node.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/trie_node.rs b/src/trie_node.rs index 005cd535..ccb67460 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2377,7 +2377,7 @@ pub(crate) fn val_count_below_node(node: & /// /// Closures: /// -/// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path, into a single `W` +/// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path into a single `W` /// `fn(val: Option<&V>, downstream: Option, prefix: &[u8]) -> W` /// /// `BranchF`: Accumulates the `W` representing a downstream branch into an `Acc` accumulator type @@ -2432,7 +2432,7 @@ where BranchF: Copy + Fn(&ByteMask, W, &mut Acc), FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, { - if node.refcount() > 1 { + if !node.is_empty() && node.refcount() > 1 { let hash = node.shared_node_id(); match cache.get(&hash) { Some(cached) => cached.clone(), @@ -3015,6 +3015,7 @@ mod opaque_dyn_rc_trie_node { pub(crate) fn new_empty() -> Self { Self { ptr: SlimNodePtr::new_empty(), alloc: MaybeUninit::uninit() } } + #[inline(always)] pub(crate) fn is_empty(&self) -> bool { self.tag() == EMPTY_NODE_TAG } From ce6a8f874422ccc7ec31e20b05432052e56a6547 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Wed, 21 Jan 2026 21:50:40 -0700 Subject: [PATCH 11/15] Filling in remaining node type branches for recursive_cata --- src/tiny_node.rs | 11 +++++++++++ src/trie_node.rs | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/tiny_node.rs b/src/tiny_node.rs index 08d61fc3..f12584ca 100644 --- a/src/tiny_node.rs +++ b/src/tiny_node.rs @@ -121,6 +121,17 @@ impl<'a, V: Clone + Send + Sync, A: Allocator> TinyRefNode<'a, V, A> { fn key(&self) -> &[u8] { unsafe{ core::slice::from_raw_parts(self.key_bytes.as_ptr().cast(), self.key_len()) } } + + pub(crate) fn node_recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF, cache: &mut HashMap) -> W + where + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + self.into_full().unwrap().node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) + } } impl<'a, V: Clone + Send + Sync, A: Allocator> TrieNode for TinyRefNode<'a, V, A> { diff --git a/src/trie_node.rs b/src/trie_node.rs index ccb67460..77d4730c 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2467,8 +2467,8 @@ where match node.as_tagged() { TaggedNodeRef::DenseByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } TaggedNodeRef::LineListNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } - TaggedNodeRef::CellByteNode(_) => { todo!() } - TaggedNodeRef::TinyRefNode(_) => { todo!() } + TaggedNodeRef::CellByteNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } + TaggedNodeRef::TinyRefNode(node) => { node.node_recursive_cata::<_, _, _, _, _, COMPUTE_PATH>(collapse_f, branch_f, finalize_f, cache) } TaggedNodeRef::EmptyNode => { finalize_f(&ByteMask::EMPTY, Acc::default()) } } } From 73130eaf6774bccd3893b863344d8e7a13354748 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Wed, 21 Jan 2026 22:02:15 -0700 Subject: [PATCH 12/15] Adding recursive_cata_stack_overflow_smoke test --- src/trie_node.rs | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/src/trie_node.rs b/src/trie_node.rs index 77d4730c..54ef9666 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2398,10 +2398,6 @@ pub(crate) fn val_count_below_node(node: & // - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) // - closure to deal with each path byte / logical node. Fn(&ByteMask, Option<&V>, Acc) -> W - -// -//GOAT, TODO: Make a test to hit the stack overflow failure case -// pub fn recursive_cata(node: &TrieNodeODRc, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where V: Clone + Send + Sync, @@ -3252,6 +3248,7 @@ mod tests { use crate::alloc::{GlobalAlloc, global_alloc}; use crate::line_list_node::LineListNode; use crate::trie_node::TrieNodeODRc; + use crate::trie_node::recursive_cata; use crate::PathMap; use crate::zipper::*; @@ -3284,4 +3281,26 @@ mod tests { node_ref.make_unique(); drop(cloned); } + + /// Finds the path_depth at which the recursive cata hits a stack overflow + /// + /// Empirically seems to be somewhere between 8 and 10 KBytes. But more branching, and thus fewer + /// bytes-per-node, will mean it will fail on shorter paths. + #[test] + fn recursive_cata_stack_overflow_smoke() { + const PATH_LEN: usize = 8_000; + + let mut map = PathMap::<()>::new(); + let path = vec![b'a'; PATH_LEN]; + map.set_val_at(&path, ()); + + let root = map.root().unwrap(); + let count = recursive_cata::<_, _, _, _, _, _, _, false>( + root, + |v, w, _| (v.is_some() as usize) + w.unwrap_or(0), + |_mask, w: usize, total| { *total += w }, + |_mask, total: usize| { total }, + ); + assert_eq!(count, 1); + } } From 58ea1f1df1d734cf2cc6dd8b3be544ee8039ae53 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Thu, 22 Jan 2026 23:47:38 -0700 Subject: [PATCH 13/15] Adding a test for recursive cata, and adding the stepping version with a test --- src/trie_node.rs | 160 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 154 insertions(+), 6 deletions(-) diff --git a/src/trie_node.rs b/src/trie_node.rs index 54ef9666..45cf21da 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2393,10 +2393,9 @@ pub(crate) fn val_count_below_node(node: & // - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) // - closure to fold accumulator back into a W for the logical node. Fn(&ByteMask, Acc) -> W // -//Then, the non-jumping API would take: -// 2 closures. -// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) -// - closure to deal with each path byte / logical node. Fn(&ByteMask, Option<&V>, Acc) -> W +// +//The non-jumping API would be the same, but collapse wouldn't take a prefix path, and instead `finalize_f(branch_f())` would be +// called in reverse order for each path byte pub fn recursive_cata(node: &TrieNodeODRc, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W where @@ -2412,6 +2411,44 @@ where recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) } +/// A stepping (non-jumping) catamorphism for the trie. +/// +/// Use this when you need the cata to evaluate once per path byte, even across non-branching sub-paths. +/// Unlike the jumping version, `branch_f` and `finalize_f` will be called for every path byte. +/// +/// See [`recursive_cata`] for closure semantics; this stepping variant omits the prefix argument from `collapse_f`. +pub fn recursive_cata_stepping( + node: &TrieNodeODRc, + collapse_f: CollapseF, + branch_f: BranchF, + finalize_f: FinalizeF, +) -> W +where + V: Clone + Send + Sync, + A: Allocator, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, +{ + recursive_cata::<_, _, _, _, _, _, _, true>( + node, + |val, downstream, prefix| { + let mut w = collapse_f(val, downstream); + for byte in prefix.iter().rev() { + let mask = ByteMask::from(*byte); + let mut acc = Acc::default(); + branch_f(&mask, w, &mut acc); + w = finalize_f(&mask, acc); + } + w + }, + branch_f, + finalize_f, + ) +} + pub(crate) fn recursive_cata_cached( node: &TrieNodeODRc, collapse_f: CollapseF, @@ -3247,8 +3284,8 @@ impl DistributiveLat mod tests { use crate::alloc::{GlobalAlloc, global_alloc}; use crate::line_list_node::LineListNode; - use crate::trie_node::TrieNodeODRc; - use crate::trie_node::recursive_cata; + use crate::trie_node::{TrieNodeODRc, recursive_cata, recursive_cata_stepping}; + use crate::utils::ByteMask; use crate::PathMap; use crate::zipper::*; @@ -3282,6 +3319,117 @@ mod tests { drop(cloned); } + /// Adapted from morphisms::cata_test1 for recursive_cata (jumping). + //GOAT, this should be in morphisms, and only use the public API + #[test] + fn recursive_cata_jumping_sum_digits() { + let tests = [ + (vec![], 0), + (vec!["1"], 1), + (vec!["1", "2"], 3), + (vec!["1", "2", "3", "4", "5", "6"], 21), + (vec!["a1", "a2"], 3), + (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), + (vec!["12345"], 5), + (vec!["1", "12", "123", "1234", "12345"], 15), + (vec!["123", "123456", "123789"], 18), + (vec!["12", "123", "123456", "123789"], 20), + (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), + ]; + + #[derive(Default)] + struct SumAcc { + idx: usize, + sum: u32, + } + + for (keys, expected_sum) in tests { + let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); + let sum = match map.root() { + Some(root) => { + let w = recursive_cata::<_, _, _, _, _, _, _, true>( + root, + |val, downstream, prefix| { + let mut sum = downstream.map(|w: (bool, u32)| w.1).unwrap_or(0); + if val.is_some() { + if let Some(byte) = prefix.last() { + sum += (*byte as char).to_digit(10).unwrap(); + } + } + (val.is_some() && prefix.is_empty(), sum) + }, + |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { + let byte = mask.indexed_bit::(acc.idx).unwrap(); + acc.idx += 1; + if w.0 { + acc.sum += (byte as char).to_digit(10).unwrap(); + } + acc.sum += w.1; + }, + |_mask: &ByteMask, acc: SumAcc| { (false, acc.sum) }, + ); + w.1 + }, + None => 0, + }; + assert_eq!(sum, expected_sum); + } + } + + /// Adapted from morphisms::cata_test1 for recursive_cata_stepping (non-jumping). + //GOAT, this should be in morphisms, and only use the public API + #[test] + fn recursive_cata_stepping_sum_digits() { + let tests = [ + (vec![], 0), + (vec!["1"], 1), + (vec!["1", "2"], 3), + (vec!["1", "2", "3", "4", "5", "6"], 21), + (vec!["a1", "a2"], 3), + (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), + (vec!["12345"], 5), + (vec!["1", "12", "123", "1234", "12345"], 15), + (vec!["123", "123456", "123789"], 18), + (vec!["12", "123", "123456", "123789"], 20), + (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), + ]; + + #[derive(Default)] + struct SumAcc { + idx: usize, + sum: u32, + } + + for (keys, expected_sum) in tests { + let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); + let sum = match map.root() { + Some(root) => { + let w = recursive_cata_stepping::<_, _, SumAcc, (bool, u32), _, _, _>( + root, + |val, downstream| { + let sum = downstream.map(|w| w.1).unwrap_or(0); + (val.is_some(), sum) + }, + |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { + let byte = mask.iter().nth(acc.idx).unwrap(); + acc.idx += 1; + if w.0 { + acc.sum += (byte as char).to_digit(10).unwrap(); + } + acc.sum += w.1; + }, + |_mask: &ByteMask, acc: SumAcc| { + (false, acc.sum) + }, + ); + w.1 + }, + None => 0, + }; + assert_eq!(sum, expected_sum); + } + } + /// Finds the path_depth at which the recursive cata hits a stack overflow /// /// Empirically seems to be somewhere between 8 and 10 KBytes. But more branching, and thus fewer From df4e48adabef65050344b2446ebf00e5c0815756 Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Fri, 23 Jan 2026 00:51:46 -0700 Subject: [PATCH 14/15] Reorganizing recursive_cata so it can be part of the public API --- src/dense_byte_node.rs | 5 +- src/line_list_node.rs | 4 +- src/morphisms.rs | 263 ++++++++++++++++++++++++++++++++++++++++- src/tiny_node.rs | 4 +- src/trie_map.rs | 7 +- src/trie_node.rs | 223 ++-------------------------------- 6 files changed, 278 insertions(+), 228 deletions(-) diff --git a/src/dense_byte_node.rs b/src/dense_byte_node.rs index 87a1f9e0..5c4144ce 100644 --- a/src/dense_byte_node.rs +++ b/src/dense_byte_node.rs @@ -1,7 +1,6 @@ use core::fmt::{Debug, Formatter}; use core::ptr; -use std::collections::HashMap; use std::hint::unreachable_unchecked; use crate::alloc::Allocator; @@ -12,6 +11,8 @@ use crate::utils::BitMask; use crate::trie_node::*; use crate::line_list_node::LineListNode; +use crate::gxhash::HashMap; + //NOTE: This: `core::array::from_fn(|i| i as u8);` ought to work, but https://github.com/rust-lang/rust/issues/109341 const ALL_BYTES: [u8; 256] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]; @@ -1028,7 +1029,7 @@ impl> TrieNode } } } - fn node_val_count(&self, cache: &mut HashMap) -> usize { + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { //Discussion: These two implementations do the same thing but with a slightly different ordering of // the operations. In `all_dense_nodes`, the "Branchy" impl wins. But in a mixed-node setting, the // IMPL B is the winner. My suspicion is that the ListNode's heavily branching structure leads to diff --git a/src/line_list_node.rs b/src/line_list_node.rs index afa5ac8d..1e7844e6 100644 --- a/src/line_list_node.rs +++ b/src/line_list_node.rs @@ -1,6 +1,6 @@ use core::hint::unreachable_unchecked; use core::mem::{ManuallyDrop, MaybeUninit}; -use std::collections::HashMap; +use crate::gxhash::HashMap; use fast_slice_utils::{find_prefix_overlap, starts_with}; use local_or_heap::LocalOrHeap; @@ -1968,7 +1968,7 @@ impl TrieNode for LineListNode } } #[inline] - fn node_val_count(&self, cache: &mut HashMap) -> usize { + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { let mut result = 0; if self.is_used_value_0() { result += 1; diff --git a/src/morphisms.rs b/src/morphisms.rs index cc5b1f02..9fe0c88b 100644 --- a/src/morphisms.rs +++ b/src/morphisms.rs @@ -72,8 +72,10 @@ use crate::utils::*; use crate::alloc::Allocator; use crate::PathMap; use crate::trie_node::TrieNodeODRc; +use crate::trie_node::recursive_cata_cached; use crate::zipper; use crate::zipper::*; +use crate::zipper::zipper_priv::ZipperPriv; use crate::gxhash::{HashMap, HashMapExt}; @@ -213,8 +215,8 @@ pub trait Catamorphism { /// See [into_cata_cached](Catamorphism::into_cata_cached) for explanation of other arguments and behavior fn into_cata_jumping_cached(self, alg_f: AlgF) -> W where - W: Clone, - AlgF: Fn(&ByteMask, &mut [W], Option<&V>, &[u8]) -> W, + W: Clone, + AlgF: Fn(&ByteMask, &mut [W], Option<&V>, &[u8]) -> W, Self: Sized { self.into_cata_jumping_cached_fallible(|mask, children, val, sub_path| -> Result { @@ -231,6 +233,55 @@ pub trait Catamorphism { AlgF: Fn(&ByteMask, &mut [W], Option<&V>, &[u8]) -> Result; } +/// Provides faster catamorphism methods for types backed by an in-memory trie, such as [`PathMap`] +/// and some zipper implementations +pub trait Summarization { + + /// GOAT recursive cached cata. If this dev branch is successful this should replace the caching cata flavors in the public API + /// This is the JUMPING cata + /// + /// Closures: + /// + /// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path into a single `W` + /// `fn(val: Option<&V>, downstream: Option, prefix: &[u8]) -> W` + /// + /// `BranchF`: Accumulates the `W` representing a downstream branch into an `Acc` accumulator type + /// `fn(branch_mask: &ByteMask, downstream: W, accumulator: &mut Acc)` + /// + /// `FinalizeF`: Converts an `Acc` accumulator into a `W` representing the logical node + /// `fn(branch_mask: &ByteMask, accumulator: Acc) -> W` + /// + /// GOAT: The `COMPUTE_PATH` parameter shouldn't be necessary in a perfect world, but unfortunately the compiler + /// isn't very good at getting rid of the dead code, so passing `COMPUTE_PATH=false` gives a considerable speedup + /// at the expense of providing paths and reliable child_masks to the closures. + fn recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + Self: Sized; + + /// A stepping (non-jumping) catamorphism for the trie. + /// + /// Use this when you need the cata to evaluate once per path byte, even across non-branching sub-paths. + /// Unlike the jumping version, `branch_f` and `finalize_f` will be called for every path byte. + /// + /// See [`Catamorphism::recursive_cata`] for closure semantics; this stepping variant omits the prefix argument from `collapse_f`. + fn recursive_cata_stepping(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + Self: Sized; + +} + //TODO GOAT!!: It would be nice to get rid of this Default bound on all morphism Ws. In this case, the plan // for doing that would be to create a new type called a TakableSlice. It would be able to deref // into a regular mutable slice of `T` so it would work just like an ordinary slice. Additionally @@ -420,6 +471,99 @@ impl Catamorph } } +impl<'a, Z, V> Summarization for Z where Z: Zipper + ZipperReadOnlyConditionalValues<'a, V> + ZipperConcrete + ZipperAbsolutePath + ZipperPathBuffer + ZipperPriv { + fn recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + let focus = self.get_focus(); + let w = match focus.borrow() { + Some(node) => { + let mut cache = HashMap::new(); + recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) + }, + None => finalize_f(&ByteMask::EMPTY, Acc::default()), + }; + collapse_f(self.val(), Some(w), &[]) + } + + fn recursive_cata_stepping(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + self.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let mut w = collapse_f(val, downstream); + for byte in prefix.iter().rev() { + let mask = ByteMask::from(*byte); + let mut acc = Acc::default(); + branch_f(&mask, w, &mut acc); + w = finalize_f(&mask, acc); + } + w + }, + branch_f, + finalize_f, + ) + } +} + +impl Summarization for PathMap { + fn recursive_cata(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + let w = match self.root() { + Some(node) => { + let mut cache = HashMap::new(); + recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) + }, + None => finalize_f(&ByteMask::EMPTY, Acc::default()), + }; + collapse_f(self.root_val(), Some(w), &[]) + } + + fn recursive_cata_stepping(&self, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W + where + V: Clone + Send + Sync, + Acc: Default, + W: Clone, + CollapseF: Copy + Fn(Option<&V>, Option) -> W, + BranchF: Copy + Fn(&ByteMask, W, &mut Acc), + FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, + { + self.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let mut w = collapse_f(val, downstream); + for byte in prefix.iter().rev() { + let mask = ByteMask::from(*byte); + let mut acc = Acc::default(); + branch_f(&mask, w, &mut acc); + w = finalize_f(&mask, acc); + } + w + }, + branch_f, + finalize_f, + ) + } +} + #[inline] fn cata_side_effect_body<'a, Z, V: 'a, W, Err, AlgF, const JUMPING: bool>(mut z: Z, mut alg_f: AlgF) -> Result where @@ -1956,6 +2100,121 @@ mod tests { eprintln!("calls_cached: {calls_cached}\ncalls_side: {calls_side}"); } + /// Adapted from morphisms::cata_test1 for recursive_cata (jumping). + #[test] + fn recursive_cata_jumping_sum_digits() { + let tests = [ + (vec![], 0), + (vec!["1"], 1), + (vec!["1", "2"], 3), + (vec!["1", "2", "3", "4", "5", "6"], 21), + (vec!["a1", "a2"], 3), + (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), + (vec!["12345"], 5), + (vec!["1", "12", "123", "1234", "12345"], 15), + (vec!["123", "123456", "123789"], 18), + (vec!["12", "123", "123456", "123789"], 20), + (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), + ]; + + #[derive(Default)] + struct SumAcc { + idx: usize, + sum: u32, + } + + for (keys, expected_sum) in tests { + let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); + let sum = map.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let mut sum = downstream.map(|w: (bool, u32)| w.1).unwrap_or(0); + if val.is_some() { + if let Some(byte) = prefix.last() { + sum += (*byte as char).to_digit(10).unwrap(); + } + } + (val.is_some() && prefix.is_empty(), sum) + }, + |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { + let byte = mask.indexed_bit::(acc.idx).unwrap(); + acc.idx += 1; + if w.0 { + acc.sum += (byte as char).to_digit(10).unwrap(); + } + acc.sum += w.1; + }, + |_mask: &ByteMask, acc: SumAcc| { (false, acc.sum) }, + ).1; + assert_eq!(sum, expected_sum); + } + } + + /// Adapted from morphisms::cata_test1 for recursive_cata_stepping (non-jumping). + #[test] + fn recursive_cata_stepping_sum_digits() { + let tests = [ + (vec![], 0), + (vec!["1"], 1), + (vec!["1", "2"], 3), + (vec!["1", "2", "3", "4", "5", "6"], 21), + (vec!["a1", "a2"], 3), + (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), + (vec!["12345"], 5), + (vec!["1", "12", "123", "1234", "12345"], 15), + (vec!["123", "123456", "123789"], 18), + (vec!["12", "123", "123456", "123789"], 20), + (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), + ]; + + #[derive(Default)] + struct SumAcc { + idx: usize, + sum: u32, + } + + for (keys, expected_sum) in tests { + let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); + let sum = map.recursive_cata_stepping::( + |val, downstream| { + let sum = downstream.map(|w| w.1).unwrap_or(0); + (val.is_some(), sum) + }, + |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { + let byte = mask.iter().nth(acc.idx).unwrap(); + acc.idx += 1; + if w.0 { + acc.sum += (byte as char).to_digit(10).unwrap(); + } + acc.sum += w.1; + }, + |_mask: &ByteMask, acc: SumAcc| { + (false, acc.sum) + }, + ).1; + assert_eq!(sum, expected_sum); + } + } + + /// Finds the path_depth at which the recursive cata hits a stack overflow + /// + /// Empirically seems to be somewhere between 8 and 10 KBytes. But more branching, and thus fewer + /// bytes-per-node, will mean it will fail on shorter paths. + #[test] + fn recursive_cata_stack_overflow_smoke() { + const PATH_LEN: usize = 8_000; + + let mut map = PathMap::<()>::new(); + let path = vec![b'a'; PATH_LEN]; + map.set_val_at(&path, ()); + + let count = map.recursive_cata::<_, _, _, _, _, false>( + |v, w, _| (v.is_some() as usize) + w.unwrap_or(0), + |_mask, w: usize, total| { *total += w }, + |_mask, total: usize| { total }, + ); + assert_eq!(count, 1); + } + /// Generate some basic tries using the [TrieBuilder::push_byte] API #[test] fn ana_test1() { diff --git a/src/tiny_node.rs b/src/tiny_node.rs index f12584ca..1c93f1a3 100644 --- a/src/tiny_node.rs +++ b/src/tiny_node.rs @@ -9,7 +9,7 @@ use core::mem::MaybeUninit; use core::fmt::{Debug, Formatter}; -use std::collections::HashMap; +use crate::gxhash::HashMap; use fast_slice_utils::{find_prefix_overlap, starts_with}; use crate::utils::ByteMask; @@ -228,7 +228,7 @@ impl<'a, V: Clone + Send + Sync, A: Allocator> TrieNode for TinyRefNode<'a fn new_iter_token(&self) -> u128 { unreachable!() } fn iter_token_for_path(&self, _key: &[u8]) -> u128 { unreachable!() } fn next_items(&self, _token: u128) -> (u128, &'a[u8], Option<&TrieNodeODRc>, Option<&V>) { unreachable!() } - fn node_val_count(&self, cache: &mut HashMap) -> usize { + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { let temp_node = self.into_full().unwrap(); temp_node.node_val_count(cache) } diff --git a/src/trie_map.rs b/src/trie_map.rs index aa835b06..76311499 100644 --- a/src/trie_map.rs +++ b/src/trie_map.rs @@ -1,7 +1,7 @@ use core::cell::UnsafeCell; use std::ptr::slice_from_raw_parts; use crate::alloc::{Allocator, GlobalAlloc, global_alloc}; -use crate::morphisms::{new_map_from_ana_in, Catamorphism, TrieBuilder}; +use crate::morphisms::{new_map_from_ana_in, Catamorphism, Summarization, TrieBuilder}; use crate::trie_node::*; use crate::zipper::*; use crate::merkleization::{MerkleizeResult, merkleize_impl}; @@ -510,9 +510,8 @@ impl PathMap { pub fn goat_val_count(&self) -> usize { let root_val = unsafe{ &*self.root_val.get() }.is_some() as usize; match self.root() { - Some(root) => { - recursive_cata::<_, _, _, _, _, _, _, false>( - root, + Some(_root) => { + self.recursive_cata::<_, _, _, _, _, false>( |v, w, _| { (v.is_some() as usize) + w.unwrap_or(0) }, // on values amongst a path |_mask, w: usize, total| { *total += w }, // on merging children into a node |_mask, total: usize| { total } // finalizing a node diff --git a/src/trie_node.rs b/src/trie_node.rs index 45cf21da..fb05bd2e 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -2,7 +2,6 @@ use core::hint::unreachable_unchecked; use core::mem::ManuallyDrop; use core::ptr::NonNull; -use std::collections::HashMap; use dyn_clone::*; use local_or_heap::LocalOrHeap; use arrayvec::ArrayVec; @@ -14,6 +13,8 @@ use crate::ring::*; use crate::tiny_node::TinyRefNode; use crate::line_list_node::LineListNode; +use crate::gxhash::HashMap; + #[cfg(feature = "bridge_nodes")] use crate::bridge_node::BridgeNode; @@ -220,7 +221,7 @@ pub(crate) trait TrieNode: TrieNodeDowncas /// Returns the total number of leaves contained within the whole subtree defined by the node /// GOAT, this should be deprecated - fn node_val_count(&self, cache: &mut HashMap) -> usize; + fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize; /// Returns the number of values contained within the node itself, irrespective of the positions within /// the node; does not include onward links @@ -1222,7 +1223,7 @@ mod tagged_node_ref { } #[inline] - pub fn node_val_count(&self, cache: &mut HashMap) -> usize { + pub fn node_val_count(&self, cache: &mut std::collections::HashMap) -> usize { match self { Self::DenseByteNode(node) => node.node_val_count(cache), Self::LineListNode(node) => node.node_val_count(cache), @@ -2353,7 +2354,7 @@ pub(crate) fn val_count_below_root(node: T node.node_val_count(&mut cache) } -pub(crate) fn val_count_below_node(node: &TrieNodeODRc, cache: &mut HashMap) -> usize { +pub(crate) fn val_count_below_node(node: &TrieNodeODRc, cache: &mut std::collections::HashMap) -> usize { if node.is_empty() { return 0 } @@ -2372,83 +2373,7 @@ pub(crate) fn val_count_below_node(node: & } } -/// GOAT recursive caching cata. If this dev branch is successful this should replace the caching cata flavors in the public API -/// This is the JUMPING cata -/// -/// Closures: -/// -/// `CollapseF`: Folds a possible value and a possible downstream continuation, prefixed by a linear sub-path into a single `W` -/// `fn(val: Option<&V>, downstream: Option, prefix: &[u8]) -> W` -/// -/// `BranchF`: Accumulates the `W` representing a downstream branch into an `Acc` accumulator type -/// `fn(branch_mask: &ByteMask, downstream: W, accumulator: &mut Acc)` -/// -/// `FinalizeF`: Converts an `Acc` accumulator into a `W` representing the logical node -/// `fn(branch_mask: &ByteMask, accumulator: Acc) -> W` - -// -//New plan for API. -// 3 closures. -// - closure to deal with straight paths and values. Fn(Option<&V>, Option, &[u8]) -> W -// - closure to deal with one downstream branch from a logical node. Fn(&ByteMask, W, &mut Acc) -// - closure to fold accumulator back into a W for the logical node. Fn(&ByteMask, Acc) -> W -// -// -//The non-jumping API would be the same, but collapse wouldn't take a prefix path, and instead `finalize_f(branch_f())` would be -// called in reverse order for each path byte - -pub fn recursive_cata(node: &TrieNodeODRc, collapse_f: CollapseF, branch_f: BranchF, finalize_f: FinalizeF) -> W -where - V: Clone + Send + Sync, - A: Allocator, - Acc: Default, - W: Clone, - CollapseF: Copy + Fn(Option<&V>, Option, &[u8]) -> W, - BranchF: Copy + Fn(&ByteMask, W, &mut Acc), - FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, -{ - let mut cache = HashMap::new(); - recursive_cata_cached::<_, _, _, _, _, _, _, COMPUTE_PATH>(node, collapse_f, branch_f, finalize_f, &mut cache) -} - -/// A stepping (non-jumping) catamorphism for the trie. -/// -/// Use this when you need the cata to evaluate once per path byte, even across non-branching sub-paths. -/// Unlike the jumping version, `branch_f` and `finalize_f` will be called for every path byte. -/// -/// See [`recursive_cata`] for closure semantics; this stepping variant omits the prefix argument from `collapse_f`. -pub fn recursive_cata_stepping( - node: &TrieNodeODRc, - collapse_f: CollapseF, - branch_f: BranchF, - finalize_f: FinalizeF, -) -> W -where - V: Clone + Send + Sync, - A: Allocator, - Acc: Default, - W: Clone, - CollapseF: Copy + Fn(Option<&V>, Option) -> W, - BranchF: Copy + Fn(&ByteMask, W, &mut Acc), - FinalizeF: Copy + Fn(&ByteMask, Acc) -> W, -{ - recursive_cata::<_, _, _, _, _, _, _, true>( - node, - |val, downstream, prefix| { - let mut w = collapse_f(val, downstream); - for byte in prefix.iter().rev() { - let mask = ByteMask::from(*byte); - let mut acc = Acc::default(); - branch_f(&mask, w, &mut acc); - w = finalize_f(&mask, acc); - } - w - }, - branch_f, - finalize_f, - ) -} - +/// Internal implementation of recursive_cata pub(crate) fn recursive_cata_cached( node: &TrieNodeODRc, collapse_f: CollapseF, @@ -3284,8 +3209,7 @@ impl DistributiveLat mod tests { use crate::alloc::{GlobalAlloc, global_alloc}; use crate::line_list_node::LineListNode; - use crate::trie_node::{TrieNodeODRc, recursive_cata, recursive_cata_stepping}; - use crate::utils::ByteMask; + use crate::trie_node::TrieNodeODRc; use crate::PathMap; use crate::zipper::*; @@ -3318,137 +3242,4 @@ mod tests { node_ref.make_unique(); drop(cloned); } - - /// Adapted from morphisms::cata_test1 for recursive_cata (jumping). - //GOAT, this should be in morphisms, and only use the public API - #[test] - fn recursive_cata_jumping_sum_digits() { - let tests = [ - (vec![], 0), - (vec!["1"], 1), - (vec!["1", "2"], 3), - (vec!["1", "2", "3", "4", "5", "6"], 21), - (vec!["a1", "a2"], 3), - (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), - (vec!["12345"], 5), - (vec!["1", "12", "123", "1234", "12345"], 15), - (vec!["123", "123456", "123789"], 18), - (vec!["12", "123", "123456", "123789"], 20), - (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), - ]; - - #[derive(Default)] - struct SumAcc { - idx: usize, - sum: u32, - } - - for (keys, expected_sum) in tests { - let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); - let sum = match map.root() { - Some(root) => { - let w = recursive_cata::<_, _, _, _, _, _, _, true>( - root, - |val, downstream, prefix| { - let mut sum = downstream.map(|w: (bool, u32)| w.1).unwrap_or(0); - if val.is_some() { - if let Some(byte) = prefix.last() { - sum += (*byte as char).to_digit(10).unwrap(); - } - } - (val.is_some() && prefix.is_empty(), sum) - }, - |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { - let byte = mask.indexed_bit::(acc.idx).unwrap(); - acc.idx += 1; - if w.0 { - acc.sum += (byte as char).to_digit(10).unwrap(); - } - acc.sum += w.1; - }, - |_mask: &ByteMask, acc: SumAcc| { (false, acc.sum) }, - ); - w.1 - }, - None => 0, - }; - assert_eq!(sum, expected_sum); - } - } - - /// Adapted from morphisms::cata_test1 for recursive_cata_stepping (non-jumping). - //GOAT, this should be in morphisms, and only use the public API - #[test] - fn recursive_cata_stepping_sum_digits() { - let tests = [ - (vec![], 0), - (vec!["1"], 1), - (vec!["1", "2"], 3), - (vec!["1", "2", "3", "4", "5", "6"], 21), - (vec!["a1", "a2"], 3), - (vec!["a1", "a2", "a3", "a4", "a5", "a6"], 21), - (vec!["12345"], 5), - (vec!["1", "12", "123", "1234", "12345"], 15), - (vec!["123", "123456", "123789"], 18), - (vec!["12", "123", "123456", "123789"], 20), - (vec!["1", "2", "123", "123765", "1234", "12345", "12349"], 29), - ]; - - #[derive(Default)] - struct SumAcc { - idx: usize, - sum: u32, - } - - for (keys, expected_sum) in tests { - let map: PathMap<()> = keys.into_iter().map(|v| (v, ())).collect(); - let sum = match map.root() { - Some(root) => { - let w = recursive_cata_stepping::<_, _, SumAcc, (bool, u32), _, _, _>( - root, - |val, downstream| { - let sum = downstream.map(|w| w.1).unwrap_or(0); - (val.is_some(), sum) - }, - |mask: &ByteMask, w: (bool, u32), acc: &mut SumAcc| { - let byte = mask.iter().nth(acc.idx).unwrap(); - acc.idx += 1; - if w.0 { - acc.sum += (byte as char).to_digit(10).unwrap(); - } - acc.sum += w.1; - }, - |_mask: &ByteMask, acc: SumAcc| { - (false, acc.sum) - }, - ); - w.1 - }, - None => 0, - }; - assert_eq!(sum, expected_sum); - } - } - - /// Finds the path_depth at which the recursive cata hits a stack overflow - /// - /// Empirically seems to be somewhere between 8 and 10 KBytes. But more branching, and thus fewer - /// bytes-per-node, will mean it will fail on shorter paths. - #[test] - fn recursive_cata_stack_overflow_smoke() { - const PATH_LEN: usize = 8_000; - - let mut map = PathMap::<()>::new(); - let path = vec![b'a'; PATH_LEN]; - map.set_val_at(&path, ()); - - let root = map.root().unwrap(); - let count = recursive_cata::<_, _, _, _, _, _, _, false>( - root, - |v, w, _| (v.is_some() as usize) + w.unwrap_or(0), - |_mask, w: usize, total| { *total += w }, - |_mask, total: usize| { total }, - ); - assert_eq!(count, 1); - } } From 2567f33bdce20f0084771de8a8e69e1c399d9d6c Mon Sep 17 00:00:00 2001 From: Luke Peterson Date: Fri, 23 Jan 2026 01:22:59 -0700 Subject: [PATCH 15/15] Adding catamorphism benchmark to put two cata implementations head-to-head --- Cargo.toml | 4 ++ benches/catamorphism.rs | 105 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 benches/catamorphism.rs diff --git a/Cargo.toml b/Cargo.toml index 0cf1cecc..6982994b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,10 @@ harness = false name = "product_zipper" harness = false +[[bench]] +name = "catamorphism" +harness = false + [[bench]] name = "sla" harness = false diff --git a/benches/catamorphism.rs b/benches/catamorphism.rs new file mode 100644 index 00000000..11d2fb55 --- /dev/null +++ b/benches/catamorphism.rs @@ -0,0 +1,105 @@ +use divan::{Divan, Bencher, black_box}; +use pathmap::morphisms::{Catamorphism, Summarization}; +use pathmap::utils::ByteMask; +use pathmap::utils::ints::gen_int_range; +use pathmap::PathMap; + +fn main() { + // Run registered benchmarks. + let divan = Divan::from_args() + .sample_count(4000); + + divan.main(); +} + +fn build_map(count: u64) -> PathMap<()> { + // Dense range of u64 keys encoded as paths; sized to keep benches fast and stable. + gen_int_range::<(), 8, u64>(0, count, 1, ()) +} + +const MAP_COUNT: u64 = 20_000_000; + +#[divan::bench()] +fn recursive_cata_jumping_val_count(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = 0usize; + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.recursive_cata::<_, _, _, _, _, false>( + |v, w, _| (v.is_some() as usize) + w.unwrap_or(0), + |_mask, w: usize, total| { *total += w }, + |_mask, total: usize| { total }, + ); + }); + assert_eq!(sink, MAP_COUNT as usize); +} + +#[divan::bench()] +fn cached_jumping_cata_val_count(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = 0usize; + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.into_cata_jumping_cached(|_mask: &ByteMask, children: &mut [usize], val, _sub_path| { + let mut sum: usize = children.iter().sum(); + if val.is_some() { + sum += 1; + } + sum + }); + }); + assert_eq!(sink, MAP_COUNT as usize); +} + +#[divan::bench()] +fn recursive_cata_jumping_total_len(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = (0usize, 0usize); + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.recursive_cata::<_, _, _, _, _, true>( + |val, downstream, prefix| { + let (mut count, mut total_len) = downstream.unwrap_or((0, 0)); + total_len += count * prefix.len(); + if val.is_some() { + count += 1; + total_len += prefix.len(); + } + (count, total_len) + }, + |mask: &ByteMask, w: (usize, usize), acc: &mut (usize, usize, usize)| { + let byte = mask.indexed_bit::(acc.0).unwrap(); + let _ = byte; // byte value unused; only length matters + acc.0 += 1; + acc.1 += w.0; + acc.2 += w.1; + }, + |_mask: &ByteMask, acc: (usize, usize, usize)| { (acc.1, acc.2) }, + ); + }); + assert_eq!(sink.0, MAP_COUNT as usize); +} + +#[divan::bench()] +fn cached_jumping_cata_total_len(bencher: Bencher) { + let map = build_map(MAP_COUNT); + let mut sink = (0usize, 0usize); + bencher.bench_local(|| { + let rz = map.read_zipper(); + *black_box(&mut sink) = rz.into_cata_jumping_cached(|mask: &ByteMask, children: &mut [(usize, usize)], val, sub_path| { + let mut count = 0usize; + let mut total_len = 0usize; + let prefix_len = sub_path.len(); + if val.is_some() { + count += 1; + total_len += prefix_len; + } + for (_byte, child) in mask.iter().zip(children.iter_mut()) { + count += child.0; + total_len += child.1 + child.0 * prefix_len; + } + (count, total_len) + }); + }); + assert_eq!(sink.0, MAP_COUNT as usize); +}