diff --git a/libnail/src/align/mod.rs b/libnail/src/align/mod.rs index d95325e..bfa2937 100644 --- a/libnail/src/align/mod.rs +++ b/libnail/src/align/mod.rs @@ -17,7 +17,8 @@ pub use optimal_accuracy::optimal_accuracy; mod scoring; pub use scoring::{ - cloud_score, e_value, null_one_score, null_two_score, p_value, Bits, Nats, Score, + cloud_score, e_value, null_one_score, null_two_score, p_value, Bits, Nats, NullTwoScratch, + Score, }; mod traceback; diff --git a/libnail/src/align/scoring.rs b/libnail/src/align/scoring.rs index e71cc75..e02c52e 100644 --- a/libnail/src/align/scoring.rs +++ b/libnail/src/align/scoring.rs @@ -217,18 +217,52 @@ pub fn null_one_score(target_length: usize) -> Nats { Nats(target_length as f32 * p_null_loop.ln() + p_null_exit.ln()) } +/// Reusable scratch buffers for `null_two_score`. Allocate once per thread via +/// `NullTwoScratch::default()` and pass `&mut scratch` on every call to avoid +/// per-call heap allocation. +#[derive(Default, Clone)] +pub struct NullTwoScratch { + pub expected_prob_ratios: Vec, + pub match_sums: Vec, + pub insert_sums: Vec, + pub core_posteriors: Vec, +} + +impl NullTwoScratch { + fn prepare(&mut self, profile_length: usize, target_length: usize) { + // Grow each buffer if the current call needs more capacity than we've + // seen before; otherwise reuse the existing allocation. Either way, + // zero only the prefix that this call will actually read/write — + // leaving stale values beyond the active range is intentional. + Self::zero_prefix(&mut self.expected_prob_ratios, Profile::MAX_DEGENERATE_ALPHABET_SIZE); + Self::zero_prefix(&mut self.match_sums, profile_length + 1); + Self::zero_prefix(&mut self.insert_sums, profile_length + 1); + Self::zero_prefix(&mut self.core_posteriors, target_length + 1); + } + + #[inline] + fn zero_prefix(v: &mut Vec, n: usize) { + if v.len() < n { + v.resize(n, 0.0); // allocates only when a new high-water mark is reached + } else { + v[..n].fill(0.0); // only touches the n elements this call will use + } + } +} + /// Compute the null two score adjustment: the composition bias. pub fn null_two_score( posterior_matrix: &impl DpMatrix, profile: &Profile, target: &Sequence, row_bounds: &RowBounds, + scratch: &mut NullTwoScratch, ) -> Nats { - // TODO: prevent these allocations? - let mut expected_prob_ratios: Vec = vec![0.0; Profile::MAX_DEGENERATE_ALPHABET_SIZE]; - let mut match_sums: Vec = vec![0.0; profile.length + 1]; - let mut insert_sums: Vec = vec![0.0; profile.length + 1]; - let mut core_posteriors: Vec = vec![0.0; target.length + 1]; + scratch.prepare(profile.length, target.length); + let expected_prob_ratios = &mut scratch.expected_prob_ratios; + let match_sums = &mut scratch.match_sums; + let insert_sums = &mut scratch.insert_sums; + let core_posteriors = &mut scratch.core_posteriors; let mut core_state_sum: f32 = 0.0; // what: for each position in the model, take the sum of diff --git a/nail/src/pipeline/align_stage.rs b/nail/src/pipeline/align_stage.rs index 01c7478..b2732c9 100644 --- a/nail/src/pipeline/align_stage.rs +++ b/nail/src/pipeline/align_stage.rs @@ -6,7 +6,7 @@ use libnail::{ align::{ backward, forward, null_one_score, null_two_score, optimal_accuracy, p_value, posterior, structs::{Alignment, AlignmentBuilder, DpMatrixSparse, RowBounds, Trace}, - traceback, Bits, + traceback, Bits, NullTwoScratch, }, structs::{Profile, Sequence}, }; @@ -72,6 +72,7 @@ pub struct DefaultAlignStage { backward_matrix: DpMatrixSparse, posterior_matrix: DpMatrixSparse, optimal_matrix: DpMatrixSparse, + null_two_scratch: NullTwoScratch, forward_p_value_threshold: f64, target_count: usize, config: AlignConfig, @@ -184,6 +185,7 @@ impl AlignStage for DefaultAlignStage { profile, target, bounds, + &mut self.null_two_scratch, )); stats.null_two_time(now.elapsed()); score diff --git a/nail/src/stats.rs b/nail/src/stats.rs index 5aba9cc..29ec6e1 100644 --- a/nail/src/stats.rs +++ b/nail/src/stats.rs @@ -97,6 +97,7 @@ pub enum ThreadedTimed { Forward, Backward, Posterior, + OptimalAccuracy, Traceback, NullTwo, OutputWrite, @@ -115,6 +116,7 @@ impl Debug for ThreadedTimed { ThreadedTimed::Forward => "forward", ThreadedTimed::Backward => "backward", ThreadedTimed::Posterior => "posterior", + ThreadedTimed::OptimalAccuracy => "optimal accuracy", ThreadedTimed::Traceback => "traceback", ThreadedTimed::NullTwo => "null two", }; @@ -264,6 +266,7 @@ impl Stats { self.add_threaded_time(ThreadedTimed::Backward, stats.backward_time); self.add_threaded_time(ThreadedTimed::Posterior, stats.posterior_time); + self.add_threaded_time(ThreadedTimed::OptimalAccuracy, stats.optimal_accuracy_time); self.add_threaded_time(ThreadedTimed::Traceback, stats.traceback_time); self.add_threaded_time(ThreadedTimed::NullTwo, stats.null_two_time); }