From e3e1651dfa36faf1d5969a43c48b579e29698b0c Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 10:10:30 +0530 Subject: [PATCH 01/39] Sync changes from CDB_DiskANN repo - Refactored recall utilities in diskann-benchmark - Updated tokio utilities - Added attribute and format parser improvements in label-filter - Updated ground_truth utilities in diskann-tools --- diskann-benchmark/src/utils/recall.rs | 703 +----------------- diskann-benchmark/src/utils/tokio.rs | 20 +- diskann-label-filter/src/attribute.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 + .../src/utils/flatten_utils.rs | 2 +- diskann-tools/Cargo.toml | 18 +- diskann-tools/src/utils/ground_truth.rs | 161 +++- 7 files changed, 196 insertions(+), 711 deletions(-) diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 5b7fd1594..bfaf46772 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,15 +3,13 @@ * Licensed under the MIT license. */ -use std::{collections::HashSet, hash::Hash}; - -use diskann_utils::strided::StridedView; -use diskann_utils::views::MatrixView; +use diskann_benchmark_core as benchmark_core; +pub(crate) use benchmark_core::recall::knn; use serde::Serialize; -use thiserror::Error; -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, @@ -25,278 +23,19 @@ pub(crate) struct RecallMetrics { pub(crate) minimum: usize, /// The maximum observed recall (max possible value: `recall_k`). pub(crate) maximum: usize, - /// Recall results by query - pub(crate) by_query: Option>, -} - -// impl RecallMetrics { -// pub(crate) fn num_queries(&self) -> usize { -// self.num_queries -// } - -// pub(crate) fn average(&self) -> f64 { -// self.average -// } -// } - -#[derive(Debug, Error)] -pub(crate) enum ComputeRecallError { - #[error("results matrix has {0} rows but ground truth has {1}")] - RowsMismatch(usize, usize), - #[error("distances matrix has {0} rows but ground truth has {1}")] - DistanceRowsMismatch(usize, usize), - #[error("recall k value {0} must be less than or equal to recall n {1}")] - RecallKAndNError(usize, usize), - #[error("number of results per query {0} must be at least the specified recall k {1}")] - NotEnoughResults(usize, usize), - #[error( - "number of groundtruth values per query {0} must be at least the specified recall n {1}" - )] - NotEnoughGroundTruth(usize, usize), - #[error("number of groundtruth distances {0} does not match groundtruth entries {1}")] - NotEnoughGroundTruthDistances(usize, usize), -} - -pub(crate) trait ComputeKnnRecall { - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result; -} - -impl ComputeKnnRecall for MatrixView<'_, T> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -impl ComputeKnnRecall for Vec> -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - fn compute_knn_recall( - &self, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result { - compute_knn_recall( - self, - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } -} - -pub(crate) trait KnnRecall { - type Item; - - fn nrows(&self) -> usize; - fn ncols(&self) -> Option; - fn row(&self, i: usize) -> &[Self::Item]; -} - -impl KnnRecall for MatrixView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - MatrixView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(MatrixView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - MatrixView::<'_, T>::row(self, i) - } -} - -impl KnnRecall for Vec> { - type Item = T; - - fn nrows(&self) -> usize { - self.len() - } - fn ncols(&self) -> Option { - None - } - fn row(&self, i: usize) -> &[Self::Item] { - &self[i] - } } -impl KnnRecall for StridedView<'_, T> { - type Item = T; - - fn nrows(&self) -> usize { - StridedView::<'_, T>::nrows(self) - } - fn ncols(&self) -> Option { - Some(StridedView::<'_, T>::ncols(self)) - } - fn row(&self, i: usize) -> &[Self::Item] { - StridedView::<'_, T>::row(self, i) - } -} - -fn compute_knn_recall( - groundtruth: &K, - groundtruth_distances: Option>, - results: StridedView<'_, T>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, - K: KnnRecall, -{ - if recall_k > recall_n { - return Err(ComputeRecallError::RecallKAndNError(recall_k, recall_n)); - } - - let nrows = results.nrows(); - if nrows != groundtruth.nrows() { - return Err(ComputeRecallError::RowsMismatch(nrows, groundtruth.nrows())); - } - - if results.ncols() < recall_n && !allow_insufficient_results { - return Err(ComputeRecallError::NotEnoughResults( - results.ncols(), - recall_n, - )); - } - - // Validate groundtruth size for fixed-size sources - match groundtruth.ncols() { - Some(ncols) if ncols < recall_k => { - return Err(ComputeRecallError::NotEnoughGroundTruth(ncols, recall_k)); - } - _ => {} - } - - if let Some(distances) = groundtruth_distances { - if nrows != distances.nrows() { - return Err(ComputeRecallError::DistanceRowsMismatch( - distances.nrows(), - nrows, - )); - } - - match groundtruth.ncols() { - Some(ncols) if distances.ncols() != ncols => { - return Err(ComputeRecallError::NotEnoughGroundTruthDistances( - distances.ncols(), - ncols, - )); - } - _ => {} +impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { + fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { + Self { + recall_k: m.recall_k, + recall_n: m.recall_n, + num_queries: m.num_queries, + average: m.average, + minimum: m.minimum, + maximum: m.maximum, } } - - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); - let mut this_groundtruth = HashSet::new(); - let mut this_results = HashSet::new(); - - for (i, result) in results.row_iter().enumerate() { - let gt_row = groundtruth.row(i); - - // Populate the groundtruth using the top-k - this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().copied().take(recall_k)); - - // If we have distances, then continue to append distances as long as the distance - // value is constant - if let Some(distances) = groundtruth_distances { - if recall_k > 0 { - let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { - if *d == last_distance { - this_groundtruth.insert(*g); - } else { - break; - } - } - } - } - } - - this_results.clear(); - this_results.extend(result.iter().copied().take(recall_n)); - - // Count the overlap - let r = this_groundtruth - .iter() - .filter(|i| this_results.contains(i)) - .count() - .min(recall_k); - - recall_values.push(r); - } - - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); - - let div = if groundtruth.ncols().is_some() { - recall_k * nrows - } else { - (0..groundtruth.nrows()) - .map(|i| groundtruth.row(i).len()) - .sum::() - .max(1) - }; - - let average = (total as f64) / (div as f64); - - Ok(RecallMetrics { - recall_k, - recall_n, - num_queries: nrows, - average, - minimum: *minimum, - maximum: *maximum, - by_query: if enhanced_metrics { - Some(recall_values) - } else { - None - }, - }) } /// Compute `k-recall-at-n` for all valid combinations of values in `recall_k` and @@ -309,14 +48,13 @@ where feature = "product-quantization" ))] pub(crate) fn compute_multiple_recalls( - results: StridedView<'_, T>, - groundtruth: StridedView<'_, T>, + results: &dyn benchmark_core::recall::Rows, + groundtruth: &dyn benchmark_core::recall::Rows, recall_k: &[usize], recall_n: &[usize], - enhanced_metrics: bool, -) -> Result, ComputeRecallError> +) -> Result, benchmark_core::recall::ComputeRecallError> where - T: Eq + Hash + Copy + std::fmt::Debug, + T: benchmark_core::recall::RecallCompatible, { let mut result = Vec::new(); for k in recall_k { @@ -325,414 +63,27 @@ where continue; } - result.push(compute_knn_recall( - &groundtruth, - None, - results, - *k, - *n, - false, - enhanced_metrics, - )?); + let recall = benchmark_core::recall::knn(groundtruth, None, results, *k, *n, false)?; + result.push((&recall).into()); } } Ok(result) } -#[derive(Debug, Serialize)] -pub(crate) struct APMetrics { +#[derive(Debug, Clone, Serialize)] +#[non_exhaustive] +pub(crate) struct AveragePrecisionMetrics { /// The number of queries. pub(crate) num_queries: usize, /// The average precision pub(crate) average_precision: f64, } -#[derive(Debug, Error)] -pub(crate) enum ComputeAPError { - #[error("results has {0} elements but ground truth has {1}")] - EntriesMismatch(usize, usize), -} - -/// Compute average precision of a range search result -pub(crate) fn compute_average_precision( - results: Vec>, - groundtruth: &[Vec], -) -> Result -where - T: Eq + Hash + Copy + std::fmt::Debug, -{ - if results.len() != groundtruth.len() { - return Err(ComputeAPError::EntriesMismatch( - results.len(), - groundtruth.len(), - )); - } - - // The actual recall computation. - let mut num_gt_results = 0; - let mut num_reported_results = 0; - - let mut scratch = HashSet::new(); - - std::iter::zip(results.iter(), groundtruth.iter()).for_each(|(result, gt)| { - scratch.clear(); - scratch.extend(result.iter().copied()); - num_reported_results += gt.iter().filter(|i| scratch.contains(i)).count(); - num_gt_results += gt.len(); - }); - - // Perform post-processing. - let average_precision = (num_reported_results as f64) / (num_gt_results as f64); - - Ok(APMetrics { - average_precision, - num_queries: results.len(), - }) -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use diskann_utils::views::Matrix; - - use super::*; - - pub(crate) fn compute_knn_recall( - results: StridedView<'_, u32>, - groundtruth: G, // StridedView - groundtruth_distances: Option>, - recall_k: usize, - recall_n: usize, - allow_insufficient_results: bool, - enhanced_metrics: bool, - ) -> Result - where - G: ComputeKnnRecall + KnnRecall + Clone, - { - groundtruth.compute_knn_recall( - groundtruth_distances, - results, - recall_k, - recall_n, - allow_insufficient_results, - enhanced_metrics, - ) - } - - struct ExpectedRecall { - recall_k: usize, - recall_n: usize, - // Recall for each component. - components: Vec, - } - - impl ExpectedRecall { - fn new(recall_k: usize, recall_n: usize, components: Vec) -> Self { - assert!(recall_k <= recall_n); - components.iter().for_each(|x| { - assert!(*x <= recall_k); - }); - Self { - recall_k, - recall_n, - components, - } - } - - fn compute_recall(&self) -> f64 { - (self.components.iter().sum::() as f64) - / ((self.components.len() * self.recall_k) as f64) - } - } - - #[test] - fn test_happy_path() { - let groundtruth = Matrix::try_from( - vec![ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 0 - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // row 1 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 2 - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - let distances = Matrix::try_from( - vec![ - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 0 - 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 1 - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, // row 2 - 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 6.0, // row 3 - ] - .into(), - 4, - 10, - ) - .unwrap(); - - // Shift row 0 by one and row 1 by two. - let our_results = Matrix::try_from( - vec![ - 100, 0, 1, 2, 5, 6, // row 0 - 100, 101, 7, 8, 9, 10, // row 1 - 0, 1, 2, 3, 4, 5, // row 2 - 0, 1, 2, 3, 4, 5, // row 3 - ] - .into(), - 4, - 6, - ) - .unwrap(); - - //---------// - // No Ties // - //---------// - let expected_no_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![3, 3, 5, 5]), - ExpectedRecall::new(6, 6, vec![4, 4, 6, 6]), - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 0, 2, 2]), - ExpectedRecall::new(3, 5, vec![3, 1, 3, 3]), - ]; - let epsilon = 1e-6; // Define a small tolerance - - for (i, expected) in expected_no_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - None, - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - } - - //-----------// - // With Ties // - //-----------// - let expected_with_ties = vec![ - // Equal `k` and `n` - ExpectedRecall::new(1, 1, vec![0, 0, 1, 1]), - ExpectedRecall::new(2, 2, vec![1, 0, 2, 2]), - ExpectedRecall::new(3, 3, vec![2, 1, 3, 3]), - ExpectedRecall::new(4, 4, vec![3, 2, 4, 4]), - ExpectedRecall::new(5, 5, vec![4, 3, 5, 5]), // tie-breaker kicks in - ExpectedRecall::new(6, 6, vec![5, 4, 6, 6]), // tie-breaker kicks in - // Unequal `k` and `n`. - ExpectedRecall::new(1, 2, vec![1, 0, 1, 1]), - ExpectedRecall::new(1, 3, vec![1, 0, 1, 1]), - ExpectedRecall::new(2, 3, vec![2, 1, 2, 2]), - ExpectedRecall::new(4, 5, vec![4, 3, 4, 4]), - ]; - - for (i, expected) in expected_with_ties.iter().enumerate() { - assert_eq!(expected.components.len(), our_results.nrows()); - let recall = compute_knn_recall( - our_results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - expected.recall_k, - expected.recall_n, - false, - true, - ) - .unwrap(); - - let left = recall.average; - let right = expected.compute_recall(); - assert!( - (left - right).abs() < epsilon, - "left = {}, right = {} on input {}", - left, - right, - i - ); - - assert_eq!(recall.num_queries, our_results.nrows()); - assert_eq!(recall.recall_k, expected.recall_k); - assert_eq!(recall.recall_n, expected.recall_n); - assert_eq!(recall.minimum, *expected.components.iter().min().unwrap()); - assert_eq!(recall.maximum, *expected.components.iter().max().unwrap()); - assert_eq!(recall.by_query, Some(expected.components.clone())); - } - } - - #[test] - fn test_errors() { - // k greater than n - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 11, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RecallKAndNError(..))); - } - - // Unequal rows - { - let groundtruth = Matrix::::new(0, 11, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::RowsMismatch(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::RowsMismatch(..) - )); - } - - // Not enough results - { - let groundtruth = Matrix::::new(0, 10, 10); - let results = Matrix::::new(0, 10, 5); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - false, - false, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughResults(..))); - let _ = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 5, - 10, - true, - false, - ); - } - - // Not enough groundtruth - { - let groundtruth = Matrix::::new(0, 10, 5); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..))); - let err_allow_insufficient_results = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - None, - 10, - 10, - true, - false, - ) - .unwrap_err(); - assert!(matches!( - err_allow_insufficient_results, - ComputeRecallError::NotEnoughGroundTruth(..) - )); - } - - // Distance Row Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 9, 10); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!(err, ComputeRecallError::DistanceRowsMismatch(..))); - } - - // Distance Cols Mismatch - { - let groundtruth = Matrix::::new(0, 10, 10); - let distances = Matrix::::new(0.0, 10, 9); - let results = Matrix::::new(0, 10, 10); - let err = compute_knn_recall( - results.as_view().into(), - groundtruth.as_view(), - Some(distances.as_view().into()), - 10, - 10, - false, - true, - ) - .unwrap_err(); - assert!(matches!( - err, - ComputeRecallError::NotEnoughGroundTruthDistances(..) - )); +impl From<&benchmark_core::recall::AveragePrecisionMetrics> for AveragePrecisionMetrics { + fn from(m: &benchmark_core::recall::AveragePrecisionMetrics) -> Self { + Self { + num_queries: m.num_queries, + average_precision: m.average_precision, } } } diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index a21d3f520..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -/// Create a multi-threaded runtime with `num_threads`. +/// Create a generic multi-threaded runtime with `num_threads`. pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { Ok(tokio::runtime::Builder::new_multi_thread() .worker_threads(num_threads) @@ -18,21 +18,3 @@ pub(crate) fn block_on(future: F) -> F::Output { .expect("current thread runtime initialization failed") .block_on(future) } - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_runtimes() { - for num_threads in [1, 2, 4, 8] { - let rt = runtime(num_threads).unwrap(); - let metrics = rt.metrics(); - assert_eq!(metrics.num_workers(), num_threads); - } - } -} diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index 9eb7ff500..f0d99bfd9 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,6 +5,7 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index c042d8338..5e9e3a9c1 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,8 +15,10 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, + } + /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 16404af4b..83c9f80f9 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&i, separator), out, separator); + flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); } } _ => { diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 7f0cb203a..1b4b3408e 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,14 +5,13 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true -license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] byteorder.workspace = true clap = { workspace = true, features = ["derive"] } -diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` +diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } diskann-disk = { workspace = true } diskann-utils = { workspace = true } @@ -24,31 +23,24 @@ ordered-float = "4.2.0" rand_distr.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive"] } -toml = "0.8.13" +serde_json.workspace = true bincode.workspace = true opentelemetry.workspace = true -opentelemetry_sdk.workspace = true -csv.workspace = true -tokio = { workspace = true, features = ["full"] } -arc-swap.workspace = true diskann-quantization = { workspace = true } diskann = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } tracing.workspace = true bit-set.workspace = true anyhow.workspace = true -serde_json.workspace = true itertools.workspace = true diskann-label-filter.workspace = true [dev-dependencies] rstest.workspace = true -assert_ok = "1.0.2" -# Use virtual-storage for integration tests -diskann-disk = { path = "../diskann-disk", features = ["virtual_storage"] } vfs = { workspace = true } -ureq = { version = "3.0.11", default-features = false, features = ["native-tls", "gzip"] } -diskann-providers = { path = "../diskann-providers", default-features = false, features = ["testing", "virtual_storage"] } +diskann-providers = { workspace = true, default-features = false, features = [ + "virtual_storage", +] } diskann-utils = { workspace = true, features = ["testing"] } [features] diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index e96f7ae8f..31e69b2b2 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -25,18 +25,97 @@ use diskann_utils::views::Matrix; use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; +use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; +/// Expands a JSON object with array-valued fields into multiple objects with scalar values. +/// For example: {"country": ["AU", "NZ"], "year": 2007} +/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// +/// If multiple fields have arrays, all combinations are generated. +fn expand_array_fields(value: &Value) -> Vec { + match value { + Value::Object(map) => { + // Start with a single empty object + let mut results: Vec> = vec![Map::new()]; + + for (key, val) in map.iter() { + if let Value::Array(arr) = val { + // Expand: for each existing result, create copies for each array element + let mut new_results: Vec> = Vec::new(); + for existing in results.iter() { + for item in arr.iter() { + let mut new_map: Map = existing.clone(); + new_map.insert(key.clone(), item.clone()); + new_results.push(new_map); + } + } + // If array is empty, keep existing results without this key + if !arr.is_empty() { + results = new_results; + } + } else { + // Non-array field: add to all existing results + for existing in results.iter_mut() { + existing.insert(key.clone(), val.clone()); + } + } + } + + results.into_iter().map(Value::Object).collect() + } + // If not an object, return as-is + _ => vec![value.clone()], + } +} + +/// Evaluates a query expression against a label, expanding array fields first. +/// Returns true if any expanded variant matches the query. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + let expanded = expand_array_fields(label); + expanded.iter().any(|item| eval_query_expr(query_expr, item)) +} + pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; + tracing::info!( + "Loaded {} base labels from {}", + base_labels.len(), + base_label_filename + ); + + // Print first few base labels for debugging + for (i, label) in base_labels.iter().take(3).enumerate() { + tracing::debug!( + "Base label sample [{}]: doc_id={}, label={}", + i, + label.doc_id, + label.label + ); + } // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; + tracing::info!( + "Loaded {} queries from {}", + parsed_queries.len(), + query_label_filename + ); + + // Print first few queries for debugging + for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { + tracing::debug!( + "Query sample [{}]: query_id={}, expr={:?}", + i, + query_id, + query_expr + ); + } // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -45,7 +124,15 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - if eval_query_expr(query_expr, &base_label.label) { + // Handle case where base_label.label is an array - check if any element matches + // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) + let matches = if let Some(array) = base_label.label.as_array() { + array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + } else { + eval_query_with_array_expansion(query_expr, &base_label.label) + }; + + if matches { bitmap.insert(base_label.doc_id); } } @@ -53,6 +140,38 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); + // Debug: Print match statistics for each query + let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); + let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); + tracing::info!( + "Filter matching summary: {} total matches across {} queries ({} queries have matches)", + total_matches, + query_bitmaps.len(), + queries_with_matches + ); + + // Print per-query match counts + for (i, bitmap) in query_bitmaps.iter().enumerate() { + if i < 10 || bitmap.is_empty() { + tracing::debug!( + "Query {}: {} base vectors matched the filter", + i, + bitmap.len() + ); + } + } + + // If no matches, print more diagnostic info + if total_matches == 0 { + tracing::warn!("WARNING: No base vectors matched any query filters!"); + tracing::warn!("This could indicate a format mismatch between base labels and query filters."); + + // Try to identify what keys exist in base labels vs queries + if let Some(first_label) = base_labels.first() { + tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + } + } + Ok(query_bitmaps) } @@ -195,6 +314,44 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); + // Debug: Print top K matches for each query + tracing::info!( + "Ground truth computed for {} queries with recall_at={}", + ground_truth.len(), + recall_at + ); + for (query_idx, npq) in ground_truth.iter().enumerate() { + let neighbors: Vec<_> = npq.iter().collect(); + let neighbor_count = neighbors.len(); + + if query_idx < 10 { + // Print top K IDs and distances for first 10 queries + let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); + let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); + tracing::debug!( + "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", + query_idx, + neighbor_count, + top_ids, + top_dists + ); + } + + if neighbor_count == 0 { + tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); + } + } + + // Summary stats + let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); + let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + tracing::info!( + "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", + total_neighbors, + queries_with_neighbors, + ground_truth.len() - queries_with_neighbors + ); + if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter() From ec2091ffb510a245970e0fdec83bb46955cbefe7 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 10 Feb 2026 11:08:49 +0530 Subject: [PATCH 02/39] Before merging with main --- Cargo.lock | 340 ------------------ .../src/utils/flatten_utils.rs | 2 +- 2 files changed, 1 insertion(+), 341 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e80330d7d..665b6c6df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "adler2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" - [[package]] name = "aho-corasick" version = "1.1.4" @@ -103,30 +97,12 @@ dependencies = [ "rustversion", ] -[[package]] -name = "assert_ok" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c770ef7624541db11cce57929f00e737fef89157d7c1cd1977b20ee74fefd84" - [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64ct" -version = "1.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" - [[package]] name = "bf-tree" version = "0.4.5" @@ -225,16 +201,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "cc" -version = "1.2.52" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" -dependencies = [ - "find-msvc-tools", - "shlex", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -327,31 +293,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - [[package]] name = "criterion" version = "0.5.1" @@ -419,43 +360,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" -[[package]] -name = "csv" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde_core", -] - -[[package]] -name = "csv-core" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" -dependencies = [ - "memchr", -] - [[package]] name = "defer" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "930c7171c8df9fb1782bdf9b918ed9ed2d33d1d22300abb754f9085bc48bf8e8" -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "pem-rfc7468", - "zeroize", -] - [[package]] name = "derive_more" version = "2.1.1" @@ -718,14 +628,11 @@ name = "diskann-tools" version = "0.41.0" dependencies = [ "anyhow", - "arc-swap", - "assert_ok", "bincode", "bit-set", "bytemuck", "byteorder", "clap", - "csv", "diskann", "diskann-disk", "diskann-label-filter", @@ -737,7 +644,6 @@ dependencies = [ "itertools 0.13.0", "num_cpus", "opentelemetry", - "opentelemetry_sdk", "ordered-float", "rand 0.9.2", "rand_distr", @@ -745,11 +651,8 @@ dependencies = [ "rstest", "serde", "serde_json", - "tokio", - "toml 0.8.23", "tracing", "tracing-subscriber", - "ureq", "vfs", ] @@ -956,12 +859,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "find-msvc-tools" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" - [[package]] name = "flatbuffers" version = "25.12.19" @@ -972,16 +869,6 @@ dependencies = [ "rustc_version", ] -[[package]] -name = "flate2" -version = "1.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1000,21 +887,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "futures" version = "0.3.31" @@ -1322,22 +1194,6 @@ dependencies = [ "paste", ] -[[package]] -name = "http" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" -dependencies = [ - "bytes", - "itoa", -] - -[[package]] -name = "httparse" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" - [[package]] name = "iai-callgrind" version = "0.14.2" @@ -1559,16 +1415,6 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" -[[package]] -name = "miniz_oxide" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" -dependencies = [ - "adler2", - "simd-adler32", -] - [[package]] name = "mio" version = "1.1.1" @@ -1650,23 +1496,6 @@ dependencies = [ "nano-gemm-core", ] -[[package]] -name = "native-tls" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "never-say-never" version = "6.6.666" @@ -1730,50 +1559,6 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" -[[package]] -name = "openssl" -version = "0.10.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" -dependencies = [ - "bitflags 2.10.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.113", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - -[[package]] -name = "openssl-sys" -version = "0.9.111" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "opentelemetry" version = "0.30.0" @@ -1842,15 +1627,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.3.2" @@ -1889,12 +1665,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - [[package]] name = "plotters" version = "0.3.7" @@ -2326,15 +2096,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "rustls-pki-types" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" -dependencies = [ - "zeroize", -] - [[package]] name = "rustversion" version = "1.0.22" @@ -2384,15 +2145,6 @@ dependencies = [ "sdd", ] -[[package]] -name = "schannel" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -2405,29 +2157,6 @@ version = "4.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63d45f3526312c9c90d717aac28d37010e623fbd7ca6f21503e69784e86f40" -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags 2.10.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.27" @@ -2523,12 +2252,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -2539,12 +2262,6 @@ dependencies = [ "libc", ] -[[package]] -name = "simd-adler32" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" - [[package]] name = "slab" version = "0.4.11" @@ -2922,42 +2639,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" -[[package]] -name = "ureq" -version = "3.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" -dependencies = [ - "base64", - "der", - "flate2", - "log", - "native-tls", - "percent-encoding", - "rustls-pki-types", - "ureq-proto", - "utf-8", - "webpki-root-certs", -] - -[[package]] -name = "ureq-proto" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" -dependencies = [ - "base64", - "http", - "httparse", - "log", -] - -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "utf8parse" version = "0.2.2" @@ -2970,12 +2651,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - [[package]] name = "version_check" version = "0.9.5" @@ -3090,15 +2765,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-root-certs" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "winapi-util" version = "0.1.11" @@ -3305,12 +2971,6 @@ dependencies = [ "syn 2.0.113", ] -[[package]] -name = "zeroize" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" - [[package]] name = "zmij" version = "1.0.11" diff --git a/diskann-label-filter/src/utils/flatten_utils.rs b/diskann-label-filter/src/utils/flatten_utils.rs index 83c9f80f9..16404af4b 100644 --- a/diskann-label-filter/src/utils/flatten_utils.rs +++ b/diskann-label-filter/src/utils/flatten_utils.rs @@ -154,7 +154,7 @@ fn flatten_json_pointer_inner( } Value::Array(arr) => { for (i, item) in arr.iter().enumerate() { - flatten_recursive(item, stack.push(&String::from(""), separator), out, separator); + flatten_recursive(item, stack.push(&i, separator), out, separator); } } _ => { From a949024b8283390d49b43061973abbb74653d17d Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Mon, 16 Feb 2026 14:40:55 +0530 Subject: [PATCH 03/39] Working version of inline beta search --- .../example/document-filter.json | 34 + .../src/backend/document_index/benchmark.rs | 1038 ++++++++++++++ .../src/backend/document_index/mod.rs | 13 + diskann-benchmark/src/backend/index/result.rs | 13 + diskann-benchmark/src/backend/mod.rs | 2 + .../src/inputs/document_index.rs | 177 +++ diskann-benchmark/src/inputs/mod.rs | 2 + diskann-benchmark/src/utils/recall.rs | 1 + diskann-benchmark/src/utils/tokio.rs | 7 + diskann-label-filter/src/attribute.rs | 1 - diskann-label-filter/src/document.rs | 4 +- .../ast_label_id_mapper.rs | 15 +- .../document_insert_strategy.rs | 274 ++++ .../document_provider.rs | 2 +- .../encoded_filter_expr.rs | 19 +- .../roaring_attribute_store.rs | 2 +- .../encoded_document_accessor.rs | 14 +- .../inline_beta_search/inline_beta_filter.rs | 67 +- diskann-label-filter/src/lib.rs | 1 + diskann-label-filter/src/parser/format.rs | 2 - .../provider/async_/inmem/full_precision.rs | 1218 +++++++++-------- diskann-tools/src/utils/ground_truth.rs | 37 +- .../disk_index_search/data.256.label.jsonl | 4 +- 23 files changed, 2307 insertions(+), 640 deletions(-) create mode 100644 diskann-benchmark/example/document-filter.json create mode 100644 diskann-benchmark/src/backend/document_index/benchmark.rs create mode 100644 diskann-benchmark/src/backend/document_index/mod.rs create mode 100644 diskann-benchmark/src/inputs/document_index.rs create mode 100644 diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json new file mode 100644 index 000000000..d6e9e13b2 --- /dev/null +++ b/diskann-benchmark/example/document-filter.json @@ -0,0 +1,34 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "document-index-build", + "content": { + "build": { + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_labels": "data.256.label.jsonl", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2 + }, + "search": { + "queries": "disk_index_sample_query_10pts.fbin", + "query_predicates": "query.10.label.jsonl", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "beta": 0.5, + "runs": [ + { + "search_n": 20, + "search_l": [20, 30, 40], + "recall_k": 10 + } + ] + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs new file mode 100644 index 000000000..dffe669ff --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -0,0 +1,1038 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Benchmark for DocumentInsertStrategy which allows inserting Documents +//! (vector + attributes) into a DiskANN index built with DocumentProvider. +//! Also benchmarks filtered search using InlineBetaStrategy. + +use std::io::Write; +use std::num::NonZeroUsize; +use std::path::Path; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use anyhow::Result; +use diskann::{ + graph::{ + config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, + search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, + }, + provider::DefaultContext, + utils::{async_tools, IntoUsize}, +}; +use diskann_benchmark_runner::{ + dispatcher::{DispatchRule, FailureScore, MatchScore}, + output::Output, + registry::Benchmarks, + utils::{datatype::DataType, percentiles, MicroSeconds}, + Any, Checkpoint, +}; +use diskann_label_filter::{ + attribute::{Attribute, AttributeValue}, + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::DocumentInsertStrategy, document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::inline_beta_filter::InlineBetaStrategy, + query::FilteredQuery, + read_and_parse_queries, read_baselabels, ASTExpr, +}; +use diskann_providers::model::graph::provider::async_::{ + common::{self, NoStore, TableBasedDeletes}, + inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, +}; +use diskann_utils::views::Matrix; +use indicatif::{ProgressBar, ProgressStyle}; +use serde::Serialize; + +use crate::{ + inputs::document_index::DocumentIndexBuild, + utils::{ + self, + datafiles::{self, BinFile}, + recall, + }, +}; + +/// Register the document index benchmarks. +pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { + benchmarks.register::>( + "document-index-build", + |job, checkpoint, out| { + let stats = job.run(checkpoint, out)?; + Ok(serde_json::to_value(stats)?) + }, + ); +} + +/// Document index benchmark job. +pub(super) struct DocumentIndexJob<'a> { + input: &'a DocumentIndexBuild, +} + +impl<'a> DocumentIndexJob<'a> { + fn new(input: &'a DocumentIndexBuild) -> Self { + Self { input } + } +} + +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static> { + type Type<'a> = DocumentIndexJob<'a>; +} + +// Dispatch from the concrete input type +impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { + type Error = std::convert::Infallible; + + fn try_match(_from: &&'a DocumentIndexBuild) -> Result { + Ok(MatchScore(1)) + } + + fn convert(from: &'a DocumentIndexBuild) -> Result { + Ok(DocumentIndexJob::new(from)) + } + + fn description( + f: &mut std::fmt::Formatter<'_>, + _from: Option<&&'a DocumentIndexBuild>, + ) -> std::fmt::Result { + writeln!(f, "tag: \"{}\"", DocumentIndexBuild::tag()) + } +} + +// Central dispatch mapping from Any +impl<'a> DispatchRule<&'a Any> for DocumentIndexJob<'a> { + type Error = anyhow::Error; + + fn try_match(from: &&'a Any) -> Result { + from.try_match::() + } + + fn convert(from: &'a Any) -> Result { + from.convert::() + } + + fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { + Any::description::(f, from, DocumentIndexBuild::tag()) + } +} +/// Convert a HashMap to Vec +fn hashmap_to_attributes(map: std::collections::HashMap) -> Vec { + map.into_iter() + .map(|(k, v)| Attribute::from_value(k, v)) + .collect() +} + +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> usize +where + T: bytemuck::Pod + Copy + 'static, +{ + use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; + + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return 0; + } + + // Compute the centroid (mean of all rows) as f64 for precision + let mut sum = vec![0.0f64; dim]; + for i in 0..data.nrows() { + let row = data.row(i); + for (j, &v) in row.iter().enumerate() { + // Convert T to f64 for summation using bytemuck + let f64_val: f64 = if std::any::TypeId::of::() == std::any::TypeId::of::() { + let f32_val: f32 = bytemuck::cast(v); + f32_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f64 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f64 + } else { + 0.0 + }; + sum[j] += f64_val; + } + } + + // Convert centroid to f32 and compute distances + let centroid_f32: Vec = sum + .iter() + .map(|s| (s / data.nrows() as f64) as f32) + .collect(); + + // Find the row closest to the centroid + let mut min_dist = f32::MAX; + let mut medoid_idx = 0; + for i in 0..data.nrows() { + let row = data.row(i); + let row_f32: Vec = row + .iter() + .map(|&v| { + if std::any::TypeId::of::() == std::any::TypeId::of::() { + bytemuck::cast(v) + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let u8_val: u8 = bytemuck::cast(v); + u8_val as f32 + } else if std::any::TypeId::of::() == std::any::TypeId::of::() { + let i8_val: i8 = bytemuck::cast(v); + i8_val as f32 + } else { + 0.0 + } + }) + .collect(); + let d = SquaredL2::evaluate(centroid_f32.as_slice(), row_f32.as_slice()); + if d < min_dist { + min_dist = d; + medoid_idx = i; + } + } + + medoid_idx +} + +impl<'a> DocumentIndexJob<'a> { + fn run( + self, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, + ) -> Result { + // Print the input description + writeln!(output, "{}", self.input)?; + + let build = &self.input.build; + + // Dispatch based on data type - retain original type without conversion + match build.data_type { + DataType::Float32 => self.run_typed::(output), + DataType::UInt8 => self.run_typed::(output), + DataType::Int8 => self.run_typed::(output), + _ => Err(anyhow::anyhow!( + "Unsupported data type: {:?}. Supported types: float32, uint8, int8.", + build.data_type + )), + } + } + + fn run_typed(self, mut output: &mut dyn Output) -> Result + where + T: bytemuck::Pod + Copy + Send + Sync + 'static + std::fmt::Debug, + T: diskann::graph::SampleableForStart + diskann_utils::future::AsyncFriendly, + T: diskann::utils::VectorRepr + diskann_utils::sampling::WithApproximateNorm, + { + let build = &self.input.build; + + // 1. Load vectors from data file in the original data type + writeln!(output, "Loading vectors ({})...", build.data_type)?; + let timer = std::time::Instant::now(); + let data_path: &Path = build.data.as_ref(); + writeln!(output, "Data path is: {}", data_path.to_string_lossy())?; + let data: Matrix = datafiles::load_dataset(BinFile(data_path))?; + let data_load_time: MicroSeconds = timer.elapsed().into(); + let num_vectors = data.nrows(); + let dim = data.ncols(); + writeln!( + output, + " Loaded {} vectors of dimension {}", + num_vectors, dim + )?; + + // 2. Load and parse labels from the data_labels file + writeln!(output, "Loading labels...")?; + let timer = std::time::Instant::now(); + let label_path: &Path = build.data_labels.as_ref(); + let labels = read_baselabels(label_path)?; + let label_load_time: MicroSeconds = timer.elapsed().into(); + let label_count = labels.len(); + writeln!(output, " Loaded {} label documents", label_count)?; + + if num_vectors != label_count { + return Err(anyhow::anyhow!( + "Mismatch: {} vectors but {} label documents", + num_vectors, + label_count + )); + } + + // Convert labels to attribute vectors + let attributes: Vec> = labels + .into_iter() + .map(|doc| hashmap_to_attributes(doc.flatten_metadata_with_separator(""))) + .collect(); + + // 3. Create the index configuration + let metric = build.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = ConfigBuilder::new( + build.max_degree, // pruned_degree + MaxDegree::Same, // max_degree + build.l_build, // l_build + prune_kind, // prune_kind + ); + config_builder.alpha(build.alpha); + let config = config_builder.build()?; + + // 4. Create the data provider directly + writeln!(output, "Creating index...")?; + let params = DefaultProviderParameters { + max_points: num_vectors, + frozen_points: diskann::utils::ONE, + metric, + dim, + prefetch_lookahead: None, + prefetch_cache_line_level: None, + max_degree: build.max_degree as u32, + }; + + // Create the underlying provider + let fp_precursor = CreateFullPrecision::::new(dim, None); + let inner_provider = + DefaultProvider::new_empty(params, fp_precursor, NoStore, TableBasedDeletes)?; + + // Set start points using medoid strategy + let start_points = StartPointStrategy::Medoid + .compute(data.as_view()) + .map_err(|e| anyhow::anyhow!("Failed to compute start points: {}", e))?; + inner_provider.set_start_points(start_points.row_iter())?; + + // 5. Create DocumentProvider wrapping the inner provider + let attribute_store = RoaringAttributeStore::::new(); + + // Store attributes for the start point (medoid) + // Start points are stored at indices num_vectors..num_vectors+frozen_points + let medoid_idx = compute_medoid_index(&data); + let start_point_id = num_vectors as u32; // Start points begin at max_points + let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); + use diskann_label_filter::traits::attribute_store::AttributeStore; + attribute_store.set_element(&start_point_id, &medoid_attrs)?; + + let doc_provider = DocumentProvider::new(inner_provider, attribute_store); + + // Create a new DiskANNIndex with DocumentProvider + let doc_index = Arc::new(DiskANNIndex::new(config, doc_provider, None)); + + // 6. Build index by inserting vectors and attributes (parallel) + writeln!( + output, + "Building index with {} vectors using {} threads...", + num_vectors, build.num_threads + )?; + let timer = std::time::Instant::now(); + + let insert_strategy: DocumentInsertStrategy<_, [T]> = + DocumentInsertStrategy::new(common::FullPrecision); + let rt = utils::tokio::runtime(build.num_threads)?; + + // Create control block for parallel work distribution + let data_arc = Arc::new(data); + let attributes_arc = Arc::new(attributes); + let control_block = DocumentControlBlock::new( + data_arc.clone(), + attributes_arc.clone(), + output.draw_target(), + )?; + + let num_tasks = build.num_threads; + let insert_latencies = rt.block_on(async { + let tasks: Vec<_> = (0..num_tasks) + .map(|_| { + let block = control_block.clone(); + let index = doc_index.clone(); + let strategy = insert_strategy; + tokio::spawn(async move { + let mut latencies = Vec::::new(); + let ctx = DefaultContext; + loop { + match block.next() { + Some((id, vector, attrs)) => { + let doc = Document::new(vector, attrs); + let start = std::time::Instant::now(); + let result = + index.insert(strategy, &ctx, &(id as u32), &doc).await; + latencies.push(MicroSeconds::from(start.elapsed())); + + if let Err(e) = result { + block.cancel(); + return Err(e); + } + } + None => return Ok(latencies), + } + } + }) + }) + .collect(); + + // Collect results from all tasks + let mut all_latencies = Vec::with_capacity(num_vectors); + for task in tasks { + let task_latencies = task.await??; + all_latencies.extend(task_latencies); + } + Ok::<_, anyhow::Error>(all_latencies) + })?; + + let build_time: MicroSeconds = timer.elapsed().into(); + writeln!(output, " Index built in {} s", build_time.as_seconds())?; + + let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies.clone())?; + // ===================== + // Search Phase + // ===================== + let search_input = &self.input.search; + + // Load query vectors (same type as data for compatible distance computation) + writeln!(output, "\nLoading query vectors...")?; + let query_path: &Path = search_input.queries.as_ref(); + let queries: Matrix = datafiles::load_dataset(BinFile(query_path))?; + let num_queries = queries.nrows(); + writeln!(output, " Loaded {} queries", num_queries)?; + + // Load and parse query predicates + writeln!(output, "Loading query predicates...")?; + let predicate_path: &Path = search_input.query_predicates.as_ref(); + let parsed_predicates = read_and_parse_queries(predicate_path)?; + writeln!(output, " Loaded {} predicates", parsed_predicates.len())?; + + if num_queries != parsed_predicates.len() { + return Err(anyhow::anyhow!( + "Mismatch: {} queries but {} predicates", + num_queries, + parsed_predicates.len() + )); + } + + // Load groundtruth + writeln!(output, "Loading groundtruth...")?; + let gt_path: &Path = search_input.groundtruth.as_ref(); + let groundtruth: Vec> = datafiles::load_range_groundtruth(BinFile(gt_path))?; + writeln!( + output, + " Loaded groundtruth with {} rows", + groundtruth.len() + )?; + + // Run filtered searches + writeln!( + output, + "\nRunning filtered searches (beta={})...", + search_input.beta + )?; + let mut search_results = Vec::new(); + + for num_threads in &search_input.num_threads { + for run in &search_input.runs { + for &search_l in &run.search_l { + writeln!( + output, + " threads={}, search_n={}, search_l={}...", + num_threads, run.search_n, search_l + )?; + + let search_run_result = run_filtered_search( + &doc_index, + &queries, + &parsed_predicates, + &groundtruth, + search_input.beta, + *num_threads, + run.search_n, + search_l, + run.recall_k, + search_input.reps, + )?; + + writeln!( + output, + " recall={:.4}, mean_qps={:.1}", + search_run_result.recall.average, + if search_run_result.qps.is_empty() { + 0.0 + } else { + search_run_result.qps.iter().sum::() + / search_run_result.qps.len() as f64 + } + )?; + + search_results.push(search_run_result); + } + } + } + + let stats = DocumentIndexStats { + num_vectors, + dim, + label_count, + data_load_time, + label_load_time, + build_time, + insert_latencies: insert_percentiles, + build_params: BuildParamsStats { + max_degree: build.max_degree, + l_build: build.l_build, + alpha: build.alpha, + }, + search: search_results, + }; + + writeln!(output, "\n{}", stats)?; + Ok(stats) + } +} +/// Local results from a partition of queries. +struct SearchLocalResults { + ids: Matrix, + distances: Vec>, + latencies: Vec, + comparisons: Vec, + hops: Vec, +} + +impl SearchLocalResults { + fn merge(all: &[SearchLocalResults]) -> anyhow::Result { + let first = all + .first() + .ok_or_else(|| anyhow::anyhow!("empty results"))?; + let num_ids = first.ids.ncols(); + let total_rows: usize = all.iter().map(|r| r.ids.nrows()).sum(); + + let mut ids = Matrix::new(0, total_rows, num_ids); + let mut output_row = 0; + for r in all { + for input_row in r.ids.row_iter() { + ids.row_mut(output_row).copy_from_slice(input_row); + output_row += 1; + } + } + + let mut distances = Vec::new(); + let mut latencies = Vec::new(); + let mut comparisons = Vec::new(); + let mut hops = Vec::new(); + for r in all { + distances.extend_from_slice(&r.distances); + latencies.extend_from_slice(&r.latencies); + comparisons.extend_from_slice(&r.comparisons); + hops.extend_from_slice(&r.hops); + } + + Ok(Self { + ids, + distances, + latencies, + comparisons, + hops, + }) + } +} + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, + beta: f32, + num_threads: NonZeroUsize, + search_n: usize, + search_l: usize, + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let rt = utils::tokio::runtime(num_threads.get())?; + let num_queries = queries.nrows(); + + let mut all_rep_results = Vec::with_capacity(reps.get()); + let mut rep_latencies = Vec::with_capacity(reps.get()); + + for _ in 0..reps.get() { + let start = std::time::Instant::now(); + let results = rt.block_on(run_search_parallel( + index.clone(), + queries, + predicates, + beta, + num_threads, + search_n, + search_l, + ))?; + rep_latencies.push(MicroSeconds::from(start.elapsed())); + all_rep_results.push(results); + } + + // Merge results from first rep for recall calculation + let merged = SearchLocalResults::merge(&all_rep_results[0])?; + + // Compute recall + let recall_metrics: recall::RecallMetrics = + (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); + + // Compute per-query details (only for queries with recall < 1) + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = merged + .ids + .row(query_idx) + .iter() + .copied() + .filter(|&id| id != u32::MAX) + .collect(); + let result_distances: Vec = merged + .distances + .get(query_idx) + .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) + .unwrap_or_default(); + // Only keep top 20 from ground truth + let gt_ids: Vec = groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + // Compute per-query recall: intersection of result_ids with gt_ids / recall_k + let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; + + // Only include queries with imperfect recall + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = predicates[query_idx]; + let filter_str = format!("{:?}", ast_expr); + + Some(PerQueryDetails { + query_id: query_idx, + filter: filter_str, + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) + }) + .collect(); + + // Compute QPS from rep latencies + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); + + // Aggregate per-query latencies across all reps + let (all_latencies, all_cmps, all_hops): (Vec<_>, Vec<_>, Vec<_>) = all_rep_results + .iter() + .map(|results| { + let mut lat = Vec::new(); + let mut cmp = Vec::new(); + let mut hop = Vec::new(); + for r in results { + lat.extend_from_slice(&r.latencies); + cmp.extend_from_slice(&r.comparisons); + hop.extend_from_slice(&r.hops); + } + (lat, cmp, hop) + }) + .fold( + (Vec::new(), Vec::new(), Vec::new()), + |(mut a, mut b, mut c): (Vec, Vec, Vec), (x, y, z)| { + a.extend(x); + b.extend(y); + c.extend(z); + (a, b, c) + }, + ); + + let mut query_latencies = all_latencies; + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut query_latencies)?; + + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; + + Ok(SearchRunStats { + num_threads: num_threads.get(), + num_queries, + search_n, + search_l, + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) +} +async fn run_search_parallel( + index: Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + beta: f32, + num_tasks: NonZeroUsize, + search_n: usize, + search_l: usize, +) -> anyhow::Result> +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let num_queries = queries.nrows(); + + // Plan query partitions + let partitions: Result, _> = (0..num_tasks.get()) + .map(|task_id| async_tools::partition(num_queries, num_tasks, task_id)) + .collect(); + let partitions = partitions?; + + // We need to clone data for each task + let queries_arc = Arc::new(queries.clone()); + let predicates_arc = Arc::new(predicates.to_vec()); + + let handles: Vec<_> = partitions + .into_iter() + .map(|range| { + let index = index.clone(); + let queries = queries_arc.clone(); + let predicates = predicates_arc.clone(); + tokio::spawn(async move { + run_search_local(index, queries, predicates, beta, range, search_n, search_l).await + }) + }) + .collect(); + + let mut results = Vec::new(); + for h in handles { + results.push(h.await??); + } + + Ok(results) +} + +async fn run_search_local( + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, + range: std::ops::Range, + search_n: usize, + search_l: usize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync, + InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>>, +{ + let mut ids = Matrix::new(0, range.len(), search_n); + let mut all_distances: Vec> = Vec::with_capacity(range.len()); + let mut latencies = Vec::with_capacity(range.len()); + let mut comparisons = Vec::with_capacity(range.len()); + let mut hops = Vec::with_capacity(range.len()); + + let ctx = DefaultContext; + let search_params = SearchParams::new_default(search_n, search_l)?; + + for (output_idx, query_idx) in range.enumerate() { + let query_vec = queries.row(query_idx); + let (_, ref ast_expr) = predicates[query_idx]; + + let strategy = InlineBetaStrategy::new(beta, common::FullPrecision); + let query_vec_owned = query_vec.to_vec(); + let filtered_query: FilteredQuery> = + FilteredQuery::new(query_vec_owned, ast_expr.clone()); + + let start = std::time::Instant::now(); + + let mut distances = vec![0.0f32; search_n]; + let result_ids = ids.row_mut(output_idx); + let mut result_buffer = search_output_buffer::IdDistance::new(result_ids, &mut distances); + + let stats = index + .search( + &strategy, + &ctx, + &filtered_query, + &search_params, + &mut result_buffer, + ) + .await?; + + let result_count = stats.result_count.into_usize(); + result_ids[result_count..].fill(u32::MAX); + distances[result_count..].fill(f32::MAX); + + latencies.push(MicroSeconds::from(start.elapsed())); + comparisons.push(stats.cmps); + hops.push(stats.hops); + all_distances.push(distances); + } + + Ok(SearchLocalResults { + ids, + distances: all_distances, + latencies, + comparisons, + hops, + }) +} +#[derive(Debug, Serialize)] +pub struct BuildParamsStats { + pub max_degree: usize, + pub l_build: usize, + pub alpha: f32, +} + +/// Helper module for serializing arrays as compact single-line JSON strings +mod compact_array { + use serde::Serializer; + + pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } + + pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub result_ids: Vec, + #[serde(serialize_with = "compact_array::serialize_f32_vec")] + pub result_distances: Vec, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub groundtruth_ids: Vec, +} + +/// Results from a single search configuration (one search_l value). +#[derive(Debug, Serialize)] +pub struct SearchRunStats { + pub num_threads: usize, + pub num_queries: usize, + pub search_n: usize, + pub search_l: usize, + pub recall: recall::RecallMetrics, + pub qps: Vec, + pub wall_clock_time: Vec, + pub mean_latency: f64, + pub p90_latency: MicroSeconds, + pub p99_latency: MicroSeconds, + pub mean_cmps: f32, + pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, +} + +#[derive(Debug, Serialize)] +pub struct DocumentIndexStats { + pub num_vectors: usize, + pub dim: usize, + pub label_count: usize, + pub data_load_time: MicroSeconds, + pub label_load_time: MicroSeconds, + pub build_time: MicroSeconds, + pub insert_latencies: percentiles::Percentiles, + pub build_params: BuildParamsStats, + pub search: Vec, +} + +impl std::fmt::Display for DocumentIndexStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build Stats:")?; + writeln!(f, " Vectors: {} x {}", self.num_vectors, self.dim)?; + writeln!(f, " Label Count: {}", self.label_count)?; + writeln!( + f, + " Data Load Time: {} s", + self.data_load_time.as_seconds() + )?; + writeln!( + f, + " Label Load Time: {} s", + self.label_load_time.as_seconds() + )?; + writeln!(f, " Total Build Time: {} s", self.build_time.as_seconds())?; + writeln!(f, " Insert Latencies:")?; + writeln!(f, " Mean: {} us", self.insert_latencies.mean)?; + writeln!(f, " P50: {} us", self.insert_latencies.median)?; + writeln!(f, " P90: {} us", self.insert_latencies.p90)?; + writeln!(f, " P99: {} us", self.insert_latencies.p99)?; + writeln!(f, " Build Parameters:")?; + writeln!(f, " max_degree (R): {}", self.build_params.max_degree)?; + writeln!(f, " l_build (L): {}", self.build_params.l_build)?; + writeln!(f, " alpha: {}", self.build_params.alpha)?; + + if !self.search.is_empty() { + writeln!(f, "\nFiltered Search Results:")?; + writeln!( + f, + " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", + "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + )?; + for s in &self.search { + let mean_qps = if s.qps.is_empty() { + 0.0 + } else { + s.qps.iter().sum::() / s.qps.len() as f64 + }; + let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); + let mean_wall_clock = if s.wall_clock_time.is_empty() { + 0.0 + } else { + s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 + }; + writeln!( + f, + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", + s.search_l, + s.search_n, + s.mean_cmps, + s.mean_hops, + mean_qps, + max_qps, + s.mean_latency, + s.p99_latency, + s.recall.average, + s.num_threads, + s.num_queries, + mean_wall_clock + )?; + } + } + Ok(()) + } +} + +// ================================ +// Parallel Build Support +// ================================ + +fn make_progress_bar( + nrows: usize, + draw_target: indicatif::ProgressDrawTarget, +) -> anyhow::Result { + let progress = ProgressBar::with_draw_target(Some(nrows as u64), draw_target); + progress.set_style(ProgressStyle::with_template( + "Building [{elapsed_precise}] {wide_bar} {percent}", + )?); + Ok(progress) +} + +/// Control block for parallel document insertion. +/// Manages work distribution and progress tracking across multiple tasks. +struct DocumentControlBlock { + data: Arc>, + attributes: Arc>>, + position: AtomicUsize, + cancel: AtomicBool, + progress: ProgressBar, +} + +impl DocumentControlBlock { + fn new( + data: Arc>, + attributes: Arc>>, + draw_target: indicatif::ProgressDrawTarget, + ) -> anyhow::Result> { + let nrows = data.nrows(); + Ok(Arc::new(Self { + data, + attributes, + position: AtomicUsize::new(0), + cancel: AtomicBool::new(false), + progress: make_progress_bar(nrows, draw_target)?, + })) + } + + /// Return the next document data to insert: (id, vector_slice, attributes). + fn next(&self) -> Option<(usize, &[T], Vec)> { + let cancel = self.cancel.load(Ordering::Relaxed); + if cancel { + None + } else { + let i = self.position.fetch_add(1, Ordering::Relaxed); + match self.data.get_row(i) { + Some(row) => { + let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + self.progress.inc(1); + Some((i, row, attrs)) + } + None => None, + } + } + } + + /// Tell all users of the control block to cancel and return early. + fn cancel(&self) { + self.cancel.store(true, Ordering::Relaxed); + } +} + +impl Drop for DocumentControlBlock { + fn drop(&mut self) { + self.progress.finish(); + } +} diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs new file mode 100644 index 000000000..9937590cc --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -0,0 +1,13 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Backend benchmark implementation for document index with label filters. +//! +//! This benchmark tests the DocumentInsertStrategy which enables inserting +//! Document objects (vector + attributes) into a DiskANN index. + +mod benchmark; + +pub(crate) use benchmark::register_benchmarks; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index c7e2ab75c..21d74f915 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -109,6 +109,7 @@ impl std::fmt::Display for AggregatedSearchResults { #[derive(Debug, Serialize)] pub(super) struct SearchResults { pub(super) num_tasks: usize, + pub(super) num_queries: usize, pub(super) search_n: usize, pub(super) search_l: usize, pub(super) qps: Vec, @@ -143,6 +144,7 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), + num_queries: recall.num_queries, search_n: parameters.k_value, search_l: parameters.l_value, qps, @@ -182,6 +184,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] } else { &[ @@ -194,6 +198,8 @@ where "p99 Latency", "Recall", "Threads", + "Queries", + "WallClock(s)", ] }; @@ -237,6 +243,13 @@ where ); row.insert(format!("{:3}", r.recall.average), col_idx + 7); row.insert(r.num_tasks, col_idx + 8); + row.insert(r.num_queries, col_idx + 9); + let mean_wall_clock = if r.search_latencies.is_empty() { + 0.0 + } else { + r.search_latencies.iter().map(|t| t.as_seconds()).sum::() / r.search_latencies.len() as f64 + }; + row.insert(format!("{:.3}", mean_wall_clock), col_idx + 10); }); write!(f, "{}", table) diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..5dc1967de 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -4,6 +4,7 @@ */ mod disk_index; +mod document_index; mod exhaustive; mod filters; mod index; @@ -13,4 +14,5 @@ pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::regis disk_index::register_benchmarks(registry); index::register_benchmarks(registry); filters::register_benchmarks(registry); + document_index::register_benchmarks(registry); } diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs new file mode 100644 index 000000000..b1a36e48a --- /dev/null +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -0,0 +1,177 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Input types for document index benchmarks using DocumentInsertStrategy. + +use std::num::NonZeroUsize; + +use anyhow::Context; +use diskann_benchmark_runner::{ + files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, +}; +use serde::{Deserialize, Serialize}; + +use super::async_::GraphSearch; +use crate::inputs::{as_input, Example, Input}; + +////////////// +// Registry // +////////////// + +as_input!(DocumentIndexBuild); + +pub(super) fn register_inputs( + registry: &mut diskann_benchmark_runner::registry::Inputs, +) -> anyhow::Result<()> { + registry.register(Input::::new())?; + Ok(()) +} + +/// Build parameters for document index construction. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentBuildParams { + pub(crate) data_type: DataType, + pub(crate) data: InputFile, + pub(crate) data_labels: InputFile, + pub(crate) distance: crate::utils::SimilarityMeasure, + pub(crate) max_degree: usize, + pub(crate) l_build: usize, + pub(crate) alpha: f32, + #[serde(default = "default_num_threads")] + pub(crate) num_threads: usize, +} + +fn default_num_threads() -> usize { + 1 +} + +impl CheckDeserialization for DocumentBuildParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.check_deserialization(checker)?; + self.data_labels.check_deserialization(checker)?; + if self.max_degree == 0 { + return Err(anyhow::anyhow!("max_degree must be > 0")); + } + if self.l_build == 0 { + return Err(anyhow::anyhow!("l_build must be > 0")); + } + if self.alpha <= 0.0 { + return Err(anyhow::anyhow!("alpha must be > 0")); + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentSearchParams { + pub(crate) queries: InputFile, + pub(crate) query_predicates: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) beta: f32, + #[serde(default = "default_reps")] + pub(crate) reps: NonZeroUsize, + #[serde(default = "default_thread_counts")] + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, +} + +fn default_reps() -> NonZeroUsize { + NonZeroUsize::new(5).unwrap() +} +fn default_thread_counts() -> Vec { + vec![NonZeroUsize::new(1).unwrap()] +} + +impl CheckDeserialization for DocumentSearchParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.check_deserialization(checker)?; + self.query_predicates.check_deserialization(checker)?; + self.groundtruth.check_deserialization(checker)?; + if self.beta <= 0.0 || self.beta > 1.0 { + return Err(anyhow::anyhow!( + "beta must be in range (0, 1], got: {}", + self.beta + )); + } + for (i, run) in self.runs.iter_mut().enumerate() { + run.check_deserialization(checker) + .with_context(|| format!("search run {}", i))?; + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentIndexBuild { + pub(crate) build: DocumentBuildParams, + pub(crate) search: DocumentSearchParams, +} + +impl DocumentIndexBuild { + pub(crate) const fn tag() -> &'static str { + "document-index-build" + } +} + +impl CheckDeserialization for DocumentIndexBuild { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.build.check_deserialization(checker)?; + self.search.check_deserialization(checker)?; + Ok(()) + } +} + +impl Example for DocumentIndexBuild { + fn example() -> Self { + Self { + build: DocumentBuildParams { + data_type: DataType::Float32, + data: InputFile::new("data.fbin"), + data_labels: InputFile::new("data.label.jsonl"), + distance: crate::utils::SimilarityMeasure::SquaredL2, + max_degree: 32, + l_build: 50, + alpha: 1.2, + num_threads: 1, + }, + search: DocumentSearchParams { + queries: InputFile::new("queries.fbin"), + query_predicates: InputFile::new("query.label.jsonl"), + groundtruth: InputFile::new("groundtruth.bin"), + beta: 0.5, + reps: default_reps(), + num_threads: default_thread_counts(), + runs: vec![GraphSearch { + search_n: 10, + search_l: vec![20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} + +impl std::fmt::Display for DocumentIndexBuild { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build with Label Filters\n")?; + writeln!(f, "tag: \"{}\"", Self::tag())?; + writeln!( + f, + "\nBuild: data={}, labels={}, R={}, L={}, alpha={}", + self.build.data.display(), + self.build.data_labels.display(), + self.build.max_degree, + self.build.l_build, + self.build.alpha + )?; + writeln!( + f, + "Search: queries={}, beta={}", + self.search.queries.display(), + self.search.beta + )?; + Ok(()) + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index a0ae1a982..65de65a41 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod async_; pub(crate) mod disk; +pub(crate) mod document_index; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod save_and_load; @@ -16,6 +17,7 @@ pub(crate) fn register_inputs( exhaustive::register_inputs(registry)?; disk::register_inputs(registry)?; filters::register_inputs(registry)?; + document_index::register_inputs(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..50ef7e430 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub(crate) use benchmark_core::recall::knn; use diskann_benchmark_core as benchmark_core; use serde::Serialize; diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 72dbeb918..21c78abb2 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,6 +3,13 @@ * Licensed under the MIT license. */ +/// Create a generic multi-threaded runtime with `num_threads`. +pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { + Ok(tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build()?) +} + /// Create a current-thread runtime and block on the given future. /// Only for functions that don't need multi-threading pub(crate) fn block_on(future: F) -> F::Output { diff --git a/diskann-label-filter/src/attribute.rs b/diskann-label-filter/src/attribute.rs index f0d99bfd9..9eb7ff500 100644 --- a/diskann-label-filter/src/attribute.rs +++ b/diskann-label-filter/src/attribute.rs @@ -5,7 +5,6 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; -use std::io::Write; use serde_json::Value; use thiserror::Error; diff --git a/diskann-label-filter/src/document.rs b/diskann-label-filter/src/document.rs index 31cad4772..5c817525c 100644 --- a/diskann-label-filter/src/document.rs +++ b/diskann-label-filter/src/document.rs @@ -8,12 +8,12 @@ use diskann_utils::reborrow::Reborrow; ///Simple container class that clients can use to /// supply diskann with a vector and its attributes -pub struct Document<'a, V> { +pub struct Document<'a, V: ?Sized> { vector: &'a V, attributes: Vec, } -impl<'a, V> Document<'a, V> { +impl<'a, V: ?Sized> Document<'a, V> { pub fn new(vector: &'a V, attributes: Vec) -> Self { Self { vector, attributes } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs index 0fa21cc02..8b39d8731 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs @@ -31,19 +31,14 @@ impl ASTLabelIdMapper { Self { attribute_map } } - fn _lookup( - encoder: &AttributeEncoder, - attribute: &Attribute, - field: &str, - op: &CompareOp, - ) -> ANNResult> { + fn _lookup(encoder: &AttributeEncoder, attribute: &Attribute) -> ANNResult> { match encoder.get(attribute) { Some(attribute_id) => Ok(ASTIdExpr::Terminal(attribute_id)), None => Err(ANNError::message( ANNErrorKind::Opaque, format!( - "{}+{} present in the query does not exist in the dataset.", - field, op + "{} present in the query does not exist in the dataset.", + attribute ), )), } @@ -120,10 +115,10 @@ impl ASTVisitor for ASTLabelIdMapper { if let Some(attribute) = label_or_none { match self.attribute_map.read() { - Ok(guard) => Self::_lookup(&guard, &attribute, field, op), + Ok(guard) => Self::_lookup(&guard, &attribute), Err(poison_error) => { let attr_map = poison_error.into_inner(); - Self::_lookup(&attr_map, &attribute, field, op) + Self::_lookup(&attr_map, &attribute) } } } else { diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs new file mode 100644 index 000000000..850976a32 --- /dev/null +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -0,0 +1,274 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +//! A strategy wrapper that enables insertion of [Document] objects into a +//! [DiskANNIndex] using a [DocumentProvider]. + +use std::marker::PhantomData; + +use diskann::{ + graph::{ + glue::{ + ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + }, + SearchOutputBuffer, + }, + neighbor::Neighbor, + provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + ANNResult, +}; + +use super::document_provider::DocumentProvider; +use crate::document::Document; +use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; + +/// A strategy wrapper that enables insertion of [Document] objects. +pub struct DocumentInsertStrategy { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl Clone for DocumentInsertStrategy { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + +impl Copy for DocumentInsertStrategy {} + +impl DocumentInsertStrategy { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } + + pub fn inner(&self) -> &Inner { + &self.inner + } +} + +/// Wrapper accessor for Document queries +pub struct DocumentSearchAccessor { + inner: Inner, + _phantom: PhantomData VT>, +} + +impl DocumentSearchAccessor { + pub fn new(inner: Inner) -> Self { + Self { + inner, + _phantom: PhantomData, + } + } +} + +impl HasId for DocumentSearchAccessor +where + Inner: HasId, + VT: ?Sized, +{ + type Id = Inner::Id; +} + +impl Accessor for DocumentSearchAccessor +where + Inner: Accessor, + VT: ?Sized, +{ + type ElementRef<'a> = Inner::ElementRef<'a>; + type Element<'a> + = Inner::Element<'a> + where + Self: 'a; + type Extended = Inner::Extended; + type GetError = Inner::GetError; + + fn get_element( + &mut self, + id: Self::Id, + ) -> impl std::future::Future, Self::GetError>> + Send { + self.inner.get_element(id) + } + + fn on_elements_unordered( + &mut self, + itr: Itr, + f: F, + ) -> impl std::future::Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + self.inner.on_elements_unordered(itr, f) + } +} + +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +where + Inner: BuildQueryComputer, + VT: ?Sized, +{ + type QueryComputerError = Inner::QueryComputerError; + type QueryComputer = Inner::QueryComputer; + + fn build_query_computer( + &self, + from: &Document<'doc, VT>, + ) -> Result { + self.inner.build_query_computer(from.vector()) + } +} + +impl<'this, Inner, VT> DelegateNeighbor<'this> for DocumentSearchAccessor +where + Inner: DelegateNeighbor<'this>, + VT: ?Sized, +{ + type Delegate = Inner::Delegate; + fn delegate_neighbor(&'this mut self) -> Self::Delegate { + self.inner.delegate_neighbor() + } +} + +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +where + Inner: ExpandBeam, + VT: ?Sized, +{ +} + +impl SearchExt for DocumentSearchAccessor +where + Inner: SearchExt, + VT: ?Sized, +{ + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + self.inner.starting_points() + } + fn terminate_early(&mut self) -> bool { + self.inner.terminate_early() + } +} + +#[derive(Debug, Default, Clone, Copy)] +pub struct CopyIdsForDocument; + +impl<'doc, A, VT> SearchPostProcess> for CopyIdsForDocument +where + A: BuildQueryComputer>, + VT: ?Sized, +{ + type Error = std::convert::Infallible; + + fn post_process( + &self, + _accessor: &mut A, + _query: &Document<'doc, VT>, + _computer: &>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let count = output.extend(candidates.map(|n| (n.id, n.distance))); + std::future::ready(Ok(count)) + } +} + +impl<'doc, Inner, DP, VT> + SearchStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type QueryComputer = Inner::QueryComputer; + type PostProcessor = CopyIdsForDocument; + type SearchAccessorError = Inner::SearchAccessorError; + type SearchAccessor<'a> = DocumentSearchAccessor, VT>; + + fn search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } + + fn post_processor(&self) -> Self::PostProcessor { + CopyIdsForDocument + } +} + +impl<'doc, Inner, DP, VT> + InsertStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type PruneStrategy = DocumentPruneStrategy; + + fn prune_strategy(&self) -> Self::PruneStrategy { + DocumentPruneStrategy::new(self.inner.prune_strategy()) + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .insert_search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } +} + +#[derive(Clone, Copy)] +pub struct DocumentPruneStrategy { + inner: Inner, +} + +impl DocumentPruneStrategy { + pub fn new(inner: Inner) -> Self { + Self { inner } + } +} + +impl PruneStrategy>> + for DocumentPruneStrategy +where + DP: DataProvider, + Inner: PruneStrategy, +{ + type DistanceComputer = Inner::DistanceComputer; + type PruneAccessor<'a> = Inner::PruneAccessor<'a>; + type PruneAccessorError = Inner::PruneAccessorError; + + fn prune_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::PruneAccessorError> { + self.inner + .prune_accessor(provider.inner_provider(), context) + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index 6b496271b..1fabf5f54 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -77,7 +77,7 @@ impl<'a, VT, DP, AS> SetElement> for DocumentProvider where DP: DataProvider + Delete + SetElement, AS: AttributeStore + AsyncFriendly, - VT: Sync + Send, + VT: Sync + Send + ?Sized, { type SetError = ANNError; type Guard = >::Guard; diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index d56cb13c1..370ef25ae 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -5,8 +5,6 @@ use std::sync::{Arc, RwLock}; -use diskann::ANNResult; - use crate::{ encoded_attribute_provider::{ ast_id_expr::ASTIdExpr, ast_label_id_mapper::ASTLabelIdMapper, @@ -16,20 +14,21 @@ use crate::{ }; pub(crate) struct EncodedFilterExpr { - ast_id_expr: ASTIdExpr, + ast_id_expr: Option>, } impl EncodedFilterExpr { - pub fn new( - ast_expr: &ASTExpr, - attribute_map: Arc>, - ) -> ANNResult { + pub fn new(ast_expr: &ASTExpr, attribute_map: Arc>) -> Self { let mut mapper = ASTLabelIdMapper::new(attribute_map); - let ast_id_expr = ast_expr.accept(&mut mapper)?; - Ok(Self { ast_id_expr }) + match ast_expr.accept(&mut mapper) { + Ok(ast_id_expr) => Self { + ast_id_expr: Some(ast_id_expr), + }, + Err(_e) => Self { ast_id_expr: None }, + } } - pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { + pub(crate) fn encoded_filter_expr(&self) -> &Option> { &self.ast_id_expr } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index 6b82a68b1..c69589ba0 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -15,7 +15,7 @@ use diskann::{utils::VectorId, ANNError, ANNErrorKind, ANNResult}; use diskann_utils::future::AsyncFriendly; use std::sync::{Arc, RwLock}; -pub(crate) struct RoaringAttributeStore +pub struct RoaringAttributeStore where IT: VectorId + AsyncFriendly, { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 962d361d7..1def9a406 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -28,7 +28,7 @@ use crate::{ type AttrAccessor = EncodedAttributeAccessor::Id>>; -pub(crate) struct EncodedDocumentAccessor +pub struct EncodedDocumentAccessor where IA: HasId, { @@ -136,7 +136,7 @@ where Some(set) => Ok(set.into_owned()), None => Err(ANNError::message( ANNErrorKind::IndexError, - "No labels were found for vector", + format!("No labels were found for vector:{:?}", id), )), } })?; @@ -220,12 +220,20 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone())?; + let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone()); + let is_valid_filter = id_query.encoded_filter_expr().is_some(); + if !is_valid_filter { + tracing::warn!( + "Failed to convert {} into an id expr. This will now be an unfiltered search.", + from.filter_expr() + ); + } Ok(InlineBetaComputer::new( inner_computer, self.beta_value, id_query, + is_valid_filter, )) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b25b1746f..f03f36c12 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -28,6 +28,13 @@ pub struct InlineBetaStrategy { inner: Strategy, } +impl InlineBetaStrategy { + /// Create a new InlineBetaStrategy with the given beta value and inner strategy. + pub fn new(beta: f32, inner: Strategy) -> Self { + Self { beta, inner } + } +} + impl SearchStrategy>, FilteredQuery> for InlineBetaStrategy @@ -72,6 +79,7 @@ pub struct InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, //optimization to avoid evaluating empty predicates. } impl InlineBetaComputer { @@ -79,17 +87,23 @@ impl InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, + is_valid_filter: bool, ) -> Self { Self { inner_computer, beta_value, filter_expr, + is_valid_filter, } } pub(crate) fn filter_expr(&self) -> &EncodedFilterExpr { &self.filter_expr } + + pub(crate) fn is_valid_filter(&self) -> bool { + self.is_valid_filter + } } impl PreprocessedDistanceFunction, f32> @@ -101,22 +115,35 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim + if self.is_valid_filter { + match self + .filter_expr + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pred_eval) + { + Ok(matched) => { + if matched { + return sim * self.beta_value; + } else { + return sim; + } + } + Err(_) => { + //If predicate evaluation fails for any reason, we simply revert + //to unfiltered search. + tracing::warn!("Predicate evaluation failed"); + return sim; } } - Err(_) => { - //TODO: If predicate evaluation fails, we are taking the approach that we will simply - //return the score returned by the inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim - } + } else { + //If predicate evaluation fails, we will return the score returned by the + //inner computer, as though no predicate was specified. + tracing::warn!( + "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" + ); + sim } } } @@ -155,8 +182,16 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.filter_expr().encoded_filter_expr().accept(&pe)? { - filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + if computer.is_valid_filter() { + if computer + .filter_expr() + .encoded_filter_expr() + .as_ref() + .unwrap() + .accept(&pe)? + { + filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); + } } } diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 106845f98..273475b15 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -40,6 +40,7 @@ pub mod encoded_attribute_provider { pub(crate) mod ast_id_expr; pub(crate) mod ast_label_id_mapper; pub(crate) mod attribute_encoder; + pub mod document_insert_strategy; pub mod document_provider; pub mod encoded_attribute_accessor; pub(crate) mod encoded_filter_expr; diff --git a/diskann-label-filter/src/parser/format.rs b/diskann-label-filter/src/parser/format.rs index 5e9e3a9c1..c042d8338 100644 --- a/diskann-label-filter/src/parser/format.rs +++ b/diskann-label-filter/src/parser/format.rs @@ -15,10 +15,8 @@ pub struct Document { /// label in raw json format #[serde(flatten)] pub label: serde_json::Value, - } - /// Represents a query expression as defined in the RFC. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryExpression { diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e74419a46..9a48488fe 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -1,580 +1,638 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{collections::HashMap, fmt::Debug, future::Future}; - -use diskann::{ - ANNError, ANNResult, - graph::{ - SearchOutputBuffer, - glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, - }, - utils::{IntoUsize, VectorRepr}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; - -use crate::model::graph::{ - provider::async_::{ - FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, - common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, - }, - inmem::DefaultProvider, - postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, - }, - traits::AdHoc, -}; - -/// A type alias for the DefaultProvider with full-precision as the primary vector store. -pub type FullPrecisionProvider = - DefaultProvider, Q, D, Ctx>; - -/// The default full-precision vector store. -pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; - -/// A default full-precision vector store provider. -#[derive(Clone)] -pub struct CreateFullPrecision { - dim: usize, - prefetch_cache_line_level: Option, - _phantom: std::marker::PhantomData, -} - -impl CreateFullPrecision -where - T: VectorRepr, -{ - /// Create a new full-precision vector store provider. - pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { - Self { - dim, - prefetch_cache_line_level, - _phantom: std::marker::PhantomData, - } - } -} - -impl CreateVectorStore for CreateFullPrecision -where - T: VectorRepr, -{ - type Target = FullPrecisionStore; - fn create( - self, - max_points: usize, - metric: Metric, - prefetch_lookahead: Option, - ) -> Self::Target { - FullPrecisionStore::new( - max_points, - self.dim, - metric, - self.prefetch_cache_line_level, - prefetch_lookahead, - ) - } -} - -//////////////// -// SetElement // -//////////////// - -impl SetElementHelper for FullPrecisionStore -where - T: VectorRepr, -{ - /// Set the element at the given index. - fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { - unsafe { self.set_vector_sync(id.into_usize(), element) } - } -} - -////////////////// -// FullAccessor // -////////////////// - -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, -{ - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, -} - -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } -} - -impl HasId for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Id = u32; -} - -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } - } -} - -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The extended element inherets the lifetime of the Accessor. - type Extended = &'a [T]; - - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - /// - /// NOTE: We intentionally don't use `'b` here since our implementation borrows - /// the inner `Opaque` from the underlying provider. - type Element<'b> - = &'a [T] - where - Self: 'b; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'b> = &'b [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} - -impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} - -impl FillSet for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - for i in itr { - set.entry(i).or_insert_with(|| unsafe { - self.provider.base_vectors.get_vector_sync(i.into_usize()) - }); - } - Ok(()) - } -} - -//-------------------// -// In-mem Extensions // -//-------------------// - -impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted - } -} - -////////////////// -// Post Process // -////////////////// - -pub trait GetFullPrecision { - type Repr: VectorRepr; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; -} - -/// A [`SearchPostProcess`]or that: -/// -/// 1. Filters out deleted ids from being returned. -/// 2. Reranks a candidate stream using full-precision distances. -/// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; - -impl glue::SearchPostProcess for Rerank -where - T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, -{ - type Error = Panics; - - fn post_process( - &self, - accessor: &mut A, - query: &[T], - _computer: &A::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, - { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) - } -} - -//////////////// -// Strategies // -//////////////// - -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -// Pruning -impl PruneStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputer = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -/// Implementing this trait allows `FullPrecision` to be used for multi-insert. -impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = diskann::error::Infallible; - fn as_element( - &mut self, - vector: &'a [T], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl InsertStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type PruneStrategy = Self; - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } -} - -// Inplace Delete // -impl InplaceDeleteStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type DeleteElementError = Panics; - type DeleteElement<'a> = [T]; - type DeleteElementGuard = Box<[T]>; - type PruneStrategy = Self; - type SearchStrategy = Internal; - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) - } - - fn prune_strategy(&self) -> Self::PruneStrategy { - Self - } - - async fn get_delete_element<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - id: u32, - ) -> Result { - Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashMap, fmt::Debug, future::Future}; + +use diskann::{ + ANNError, ANNResult, + graph::{ + SearchOutputBuffer, + glue::{ + self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, + }, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasId, + }, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::{DistanceFunction, distance::Metric}; + +use crate::model::graph::{ + provider::async_::{ + FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, + common::{ + CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, + PrefetchCacheLineLevel, SetElementHelper, + }, + inmem::DefaultProvider, + postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, + }, + traits::AdHoc, +}; + +/// A type alias for the DefaultProvider with full-precision as the primary vector store. +pub type FullPrecisionProvider = + DefaultProvider, Q, D, Ctx>; + +/// The default full-precision vector store. +pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; + +/// A default full-precision vector store provider. +#[derive(Clone)] +pub struct CreateFullPrecision { + dim: usize, + prefetch_cache_line_level: Option, + _phantom: std::marker::PhantomData, +} + +impl CreateFullPrecision +where + T: VectorRepr, +{ + /// Create a new full-precision vector store provider. + pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { + Self { + dim, + prefetch_cache_line_level, + _phantom: std::marker::PhantomData, + } + } +} + +impl CreateVectorStore for CreateFullPrecision +where + T: VectorRepr, +{ + type Target = FullPrecisionStore; + fn create( + self, + max_points: usize, + metric: Metric, + prefetch_lookahead: Option, + ) -> Self::Target { + FullPrecisionStore::new( + max_points, + self.dim, + metric, + self.prefetch_cache_line_level, + prefetch_lookahead, + ) + } +} + +//////////////// +// SetElement // +//////////////// + +impl SetElementHelper for FullPrecisionStore +where + T: VectorRepr, +{ + /// Set the element at the given index. + fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { + unsafe { self.set_vector_sync(id.into_usize(), element) } + } +} + +////////////////// +// FullAccessor // +////////////////// + +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +/// * [`BuildQueryComputer`]. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, + + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: Vec, +} + +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors + } +} + +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } +} + +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + pub fn new(provider: &'a FullPrecisionProvider) -> Self { + Self { + provider, + id_buffer: Vec::new(), + } + } +} + +impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.provider.neighbors() + } +} + +impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + /// The extended element inherets the lifetime of the Accessor. + type Extended = &'a [T]; + + /// This accessor returns raw slices. There *is* a chance of racing when the fast + /// providers are used. We just have to live with it. + /// + /// NOTE: We intentionally don't use `'b` here since our implementation borrows + /// the inner `Opaque` from the underlying provider. + type Element<'b> + = &'a [T] + where + Self: 'b; + + /// `ElementRef` has an arbitrarily short lifetime. + type ElementRef<'b> = &'b [T]; + + /// Choose to panic on an out-of-bounds access rather than propagate an error. + type GetError = Panics; + + /// Return the full-precision vector stored at index `i`. + /// + /// This function always completes synchronously. + #[inline(always)] + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // potentially mixing unsynchronized reads and writes on the underlying memory. + std::future::ready(Ok(unsafe { + self.provider.base_vectors.get_vector_sync(id.into_usize()) + })) + } + + /// Perform a bulk operation. + /// + /// This implementation uses prefetching. + fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> impl Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(itr); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + f( + unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + *id, + ) + } + + std::future::ready(Ok(())) + } +} + +impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.provider.metric, + Some(self.provider.base_vectors.dim()), + )) + } +} + +impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &[T], + ) -> Result { + Ok(T::query_distance(from, self.provider.metric)) + } +} + +impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +/// Support for Vec queries that delegates to the [T] impl via deref. +/// This allows InlineBetaStrategy to use Vec queries with FullAccessor. +impl BuildQueryComputer> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &Vec, + ) -> Result { + // Delegate to [T] impl via deref + Ok(T::query_distance(from.as_slice(), self.provider.metric)) + } +} + +/// Support for Vec queries that delegates to the [T] impl. +impl ExpandBeam> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +impl FillSet for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + async fn fill_set( + &mut self, + set: &mut HashMap, + itr: Itr, + ) -> Result<(), Self::GetError> + where + Itr: Iterator + Send + Sync, + { + for i in itr { + set.entry(i).or_insert_with(|| unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + } + Ok(()) + } +} + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted + } +} + +////////////////// +// Post Process // +////////////////// + +pub trait GetFullPrecision { + type Repr: VectorRepr; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; +} + +/// A [`SearchPostProcess`]or that: +/// +/// 1. Filters out deleted ids from being returned. +/// 2. Reranks a candidate stream using full-precision distances. +/// 3. Copies back the results to the output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl glue::SearchPostProcess for Rerank +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = Panics; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator>, + B: SearchOutputBuffer + ?Sized, + { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let f = full.distance(); + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. + std::future::ready(Ok(output.extend(reranked))) + } +} + +//////////////// +// Strategies // +//////////////// + +// A layered approach is used for search strategies. The `Internal` version does the heavy +// lifting in terms of establishing accessors and post processing. +// +// However, during post-processing, the `Internal` versions of strategies will not filter +// out the start points. The publicly exposed types *will* filter out the start points. +// +// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust +// the adjacency list for the start point to reuse the `Internal` strategies. + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = RemoveDeletedIdsAndCopy; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Support for Vec queries that delegates to the [T] impl. +/// This allows InlineBetaStrategy to use Vec queries with FullPrecision. +impl SearchStrategy, Vec> for FullPrecision +where + T: VectorRepr + Clone, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// Pruning +impl PruneStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputer = T::Distance; + type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::PruneAccessorError> { + Ok(FullAccessor::new(provider)) + } +} + +/// Implementing this trait allows `FullPrecision` to be used for multi-insert. +impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [T], + _id: Self::Id, + ) -> impl Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +impl InsertStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +// Inplace Delete // +impl InplaceDeleteStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type DeleteElementError = Panics; + type DeleteElement<'a> = [T]; + type DeleteElementGuard = Box<[T]>; + type PruneStrategy = Self; + type SearchStrategy = Internal; + fn search_strategy(&self) -> Self::SearchStrategy { + Internal(Self) + } + + fn prune_strategy(&self) -> Self::PruneStrategy { + Self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + id: u32, + ) -> Result { + Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) + } +} diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 31e69b2b2..8c2fa29f6 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -32,14 +32,14 @@ use crate::utils::{search_index_utils, CMDResult, CMDToolError}; /// Expands a JSON object with array-valued fields into multiple objects with scalar values. /// For example: {"country": ["AU", "NZ"], "year": 2007} /// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] -/// +/// /// If multiple fields have arrays, all combinations are generated. fn expand_array_fields(value: &Value) -> Vec { match value { Value::Object(map) => { // Start with a single empty object let mut results: Vec> = vec![Map::new()]; - + for (key, val) in map.iter() { if let Value::Array(arr) = val { // Expand: for each existing result, create copies for each array element @@ -62,7 +62,7 @@ fn expand_array_fields(value: &Value) -> Vec { } } } - + results.into_iter().map(Value::Object).collect() } // If not an object, return as-is @@ -74,7 +74,9 @@ fn expand_array_fields(value: &Value) -> Vec { /// Returns true if any expanded variant matches the query. fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { let expanded = expand_array_fields(label); - expanded.iter().any(|item| eval_query_expr(query_expr, item)) + expanded + .iter() + .any(|item| eval_query_expr(query_expr, item)) } pub fn read_labels_and_compute_bitmap( @@ -127,11 +129,13 @@ pub fn read_labels_and_compute_bitmap( // Handle case where base_label.label is an array - check if any element matches // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) let matches = if let Some(array) = base_label.label.as_array() { - array.iter().any(|item| eval_query_with_array_expansion(query_expr, item)) + array + .iter() + .any(|item| eval_query_with_array_expansion(query_expr, item)) } else { eval_query_with_array_expansion(query_expr, &base_label.label) }; - + if matches { bitmap.insert(base_label.doc_id); } @@ -164,11 +168,17 @@ pub fn read_labels_and_compute_bitmap( // If no matches, print more diagnostic info if total_matches == 0 { tracing::warn!("WARNING: No base vectors matched any query filters!"); - tracing::warn!("This could indicate a format mismatch between base labels and query filters."); - + tracing::warn!( + "This could indicate a format mismatch between base labels and query filters." + ); + // Try to identify what keys exist in base labels vs queries if let Some(first_label) = base_labels.first() { - tracing::warn!("First base label (full): doc_id={}, label={}", first_label.doc_id, first_label.label); + tracing::warn!( + "First base label (full): doc_id={}, label={}", + first_label.doc_id, + first_label.label + ); } } @@ -323,7 +333,7 @@ pub fn compute_ground_truth_from_datafiles< for (query_idx, npq) in ground_truth.iter().enumerate() { let neighbors: Vec<_> = npq.iter().collect(); let neighbor_count = neighbors.len(); - + if query_idx < 10 { // Print top K IDs and distances for first 10 queries let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); @@ -336,7 +346,7 @@ pub fn compute_ground_truth_from_datafiles< top_dists ); } - + if neighbor_count == 0 { tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); } @@ -344,7 +354,10 @@ pub fn compute_ground_truth_from_datafiles< // Summary stats let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); - let queries_with_neighbors = ground_truth.iter().filter(|npq| npq.iter().count() > 0).count(); + let queries_with_neighbors = ground_truth + .iter() + .filter(|npq| npq.iter().count() > 0) + .count(); tracing::info!( "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", total_neighbors, diff --git a/test_data/disk_index_search/data.256.label.jsonl b/test_data/disk_index_search/data.256.label.jsonl index 83254af7b..a99cde8e2 100644 --- a/test_data/disk_index_search/data.256.label.jsonl +++ b/test_data/disk_index_search/data.256.label.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f8b6b99ca32173557689712d3fb5da30c5e4111130fd2accbccf32f5ce3e47e -size 17702 +oid sha256:92576896b10780a2cd80a16030f8384610498b76453f57fadeacb854379e0acf +size 17701 From 98ad4f7d28b3a77e613105ecab3fc3b6d162c872 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 17 Feb 2026 11:57:11 +0530 Subject: [PATCH 04/39] Removing unnecessary stats --- .../src/backend/document_index/benchmark.rs | 118 ++---------------- diskann-benchmark/src/backend/index/result.rs | 13 -- 2 files changed, 12 insertions(+), 119 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index dffe669ff..ed2915974 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -586,57 +586,6 @@ where let recall_metrics: recall::RecallMetrics = (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); - // Compute per-query details (only for queries with recall < 1) - let per_query_details: Vec = (0..num_queries) - .filter_map(|query_idx| { - let result_ids: Vec = merged - .ids - .row(query_idx) - .iter() - .copied() - .filter(|&id| id != u32::MAX) - .collect(); - let result_distances: Vec = merged - .distances - .get(query_idx) - .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) - .unwrap_or_default(); - // Only keep top 20 from ground truth - let gt_ids: Vec = groundtruth - .get(query_idx) - .map(|gt| gt.iter().take(20).copied().collect()) - .unwrap_or_default(); - - // Compute per-query recall: intersection of result_ids with gt_ids / recall_k - let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); - let gt_set: std::collections::HashSet = - gt_ids.iter().take(recall_k).copied().collect(); - let intersection = result_set.intersection(>_set).count(); - let per_query_recall = if gt_set.is_empty() { - 1.0 - } else { - intersection as f64 / gt_set.len() as f64 - }; - - // Only include queries with imperfect recall - if per_query_recall >= 1.0 { - return None; - } - - let (_, ref ast_expr) = predicates[query_idx]; - let filter_str = format!("{:?}", ast_expr); - - Some(PerQueryDetails { - query_id: query_idx, - filter: filter_str, - recall: per_query_recall, - result_ids, - result_distances, - groundtruth_ids: gt_ids, - }) - }) - .collect(); - // Compute QPS from rep latencies let qps: Vec = rep_latencies .iter() @@ -684,18 +633,15 @@ where Ok(SearchRunStats { num_threads: num_threads.get(), - num_queries, search_n, search_l, recall: recall_metrics, qps, - wall_clock_time: rep_latencies, mean_latency: mean, p90_latency: p90, p99_latency: p99, mean_cmps, mean_hops, - per_query_details: Some(per_query_details), }) } async fn run_search_parallel( @@ -830,60 +776,19 @@ pub struct BuildParamsStats { pub alpha: f32, } -/// Helper module for serializing arrays as compact single-line JSON strings -mod compact_array { - use serde::Serializer; - - pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } - - pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } -} - -/// Per-query detailed results for debugging/analysis -#[derive(Debug, Serialize)] -pub struct PerQueryDetails { - pub query_id: usize, - pub filter: String, - pub recall: f64, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] - pub result_ids: Vec, - #[serde(serialize_with = "compact_array::serialize_f32_vec")] - pub result_distances: Vec, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] - pub groundtruth_ids: Vec, -} - /// Results from a single search configuration (one search_l value). #[derive(Debug, Serialize)] pub struct SearchRunStats { pub num_threads: usize, - pub num_queries: usize, pub search_n: usize, pub search_l: usize, pub recall: recall::RecallMetrics, pub qps: Vec, - pub wall_clock_time: Vec, pub mean_latency: f64, pub p90_latency: MicroSeconds, pub p99_latency: MicroSeconds, pub mean_cmps: f32, pub mean_hops: f32, - #[serde(skip_serializing_if = "Option::is_none")] - pub per_query_details: Option>, } #[derive(Debug, Serialize)] @@ -929,8 +834,16 @@ impl std::fmt::Display for DocumentIndexStats { writeln!(f, "\nFiltered Search Results:")?; writeln!( f, - " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", - "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8}", + "L", + "KNN", + "Avg Cmps", + "Avg Hops", + "QPS -mean(max)", + "Avg Latency", + "p99 Latency", + "Recall", + "Threads" )?; for s in &self.search { let mean_qps = if s.qps.is_empty() { @@ -939,14 +852,9 @@ impl std::fmt::Display for DocumentIndexStats { s.qps.iter().sum::() / s.qps.len() as f64 }; let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); - let mean_wall_clock = if s.wall_clock_time.is_empty() { - 0.0 - } else { - s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 - }; writeln!( f, - " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8}", s.search_l, s.search_n, s.mean_cmps, @@ -956,9 +864,7 @@ impl std::fmt::Display for DocumentIndexStats { s.mean_latency, s.p99_latency, s.recall.average, - s.num_threads, - s.num_queries, - mean_wall_clock + s.num_threads )?; } } diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 21d74f915..c7e2ab75c 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -109,7 +109,6 @@ impl std::fmt::Display for AggregatedSearchResults { #[derive(Debug, Serialize)] pub(super) struct SearchResults { pub(super) num_tasks: usize, - pub(super) num_queries: usize, pub(super) search_n: usize, pub(super) search_l: usize, pub(super) qps: Vec, @@ -144,7 +143,6 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), - num_queries: recall.num_queries, search_n: parameters.k_value, search_l: parameters.l_value, qps, @@ -184,8 +182,6 @@ where "p99 Latency", "Recall", "Threads", - "Queries", - "WallClock(s)", ] } else { &[ @@ -198,8 +194,6 @@ where "p99 Latency", "Recall", "Threads", - "Queries", - "WallClock(s)", ] }; @@ -243,13 +237,6 @@ where ); row.insert(format!("{:3}", r.recall.average), col_idx + 7); row.insert(r.num_tasks, col_idx + 8); - row.insert(r.num_queries, col_idx + 9); - let mean_wall_clock = if r.search_latencies.is_empty() { - 0.0 - } else { - r.search_latencies.iter().map(|t| t.as_seconds()).sum::() / r.search_latencies.len() as f64 - }; - row.insert(format!("{:.3}", mean_wall_clock), col_idx + 10); }); write!(f, "{}", table) From edfdee6e7dd3d63f03dc8484390d1f5ebd604c4c Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Mon, 2 Mar 2026 15:54:44 +0530 Subject: [PATCH 05/39] Fix clippy warnings --- .../roaring_attribute_store.rs | 9 +++++++++ .../src/inline_beta_search/inline_beta_filter.rs | 15 +++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index c69589ba0..cab250c76 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -24,6 +24,15 @@ where inv_index: Arc>>, } +impl Default for RoaringAttributeStore +where + IT: VectorId, +{ + fn default() -> Self { + Self::new() + } +} + impl RoaringAttributeStore where IT: VectorId, diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index f03f36c12..76a78de67 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -125,16 +125,16 @@ where { Ok(matched) => { if matched { - return sim * self.beta_value; + sim * self.beta_value } else { - return sim; + sim } } Err(_) => { //If predicate evaluation fails for any reason, we simply revert //to unfiltered search. tracing::warn!("Predicate evaluation failed"); - return sim; + sim } } } else { @@ -182,16 +182,15 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.is_valid_filter() { - if computer + if computer.is_valid_filter() + && computer .filter_expr() .encoded_filter_expr() .as_ref() .unwrap() .accept(&pe)? - { - filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); - } + { + filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); } } From cbb77f64e54a38038f64fd6fba7b9643e626b9e1 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:44:47 +0530 Subject: [PATCH 06/39] use search and build benchmark apis --- .../src/backend/document_index/benchmark.rs | 630 ++++++++---------- 1 file changed, 288 insertions(+), 342 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index dffe669ff..51b3467bf 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -20,7 +20,12 @@ use diskann::{ search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, }, provider::DefaultContext, - utils::{async_tools, IntoUsize}, +}; +use diskann_benchmark_core::{ + build::{self, AsProgress, Build, Parallelism, Progress}, + recall, + search as search_api, + tokio, }; use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, @@ -485,228 +490,237 @@ impl<'a> DocumentIndexJob<'a> { Ok(stats) } } -/// Local results from a partition of queries. -struct SearchLocalResults { - ids: Matrix, - distances: Vec>, - latencies: Vec, - comparisons: Vec, - hops: Vec, +/// Per-query output from [`FilteredSearcher::search`]. +struct FilteredSearchOutput { + distances: Vec, + comparisons: u32, + hops: u32, } -impl SearchLocalResults { - fn merge(all: &[SearchLocalResults]) -> anyhow::Result { - let first = all - .first() - .ok_or_else(|| anyhow::anyhow!("empty results"))?; - let num_ids = first.ids.ncols(); - let total_rows: usize = all.iter().map(|r| r.ids.nrows()).sum(); - - let mut ids = Matrix::new(0, total_rows, num_ids); - let mut output_row = 0; - for r in all { - for input_row in r.ids.row_iter() { - ids.row_mut(output_row).copy_from_slice(input_row); - output_row += 1; - } - } - - let mut distances = Vec::new(); - let mut latencies = Vec::new(); - let mut comparisons = Vec::new(); - let mut hops = Vec::new(); - for r in all { - distances.extend_from_slice(&r.distances); - latencies.extend_from_slice(&r.latencies); - comparisons.extend_from_slice(&r.comparisons); - hops.extend_from_slice(&r.hops); - } - - Ok(Self { - ids, - distances, - latencies, - comparisons, - hops, - }) - } +/// Implements [`search_api::Search`] for parallelized inline-beta filtered search. +/// +/// Each query is paired with a predicate at the same index in `predicates`. The +/// [`InlineBetaStrategy`] is used with a [`FilteredQuery`] containing the raw vector +/// and the predicate's [`ASTExpr`]. +struct FilteredSearcher +where + DP: diskann::provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, } -/// Run filtered search with the given parameters. -#[allow(clippy::too_many_arguments)] -fn run_filtered_search( - index: &Arc>, - queries: &Matrix, - predicates: &[(usize, ASTExpr)], - groundtruth: &Vec>, - beta: f32, - num_threads: NonZeroUsize, - search_n: usize, - search_l: usize, - recall_k: usize, - reps: NonZeroUsize, -) -> anyhow::Result +impl search_api::Search for FilteredSearcher where - T: bytemuck::Pod + Copy + Send + Sync + 'static, - DP: diskann::provider::DataProvider< - Context = DefaultContext, - ExternalId = u32, - InternalId = u32, - > + Send + DP: diskann::provider::DataProvider + + Send + Sync + 'static, - InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>>, + InlineBetaStrategy: diskann::graph::glue::SearchStrategy>, u32>, + T: bytemuck::Pod + Copy + Send + Sync + 'static, { - let rt = utils::tokio::runtime(num_threads.get())?; - let num_queries = queries.nrows(); + type Id = DP::ExternalId; + type Parameters = SearchParams; + type Output = FilteredSearchOutput; - let mut all_rep_results = Vec::with_capacity(reps.get()); - let mut rep_latencies = Vec::with_capacity(reps.get()); + fn num_queries(&self) -> usize { + self.queries.nrows() + } - for _ in 0..reps.get() { - let start = std::time::Instant::now(); - let results = rt.block_on(run_search_parallel( - index.clone(), - queries, - predicates, - beta, - num_threads, - search_n, - search_l, - ))?; - rep_latencies.push(MicroSeconds::from(start.elapsed())); - all_rep_results.push(results); + fn id_count(&self, parameters: &SearchParams) -> search_api::IdCount { + search_api::IdCount::Fixed( + NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE), + ) } - // Merge results from first rep for recall calculation - let merged = SearchLocalResults::merge(&all_rep_results[0])?; - - // Compute recall - let recall_metrics: recall::RecallMetrics = - (&recall::knn(groundtruth, None, &merged.ids, recall_k, search_n, false)?).into(); - - // Compute per-query details (only for queries with recall < 1) - let per_query_details: Vec = (0..num_queries) - .filter_map(|query_idx| { - let result_ids: Vec = merged - .ids - .row(query_idx) - .iter() - .copied() - .filter(|&id| id != u32::MAX) - .collect(); - let result_distances: Vec = merged - .distances - .get(query_idx) - .map(|d| d.iter().copied().filter(|&dist| dist != f32::MAX).collect()) - .unwrap_or_default(); - // Only keep top 20 from ground truth - let gt_ids: Vec = groundtruth - .get(query_idx) - .map(|gt| gt.iter().take(20).copied().collect()) - .unwrap_or_default(); - - // Compute per-query recall: intersection of result_ids with gt_ids / recall_k - let result_set: std::collections::HashSet = result_ids.iter().copied().collect(); - let gt_set: std::collections::HashSet = - gt_ids.iter().take(recall_k).copied().collect(); - let intersection = result_set.intersection(>_set).count(); - let per_query_recall = if gt_set.is_empty() { - 1.0 - } else { - intersection as f64 / gt_set.len() as f64 - }; + async fn search( + &self, + parameters: &SearchParams, + buffer: &mut O, + index: usize, + ) -> diskann::ANNResult + where + O: diskann::graph::SearchOutputBuffer + Send, + { + let ctx = DefaultContext; + let query_vec = self.queries.row(index); + let (_, ref ast_expr) = self.predicates[index]; + let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); + let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); + + // Use a concrete IdDistance scratch buffer so that both the IDs and distances + // are captured. Afterwards, the valid IDs are forwarded into the framework buffer. + let k = parameters.k_value; + let mut ids = vec![0u32; k]; + let mut distances = vec![0.0f32; k]; + let mut scratch = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let stats = self + .index + .search(&strategy, &ctx, &filtered_query, parameters, &mut scratch) + .await?; - // Only include queries with imperfect recall - if per_query_recall >= 1.0 { - return None; + let count = scratch.current_len(); + for (&id, &dist) in std::iter::zip(&ids[..count], &distances[..count]) { + if buffer.push(id, dist).is_full() { + break; } + } + + Ok(FilteredSearchOutput { + distances: distances[..count].to_vec(), + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + +/// Aggregates per-rep [`search_api::SearchResults`] into a [`SearchRunStats`]. +struct FilteredSearchAggregator<'a> { + groundtruth: &'a Vec>, + predicates: &'a [(usize, ASTExpr)], + recall_k: usize, +} + +impl search_api::Aggregate + for FilteredSearchAggregator<'_> +{ + type Output = SearchRunStats; + + fn aggregate( + &mut self, + run: search_api::Run, + results: Vec>, + ) -> anyhow::Result { + let parameters = run.parameters(); + let search_n = parameters.k_value; + let num_queries = results.first().map(|r| r.len()).unwrap_or(0); + + // Recall from first rep only. + let recall_metrics: SerializableRecallMetrics = match results.first() { + Some(first) => (&recall::knn( + self.groundtruth, + None, + first.ids().as_rows(), + self.recall_k, + search_n, + true, + )?) + .into(), + None => anyhow::bail!("no search results"), + }; - let (_, ref ast_expr) = predicates[query_idx]; - let filter_str = format!("{:?}", ast_expr); + // Per-query details from first rep (only queries with recall < 1). + let first = results.first().unwrap(); + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = first.ids().as_rows().row(query_idx).to_vec(); + let result_distances: Vec = first + .output() + .get(query_idx) + .map(|o| o.distances.clone()) + .unwrap_or_default(); + let gt_ids: Vec = self + .groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + let result_set: std::collections::HashSet = + result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(self.recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; - Some(PerQueryDetails { - query_id: query_idx, - filter: filter_str, - recall: per_query_recall, - result_ids, - result_distances, - groundtruth_ids: gt_ids, + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = self.predicates[query_idx]; + Some(PerQueryDetails { + query_id: query_idx, + filter: format!("{:?}", ast_expr), + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) }) - }) - .collect(); + .collect(); - // Compute QPS from rep latencies - let qps: Vec = rep_latencies - .iter() - .map(|l| num_queries as f64 / l.as_seconds()) - .collect(); + // Wall-clock latency and QPS per rep. + let rep_latencies: Vec = + results.iter().map(|r| r.end_to_end_latency()).collect(); + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); - // Aggregate per-query latencies across all reps - let (all_latencies, all_cmps, all_hops): (Vec<_>, Vec<_>, Vec<_>) = all_rep_results - .iter() - .map(|results| { - let mut lat = Vec::new(); - let mut cmp = Vec::new(); - let mut hop = Vec::new(); - for r in results { - lat.extend_from_slice(&r.latencies); - cmp.extend_from_slice(&r.comparisons); - hop.extend_from_slice(&r.hops); + // Per-query latencies, comparisons, and hops aggregated across all reps. + let mut all_latencies: Vec = Vec::new(); + let mut all_cmps: Vec = Vec::new(); + let mut all_hops: Vec = Vec::new(); + for r in &results { + all_latencies.extend_from_slice(r.latencies()); + for o in r.output() { + all_cmps.push(o.comparisons); + all_hops.push(o.hops); } - (lat, cmp, hop) - }) - .fold( - (Vec::new(), Vec::new(), Vec::new()), - |(mut a, mut b, mut c): (Vec, Vec, Vec), (x, y, z)| { - a.extend(x); - b.extend(y); - c.extend(z); - (a, b, c) - }, - ); + } - let mut query_latencies = all_latencies; - let percentiles::Percentiles { mean, p90, p99, .. } = - percentiles::compute_percentiles(&mut query_latencies)?; + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut all_latencies)?; - let mean_cmps = if all_cmps.is_empty() { - 0.0 - } else { - all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 - }; - let mean_hops = if all_hops.is_empty() { - 0.0 - } else { - all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 - }; + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; - Ok(SearchRunStats { - num_threads: num_threads.get(), - num_queries, - search_n, - search_l, - recall: recall_metrics, - qps, - wall_clock_time: rep_latencies, - mean_latency: mean, - p90_latency: p90, - p99_latency: p99, - mean_cmps, - mean_hops, - per_query_details: Some(per_query_details), - }) + Ok(SearchRunStats { + num_threads: run.setup().threads.get(), + num_queries, + search_n, + search_l: parameters.l_value, + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) + } } -async fn run_search_parallel( - index: Arc>, + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, queries: &Matrix, predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, beta: f32, - num_tasks: NonZeroUsize, + num_threads: NonZeroUsize, search_n: usize, search_l: usize, -) -> anyhow::Result> + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result where T: bytemuck::Pod + Copy + Send + Sync + 'static, DP: diskann::provider::DataProvider< @@ -719,109 +733,31 @@ where InlineBetaStrategy: diskann::graph::glue::SearchStrategy>>, { - let num_queries = queries.nrows(); - - // Plan query partitions - let partitions: Result, _> = (0..num_tasks.get()) - .map(|task_id| async_tools::partition(num_queries, num_tasks, task_id)) - .collect(); - let partitions = partitions?; - - // We need to clone data for each task - let queries_arc = Arc::new(queries.clone()); - let predicates_arc = Arc::new(predicates.to_vec()); - - let handles: Vec<_> = partitions - .into_iter() - .map(|range| { - let index = index.clone(); - let queries = queries_arc.clone(); - let predicates = predicates_arc.clone(); - tokio::spawn(async move { - run_search_local(index, queries, predicates, beta, range, search_n, search_l).await - }) - }) - .collect(); - - let mut results = Vec::new(); - for h in handles { - results.push(h.await??); - } - - Ok(results) -} - -async fn run_search_local( - index: Arc>, - queries: Arc>, - predicates: Arc>, - beta: f32, - range: std::ops::Range, - search_n: usize, - search_l: usize, -) -> anyhow::Result -where - T: bytemuck::Pod + Copy + Send + Sync + 'static, - DP: diskann::provider::DataProvider< - Context = DefaultContext, - ExternalId = u32, - InternalId = u32, - > + Send - + Sync, - InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>>, -{ - let mut ids = Matrix::new(0, range.len(), search_n); - let mut all_distances: Vec> = Vec::with_capacity(range.len()); - let mut latencies = Vec::with_capacity(range.len()); - let mut comparisons = Vec::with_capacity(range.len()); - let mut hops = Vec::with_capacity(range.len()); - - let ctx = DefaultContext; - let search_params = SearchParams::new_default(search_n, search_l)?; - - for (output_idx, query_idx) in range.enumerate() { - let query_vec = queries.row(query_idx); - let (_, ref ast_expr) = predicates[query_idx]; - - let strategy = InlineBetaStrategy::new(beta, common::FullPrecision); - let query_vec_owned = query_vec.to_vec(); - let filtered_query: FilteredQuery> = - FilteredQuery::new(query_vec_owned, ast_expr.clone()); - - let start = std::time::Instant::now(); - - let mut distances = vec![0.0f32; search_n]; - let result_ids = ids.row_mut(output_idx); - let mut result_buffer = search_output_buffer::IdDistance::new(result_ids, &mut distances); - - let stats = index - .search( - &strategy, - &ctx, - &filtered_query, - &search_params, - &mut result_buffer, - ) - .await?; - - let result_count = stats.result_count.into_usize(); - result_ids[result_count..].fill(u32::MAX); - distances[result_count..].fill(f32::MAX); + let searcher = Arc::new(FilteredSearcher { + index: index.clone(), + queries: Arc::new(queries.clone()), + predicates: Arc::new(predicates.to_vec()), + beta, + }); + + let parameters = SearchParams::new_default(search_n, search_l)?; + let setup = search_api::Setup { + threads: num_threads, + tasks: num_threads, + reps, + }; - latencies.push(MicroSeconds::from(start.elapsed())); - comparisons.push(stats.cmps); - hops.push(stats.hops); - all_distances.push(distances); - } + let mut results = search_api::search_all( + searcher, + [search_api::Run::new(parameters, setup)], + FilteredSearchAggregator { + groundtruth, + predicates, + recall_k, + }, + )?; - Ok(SearchLocalResults { - ids, - distances: all_distances, - latencies, - comparisons, - hops, - }) + results.pop().ok_or_else(|| anyhow::anyhow!("no search results")) } #[derive(Debug, Serialize)] pub struct BuildParamsStats { @@ -970,69 +906,79 @@ impl std::fmt::Display for DocumentIndexStats { // Parallel Build Support // ================================ -fn make_progress_bar( - nrows: usize, - draw_target: indicatif::ProgressDrawTarget, -) -> anyhow::Result { - let progress = ProgressBar::with_draw_target(Some(nrows as u64), draw_target); - progress.set_style(ProgressStyle::with_template( - "Building [{elapsed_precise}] {wide_bar} {percent}", - )?); - Ok(progress) -} - -/// Control block for parallel document insertion. -/// Manages work distribution and progress tracking across multiple tasks. -struct DocumentControlBlock { +/// Implements [`Build`] for parallel document insertion into a [`DiskANNIndex`] +/// backed by a [`DocumentProvider`]. Each call to [`Build::build`] inserts a +/// contiguous range of vectors and their associated attributes. +struct DocumentIndexBuilder { + index: Arc>, data: Arc>, attributes: Arc>>, - position: AtomicUsize, - cancel: AtomicBool, - progress: ProgressBar, + strategy: DocumentInsertStrategy, } -impl DocumentControlBlock { +impl DocumentIndexBuilder { fn new( + index: Arc>, data: Arc>, attributes: Arc>>, - draw_target: indicatif::ProgressDrawTarget, - ) -> anyhow::Result> { - let nrows = data.nrows(); - Ok(Arc::new(Self { + strategy: DocumentInsertStrategy, + ) -> Arc { + Arc::new(Self { + index, data, attributes, - position: AtomicUsize::new(0), - cancel: AtomicBool::new(false), - progress: make_progress_bar(nrows, draw_target)?, - })) + strategy, + }) } +} - /// Return the next document data to insert: (id, vector_slice, attributes). - fn next(&self) -> Option<(usize, &[T], Vec)> { - let cancel = self.cancel.load(Ordering::Relaxed); - if cancel { - None - } else { - let i = self.position.fetch_add(1, Ordering::Relaxed); - match self.data.get_row(i) { - Some(row) => { - let attrs = self.attributes.get(i).cloned().unwrap_or_default(); - self.progress.inc(1); - Some((i, row, attrs)) - } - None => None, - } +impl Build for DocumentIndexBuilder +where + DP: diskann::provider::DataProvider + + for<'doc> diskann::provider::SetElement> + + AsyncFriendly, + for<'doc> DocumentInsertStrategy: + diskann::graph::glue::InsertStrategy>, + DocumentInsertStrategy: AsyncFriendly, + T: AsyncFriendly, +{ + type Output = (); + + fn num_data(&self) -> usize { + self.data.nrows() + } + + async fn build(&self, range: std::ops::Range) -> diskann::ANNResult { + let ctx = DefaultContext; + for i in range { + let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + let doc = Document::new(self.data.row(i), attrs); + self.index + .insert(self.strategy, &ctx, &(i as u32), &doc) + .await?; } + Ok(()) + } +} + +/// Adapts an already-constructed [`ProgressBar`] into the [`AsProgress`] / [`Progress`] +/// traits expected by [`build_tracked`]. +struct IndicatifAsProgress(ProgressBar); + +struct IndicatifProgress(ProgressBar); + +impl Progress for IndicatifProgress { + fn progress(&self, handled: usize) { + self.0.inc(handled as u64); } - /// Tell all users of the control block to cancel and return early. - fn cancel(&self) { - self.cancel.store(true, Ordering::Relaxed); + fn finish(&self) { + self.0.finish(); } } -impl Drop for DocumentControlBlock { - fn drop(&mut self) { - self.progress.finish(); +impl AsProgress for IndicatifAsProgress { + fn as_progress(&self, _max: usize) -> Arc { + Arc::new(IndicatifProgress(self.0.clone())) } } From 670782f4bec3079a5ee144fca16a160b56dea011 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:47:17 +0530 Subject: [PATCH 07/39] Rename struct for recall metrics --- diskann-benchmark/src/backend/document_index/benchmark.rs | 4 ++-- diskann-benchmark/src/utils/recall.rs | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 51b3467bf..3a88ba9b7 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -58,7 +58,7 @@ use crate::{ utils::{ self, datafiles::{self, BinFile}, - recall, + recall::SerializableRecallMetrics, }, }; @@ -810,7 +810,7 @@ pub struct SearchRunStats { pub num_queries: usize, pub search_n: usize, pub search_l: usize, - pub recall: recall::RecallMetrics, + pub recall: SerializableRecallMetrics, pub qps: Vec, pub wall_clock_time: Vec, pub mean_latency: f64, diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 50ef7e430..a7e0e39ab 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -2,15 +2,13 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ - -pub(crate) use benchmark_core::recall::knn; use diskann_benchmark_core as benchmark_core; use serde::Serialize; #[derive(Debug, Clone, Serialize)] #[non_exhaustive] -pub(crate) struct RecallMetrics { +pub(crate) struct SerializableRecallMetrics(benchmark_core::recall::RecallMetrics) { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, /// The `n` value for `k-recall-at-n`. @@ -25,7 +23,7 @@ pub(crate) struct RecallMetrics { pub(crate) maximum: usize, } -impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { +impl From<&benchmark_core::recall::RecallMetrics> for SerializableRecallMetrics { fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { Self { recall_k: m.recall_k, From 6c2c967d6cbd1d37d353054864ccfc2d8133d63f Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:47:51 +0530 Subject: [PATCH 08/39] Use copyIds --- .../document_insert_strategy.rs | 39 ++----------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 850976a32..9a3bad9a0 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -9,13 +9,7 @@ use std::marker::PhantomData; use diskann::{ - graph::{ - glue::{ - ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, - }, - SearchOutputBuffer, - }, - neighbor::Neighbor, + graph::glue::{self, ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, ANNResult, }; @@ -160,33 +154,6 @@ where } } -#[derive(Debug, Default, Clone, Copy)] -pub struct CopyIdsForDocument; - -impl<'doc, A, VT> SearchPostProcess> for CopyIdsForDocument -where - A: BuildQueryComputer>, - VT: ?Sized, -{ - type Error = std::convert::Infallible; - - fn post_process( - &self, - _accessor: &mut A, - _query: &Document<'doc, VT>, - _computer: &>>::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl std::future::Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - let count = output.extend(candidates.map(|n| (n.id, n.distance))); - std::future::ready(Ok(count)) - } -} - impl<'doc, Inner, DP, VT> SearchStrategy>, Document<'doc, VT>> for DocumentInsertStrategy @@ -196,7 +163,7 @@ where VT: Sync + Send + ?Sized + 'static, { type QueryComputer = Inner::QueryComputer; - type PostProcessor = CopyIdsForDocument; + type PostProcessor = glue::CopyIds; type SearchAccessorError = Inner::SearchAccessorError; type SearchAccessor<'a> = DocumentSearchAccessor, VT>; @@ -212,7 +179,7 @@ where } fn post_processor(&self) -> Self::PostProcessor { - CopyIdsForDocument + glue::CopyIds } } From d13dc7f303c08535adfc67a18bf6417a1051d17c Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:48:10 +0530 Subject: [PATCH 09/39] Use renamed struct in SearchResults --- diskann-benchmark/src/backend/index/result.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 21d74f915..7f9514613 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -117,7 +117,7 @@ pub(super) struct SearchResults { pub(super) mean_latencies: Vec, pub(super) p90_latencies: Vec, pub(super) p99_latencies: Vec, - pub(super) recall: utils::recall::RecallMetrics, + pub(super) recall: utils::recall::SerializableRecallMetrics, pub(super) mean_cmps: f32, pub(super) mean_hops: f32, } From bd19bdebece118c70165f53d4fd32af80b76b580 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:49:04 +0530 Subject: [PATCH 10/39] Evaluate query progressively without flattening and cloning the json kv map --- diskann-tools/src/utils/ground_truth.rs | 83 +++++++++++++------------ 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 8c2fa29f6..678e83dfd 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -29,56 +29,57 @@ use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; -/// Expands a JSON object with array-valued fields into multiple objects with scalar values. -/// For example: {"country": ["AU", "NZ"], "year": 2007} -/// becomes: [{"country": "AU", "year": 2007}, {"country": "NZ", "year": 2007}] +/// Evaluates a query expression against a label, expanding array-valued fields by recursion. /// -/// If multiple fields have arrays, all combinations are generated. -fn expand_array_fields(value: &Value) -> Vec { - match value { +/// For each key in the JSON object, if the value is an array the expression is evaluated +/// against one element at a time (any-match semantics) without materialising the full +/// Cartesian product. Non-object labels are evaluated directly. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + match label { Value::Object(map) => { - // Start with a single empty object - let mut results: Vec> = vec![Map::new()]; - - for (key, val) in map.iter() { - if let Value::Array(arr) = val { - // Expand: for each existing result, create copies for each array element - let mut new_results: Vec> = Vec::new(); - for existing in results.iter() { - for item in arr.iter() { - let mut new_map: Map = existing.clone(); - new_map.insert(key.clone(), item.clone()); - new_results.push(new_map); - } - } - // If array is empty, keep existing results without this key - if !arr.is_empty() { - results = new_results; - } - } else { - // Non-array field: add to all existing results - for existing in results.iter_mut() { - existing.insert(key.clone(), val.clone()); + let entries: Vec<(&String, &Value)> = map.iter().collect(); + eval_map_recursive(query_expr, &entries, Map::new()) + } + _ => eval_query_expr(query_expr, label), + } +} + +/// Walk `entries` one field at a time, accumulating scalar values into `current`. +/// +/// * Scalar fields are inserted directly and the walk continues with the remaining entries. +/// * Array fields branch once per element; evaluation short-circuits on the first branch +/// that returns `true`. +/// * An empty array is treated as an absent field (preserving the previous behaviour). +/// * When all fields have been consumed, `eval_query_expr` is called on the accumulated object. +fn eval_map_recursive( + query_expr: &ASTExpr, + entries: &[(&String, &Value)], + mut current: Map, +) -> bool { + match entries { + [] => eval_query_expr(query_expr, &Value::Object(current)), + [(key, Value::Array(arr)), rest @ ..] => { + if arr.is_empty() { + // Omit this key, matching the original behaviour for empty arrays. + eval_map_recursive(query_expr, rest, current) + } else { + for item in arr { + let mut branch = current.clone(); + branch.insert((*key).clone(), item.clone()); + if eval_map_recursive(query_expr, rest, branch) { + return true; } } + false } - - results.into_iter().map(Value::Object).collect() } - // If not an object, return as-is - _ => vec![value.clone()], + [(key, val), rest @ ..] => { + current.insert((*key).clone(), (*val).clone()); + eval_map_recursive(query_expr, rest, current) + } } } -/// Evaluates a query expression against a label, expanding array fields first. -/// Returns true if any expanded variant matches the query. -fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { - let expanded = expand_array_fields(label); - expanded - .iter() - .any(|item| eval_query_expr(query_expr, item)) -} - pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, From 8e3a89bdcfb4754107b4e907bfb9a45744ce272b Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:49:39 +0530 Subject: [PATCH 11/39] Use config api to validate values --- .../src/inputs/document_index.rs | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index b1a36e48a..11e36d5e3 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -51,15 +51,17 @@ impl CheckDeserialization for DocumentBuildParams { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.data.check_deserialization(checker)?; self.data_labels.check_deserialization(checker)?; - if self.max_degree == 0 { - return Err(anyhow::anyhow!("max_degree must be > 0")); - } - if self.l_build == 0 { - return Err(anyhow::anyhow!("l_build must be > 0")); - } - if self.alpha <= 0.0 { - return Err(anyhow::anyhow!("alpha must be > 0")); - } + + // checking if the max_degree, l_build and alpha values are valid. + use diskann::graph::config::{Builder, MaxDegree, PruneKind}; + let mut builder = Builder::new( + self.max_degree, + MaxDegree::Value(self.max_degree), + self.l_build, + PruneKind::Occluding, + ); + builder.alpha(self.alpha); + builder.build()?; Ok(()) } } From 40b131485f908d4f9d2be60e88ad04e6104ffc41 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:49:58 +0530 Subject: [PATCH 12/39] Specify number of threads explicitly --- diskann-benchmark/example/document-filter.json | 3 ++- diskann-benchmark/src/inputs/document_index.rs | 5 ----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json index d6e9e13b2..d60cd4806 100644 --- a/diskann-benchmark/example/document-filter.json +++ b/diskann-benchmark/example/document-filter.json @@ -13,7 +13,8 @@ "distance": "squared_l2", "max_degree": 32, "l_build": 50, - "alpha": 1.2 + "alpha": 1.2, + "num_threads": 4 }, "search": { "queries": "disk_index_sample_query_10pts.fbin", diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index 11e36d5e3..f1d2d7c67 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -39,14 +39,9 @@ pub(crate) struct DocumentBuildParams { pub(crate) max_degree: usize, pub(crate) l_build: usize, pub(crate) alpha: f32, - #[serde(default = "default_num_threads")] pub(crate) num_threads: usize, } -fn default_num_threads() -> usize { - 1 -} - impl CheckDeserialization for DocumentBuildParams { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.data.check_deserialization(checker)?; From 3d2397254c882a1299888c26f54ba7a782658daa Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:51:23 +0530 Subject: [PATCH 13/39] Error when the visit of input expression fails when creating encodedFilterExpression --- .../encoded_filter_expr.rs | 18 ++++---- .../encoded_document_accessor.rs | 10 +--- .../inline_beta_search/inline_beta_filter.rs | 46 +++++-------------- 3 files changed, 22 insertions(+), 52 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index 370ef25ae..b621e347c 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -5,6 +5,8 @@ use std::sync::{Arc, RwLock}; +use diskann::ANNResult; + use crate::{ encoded_attribute_provider::{ ast_id_expr::ASTIdExpr, ast_label_id_mapper::ASTLabelIdMapper, @@ -14,21 +16,19 @@ use crate::{ }; pub(crate) struct EncodedFilterExpr { - ast_id_expr: Option>, + ast_id_expr: ASTIdExpr, } impl EncodedFilterExpr { - pub fn new(ast_expr: &ASTExpr, attribute_map: Arc>) -> Self { + pub fn try_create(ast_expr: &ASTExpr, attribute_map: Arc>) -> ANNResult { let mut mapper = ASTLabelIdMapper::new(attribute_map); - match ast_expr.accept(&mut mapper) { - Ok(ast_id_expr) => Self { - ast_id_expr: Some(ast_id_expr), - }, - Err(_e) => Self { ast_id_expr: None }, - } + let ast_id_expr = ast_expr.accept(&mut mapper)?; + Ok(Self { + ast_id_expr, + }) } - pub(crate) fn encoded_filter_expr(&self) -> &Option> { + pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { &self.ast_id_expr } } diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 1def9a406..50e835dc7 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -220,20 +220,12 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone()); - let is_valid_filter = id_query.encoded_filter_expr().is_some(); - if !is_valid_filter { - tracing::warn!( - "Failed to convert {} into an id expr. This will now be an unfiltered search.", - from.filter_expr() - ); - } + let id_query = EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; Ok(InlineBetaComputer::new( inner_computer, self.beta_value, id_query, - is_valid_filter, )) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 76a78de67..6093451c2 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -115,35 +115,20 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - if self.is_valid_filter { - match self - .filter_expr - .encoded_filter_expr() - .as_ref() - .unwrap() - .accept(&pred_eval) - { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim - } - } - Err(_) => { - //If predicate evaluation fails for any reason, we simply revert - //to unfiltered search. - tracing::warn!("Predicate evaluation failed"); + match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { + Ok(matched) => { + if matched { + sim * self.beta_value + } else { sim } } - } else { - //If predicate evaluation fails, we will return the score returned by the - //inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim + Err(_) => { + //If predicate evaluation fails for any reason, we simply revert + //to unfiltered search. + tracing::warn!("Predicate evaluation failed"); + sim + } } } } @@ -182,14 +167,7 @@ where let doc = accessor.get_element(candidate.id).await?; let pe = PredicateEvaluator::new(doc.attributes()); - if computer.is_valid_filter() - && computer - .filter_expr() - .encoded_filter_expr() - .as_ref() - .unwrap() - .accept(&pe)? - { + if computer.filter_expr().encoded_filter_expr().accept(&pe)? { filtered_candidates.push(Neighbor::new(candidate.id, candidate.distance)); } } From 9e35ccba00fc17eb5d29c6e84973ef5bd705cc1c Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:52:07 +0530 Subject: [PATCH 14/39] remove new runtime method added, use method in benchmark::core --- diskann-benchmark/src/utils/tokio.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 21c78abb2..72dbeb918 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -3,13 +3,6 @@ * Licensed under the MIT license. */ -/// Create a generic multi-threaded runtime with `num_threads`. -pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { - Ok(tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads) - .build()?) -} - /// Create a current-thread runtime and block on the given future. /// Only for functions that don't need multi-threading pub(crate) fn block_on(future: F) -> F::Output { From 5a8c5601720ffd78421f30d331e481e85a4f8100 Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:55:10 +0530 Subject: [PATCH 15/39] Use dispatch rule to validate benchmark type support --- .../src/backend/document_index/benchmark.rs | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 3a88ba9b7..96d9f42e3 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -64,36 +64,48 @@ use crate::{ /// Register the document index benchmarks. pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>( - "document-index-build", - |job, checkpoint, out| { - let stats = job.run(checkpoint, out)?; + benchmarks.register::>( + "document-index-build-f32", + |job, _checkpoint, out| { + let stats = job.run(out)?; Ok(serde_json::to_value(stats)?) }, ); } /// Document index benchmark job. -pub(super) struct DocumentIndexJob<'a> { +pub(super) struct DocumentIndexJob<'a, T> { input: &'a DocumentIndexBuild, + _type: std::marker::PhantomData, } -impl<'a> DocumentIndexJob<'a> { +impl<'a, T> DocumentIndexJob<'a, T> { fn new(input: &'a DocumentIndexBuild) -> Self { - Self { input } + Self { + input, + _type: std::marker::PhantomData, + } } } -impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static> { - type Type<'a> = DocumentIndexJob<'a>; +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static, T> { + type Type<'a> = DocumentIndexJob<'a, T>; } // Dispatch from the concrete input type -impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { +impl<'a, T> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a, T> +where + datatype::Type: DispatchRule, +{ type Error = std::convert::Infallible; fn try_match(_from: &&'a DocumentIndexBuild) -> Result { - Ok(MatchScore(1)) + match _from.build.data_type { + datatype::DataType::Float32 => Ok(MatchScore(0)), + datatype::DataType::UInt8 => Ok(MatchScore(0)), + datatype::DataType::Int8 => Ok(MatchScore(0)), + _ => Err(datatype::MATCH_FAIL), + } } fn convert(from: &'a DocumentIndexBuild) -> Result { @@ -109,7 +121,10 @@ impl<'a> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a> { } // Central dispatch mapping from Any -impl<'a> DispatchRule<&'a Any> for DocumentIndexJob<'a> { +impl<'a, T> DispatchRule<&'a Any> for DocumentIndexJob<'a, T> +where + datatype::Type: DispatchRule, +{ type Error = anyhow::Error; fn try_match(from: &&'a Any) -> Result { From de2036535cab301e9a5cd1c3c6ac202c0d49457d Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:56:34 +0530 Subject: [PATCH 16/39] Use compute_medioid helper --- .../src/backend/document_index/benchmark.rs | 127 ++++++------------ 1 file changed, 39 insertions(+), 88 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 96d9f42e3..692cbc0a0 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -50,6 +50,8 @@ use diskann_providers::model::graph::provider::async_::{ inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, }; use diskann_utils::views::Matrix; +use diskann_vector::PureDistanceFunction; +use diskann_vector::distance::SquaredL2; use indicatif::{ProgressBar, ProgressStyle}; use serde::Serialize; @@ -146,105 +148,54 @@ fn hashmap_to_attributes(map: std::collections::HashMap) .collect() } -/// Compute the index of the row closest to the medoid (centroid) of the data. -fn compute_medoid_index(data: &Matrix) -> usize +fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option where - T: bytemuck::Pod + Copy + 'static, + for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, { - use diskann_vector::{distance::SquaredL2, PureDistanceFunction}; - - let dim = data.ncols(); - if dim == 0 || data.nrows() == 0 { - return 0; - } - - // Compute the centroid (mean of all rows) as f64 for precision - let mut sum = vec![0.0f64; dim]; - for i in 0..data.nrows() { - let row = data.row(i); - for (j, &v) in row.iter().enumerate() { - // Convert T to f64 for summation using bytemuck - let f64_val: f64 = if std::any::TypeId::of::() == std::any::TypeId::of::() { - let f32_val: f32 = bytemuck::cast(v); - f32_val as f64 - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let u8_val: u8 = bytemuck::cast(v); - u8_val as f64 - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let i8_val: i8 = bytemuck::cast(v); - i8_val as f64 - } else { - 0.0 - }; - sum[j] += f64_val; + let mut min_dist = f32::INFINITY; + let mut min_ind = x.nrows(); + for (i, row) in x.row_iter().enumerate() { + let dist = SquaredL2::evaluate(row, y); + if dist < min_dist { + min_dist = dist; + min_ind = i; } } - // Convert centroid to f32 and compute distances - let centroid_f32: Vec = sum - .iter() - .map(|s| (s / data.nrows() as f64) as f32) - .collect(); - - // Find the row closest to the centroid - let mut min_dist = f32::MAX; - let mut medoid_idx = 0; - for i in 0..data.nrows() { - let row = data.row(i); - let row_f32: Vec = row - .iter() - .map(|&v| { - if std::any::TypeId::of::() == std::any::TypeId::of::() { - bytemuck::cast(v) - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let u8_val: u8 = bytemuck::cast(v); - u8_val as f32 - } else if std::any::TypeId::of::() == std::any::TypeId::of::() { - let i8_val: i8 = bytemuck::cast(v); - i8_val as f32 - } else { - 0.0 - } - }) - .collect(); - let d = SquaredL2::evaluate(centroid_f32.as_slice(), row_f32.as_slice()); - if d < min_dist { - min_dist = d; - medoid_idx = i; - } + // No closest neighbor found. + if min_ind == x.nrows() { + None + } else { + Some(min_ind) } - - medoid_idx } -impl<'a> DocumentIndexJob<'a> { - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - // Print the input description - writeln!(output, "{}", self.input)?; +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> anyhow::Result +where + T: bytemuck::Pod + Copy + 'static + ComputeMedoid, + for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, +{ + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return Ok(0); + } - let build = &self.input.build; + // returns row closes to centroid. + let medoid = T::compute_medoid(data.as_view()); - // Dispatch based on data type - retain original type without conversion - match build.data_type { - DataType::Float32 => self.run_typed::(output), - DataType::UInt8 => self.run_typed::(output), - DataType::Int8 => self.run_typed::(output), - _ => Err(anyhow::anyhow!( - "Unsupported data type: {:?}. Supported types: float32, uint8, int8.", - build.data_type - )), - } - } + find_medoid_index(data.as_view(), medoid.as_slice()) + .ok_or_else(|| anyhow::anyhow!("Failed to find medoid index: no closest row found")) +} - fn run_typed(self, mut output: &mut dyn Output) -> Result +impl<'a, T> DocumentIndexJob<'a, T> { + fn run(self, mut output: &mut dyn Output) -> Result where - T: bytemuck::Pod + Copy + Send + Sync + 'static + std::fmt::Debug, - T: diskann::graph::SampleableForStart + diskann_utils::future::AsyncFriendly, - T: diskann::utils::VectorRepr + diskann_utils::sampling::WithApproximateNorm, + T: diskann::utils::VectorRepr + + diskann::graph::SampleableForStart + + diskann_utils::sampling::WithApproximateNorm + + 'static, + for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]> { let build = &self.input.build; @@ -326,7 +277,7 @@ impl<'a> DocumentIndexJob<'a> { // Store attributes for the start point (medoid) // Start points are stored at indices num_vectors..num_vectors+frozen_points - let medoid_idx = compute_medoid_index(&data); + let medoid_idx = compute_medoid_index(&data)?; let start_point_id = num_vectors as u32; // Start points begin at max_points let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); use diskann_label_filter::traits::attribute_store::AttributeStore; From 9c477cdefb890e4ab2213e4d162d400d846c214f Mon Sep 17 00:00:00 2001 From: Sampath Rajendra Date: Thu, 12 Mar 2026 13:56:57 +0530 Subject: [PATCH 17/39] Remaining changes from search + build api refactor --- .../src/backend/document_index/benchmark.rs | 99 +++++++++---------- 1 file changed, 44 insertions(+), 55 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 692cbc0a0..1f64ffccb 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -10,14 +10,13 @@ use std::io::Write; use std::num::NonZeroUsize; use std::path::Path; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, - search_output_buffer, DiskANNIndex, SearchParams, StartPointStrategy, + search_output_buffer, DiskANNIndex, SearchOutputBuffer, SearchParams, StartPointStrategy, }, provider::DefaultContext, }; @@ -31,8 +30,8 @@ use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, registry::Benchmarks, - utils::{datatype::DataType, percentiles, MicroSeconds}, - Any, Checkpoint, + utils::{datatype, percentiles, MicroSeconds}, + Any, }; use diskann_label_filter::{ attribute::{Attribute, AttributeValue}, @@ -49,6 +48,8 @@ use diskann_providers::model::graph::provider::async_::{ common::{self, NoStore, TableBasedDeletes}, inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, }; +use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; +use diskann_utils::views::MatrixView; use diskann_utils::views::Matrix; use diskann_vector::PureDistanceFunction; use diskann_vector::distance::SquaredL2; @@ -58,7 +59,6 @@ use serde::Serialize; use crate::{ inputs::document_index::DocumentIndexBuild, utils::{ - self, datafiles::{self, BinFile}, recall::SerializableRecallMetrics, }, @@ -296,58 +296,33 @@ impl<'a, T> DocumentIndexJob<'a, T> { )?; let timer = std::time::Instant::now(); - let insert_strategy: DocumentInsertStrategy<_, [T]> = - DocumentInsertStrategy::new(common::FullPrecision); - let rt = utils::tokio::runtime(build.num_threads)?; - - // Create control block for parallel work distribution + let rt = tokio::runtime(build.num_threads)?; let data_arc = Arc::new(data); let attributes_arc = Arc::new(attributes); - let control_block = DocumentControlBlock::new( + + let builder = DocumentIndexBuilder::new( + doc_index.clone(), data_arc.clone(), attributes_arc.clone(), - output.draw_target(), - )?; - - let num_tasks = build.num_threads; - let insert_latencies = rt.block_on(async { - let tasks: Vec<_> = (0..num_tasks) - .map(|_| { - let block = control_block.clone(); - let index = doc_index.clone(); - let strategy = insert_strategy; - tokio::spawn(async move { - let mut latencies = Vec::::new(); - let ctx = DefaultContext; - loop { - match block.next() { - Some((id, vector, attrs)) => { - let doc = Document::new(vector, attrs); - let start = std::time::Instant::now(); - let result = - index.insert(strategy, &ctx, &(id as u32), &doc).await; - latencies.push(MicroSeconds::from(start.elapsed())); - - if let Err(e) = result { - block.cancel(); - return Err(e); - } - } - None => return Ok(latencies), - } - } - }) - }) - .collect(); - - // Collect results from all tasks - let mut all_latencies = Vec::with_capacity(num_vectors); - for task in tasks { - let task_latencies = task.await??; - all_latencies.extend(task_latencies); - } - Ok::<_, anyhow::Error>(all_latencies) - })?; + DocumentInsertStrategy::new(common::FullPrecision), + ); + let num_tasks = NonZeroUsize::new(build.num_threads).unwrap_or(diskann::utils::ONE); + let parallelism = Parallelism::dynamic(diskann::utils::ONE, num_tasks); + let progress = IndicatifAsProgress({ + let bar = ProgressBar::with_draw_target(Some(num_vectors as u64), output.draw_target()); + bar.set_style( + ProgressStyle::with_template("Building [{elapsed_precise}] {wide_bar} {percent}") + .expect("valid template"), + ); + bar + }); + let build_results = + build::build_tracked(builder, parallelism, &rt, Some(&progress))?; + let insert_latencies: Vec = build_results + .take_output() + .into_iter() + .map(|r| r.latency) + .collect(); let build_time: MicroSeconds = timer.elapsed().into(); writeln!(output, " Index built in {} s", build_time.as_seconds())?; @@ -832,7 +807,17 @@ impl std::fmt::Display for DocumentIndexStats { writeln!( f, " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", - "L", "KNN", "Avg Cmps", "Avg Hops", "QPS -mean(max)", "Avg Latency", "p99 Latency", "Recall", "Threads", "Queries", "WallClock(s)" + "L", + "KNN", + "Avg Cmps", + "Avg Hops", + "QPS -mean(max)", + "Avg Latency", + "p99 Latency", + "Recall", + "Threads", + "Queries", + "WallClock(s)" )?; for s in &self.search { let mean_qps = if s.qps.is_empty() { @@ -844,7 +829,11 @@ impl std::fmt::Display for DocumentIndexStats { let mean_wall_clock = if s.wall_clock_time.is_empty() { 0.0 } else { - s.wall_clock_time.iter().map(|t| t.as_seconds()).sum::() / s.wall_clock_time.len() as f64 + s.wall_clock_time + .iter() + .map(|t| t.as_seconds()) + .sum::() + / s.wall_clock_time.len() as f64 }; writeln!( f, From 7f244328209c65d13f67c8d8d9a2a78f0733c9b4 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 12 Mar 2026 15:05:33 +0530 Subject: [PATCH 18/39] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- diskann-tools/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index 1b4b3408e..0803dae32 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -5,6 +5,7 @@ version.workspace = true authors.workspace = true description.workspace = true documentation.workspace = true +license.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 9c1870aff4667865d6f081511b54fe3180f22bfc Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 12 Mar 2026 15:17:46 +0530 Subject: [PATCH 19/39] fix merge errors --- .../src/backend/document_index/benchmark.rs | 49 +++++++++++++++++-- .../inline_beta_search/inline_beta_filter.rs | 7 --- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index ccde2efbf..b12445b16 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -486,7 +486,7 @@ where O: diskann::graph::SearchOutputBuffer + Send, { let ctx = DefaultContext; - let query_vec = self.queries.row(index); + let query_vec = self.queries.row(index).to_vec(); let (_, ref ast_expr) = self.predicates[index]; let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); @@ -707,19 +707,60 @@ pub struct BuildParamsStats { pub alpha: f32, } +/// Helper module for serializing arrays as compact single-line JSON strings +mod compact_array { + use serde::Serializer; + + pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } + + pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a string containing the compact JSON array + let compact = serde_json::to_string(vec).unwrap_or_default(); + serializer.serialize_str(&compact) + } +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub result_ids: Vec, + #[serde(serialize_with = "compact_array::serialize_f32_vec")] + pub result_distances: Vec, + #[serde(serialize_with = "compact_array::serialize_u32_vec")] + pub groundtruth_ids: Vec, +} + /// Results from a single search configuration (one search_l value). #[derive(Debug, Serialize)] pub struct SearchRunStats { pub num_threads: usize, + pub num_queries: usize, pub search_n: usize, pub search_l: usize, pub recall: SerializableRecallMetrics, pub qps: Vec, + pub wall_clock_time: Vec, pub mean_latency: f64, pub p90_latency: MicroSeconds, pub p99_latency: MicroSeconds, pub mean_cmps: f32, pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, } #[derive(Debug, Serialize)] @@ -796,7 +837,7 @@ impl std::fmt::Display for DocumentIndexStats { }; writeln!( f, - " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8}", + " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", s.search_l, s.search_n, s.mean_cmps, @@ -806,7 +847,9 @@ impl std::fmt::Display for DocumentIndexStats { s.mean_latency, s.p99_latency, s.recall.average, - s.num_threads + s.num_threads, + s.num_queries, + mean_wall_clock )?; } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 6093451c2..58087b24e 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -79,7 +79,6 @@ pub struct InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, - is_valid_filter: bool, //optimization to avoid evaluating empty predicates. } impl InlineBetaComputer { @@ -87,23 +86,17 @@ impl InlineBetaComputer { inner_computer: Inner, beta_value: f32, filter_expr: EncodedFilterExpr, - is_valid_filter: bool, ) -> Self { Self { inner_computer, beta_value, filter_expr, - is_valid_filter, } } pub(crate) fn filter_expr(&self) -> &EncodedFilterExpr { &self.filter_expr } - - pub(crate) fn is_valid_filter(&self) -> bool { - self.is_valid_filter - } } impl PreprocessedDistanceFunction, f32> From 1591956df376ac2ef533be349ca6cd42f7a533ca Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 12 Mar 2026 15:37:05 +0530 Subject: [PATCH 20/39] Fix merge errors white recall metrics --- diskann-benchmark/src/utils/recall.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index a7e0e39ab..9628a6205 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -8,7 +8,7 @@ use serde::Serialize; #[derive(Debug, Clone, Serialize)] #[non_exhaustive] -pub(crate) struct SerializableRecallMetrics(benchmark_core::recall::RecallMetrics) { +pub(crate) struct SerializableRecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, /// The `n` value for `k-recall-at-n`. From f13e8498727c24b92daf820bb7a374dc56cc3c01 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 15:40:04 +0530 Subject: [PATCH 21/39] Remove the need for Vec variants --- .../src/backend/document_index/benchmark.rs | 37 ++++++------ .../encoded_document_accessor.rs | 10 ++-- .../inline_beta_search/inline_beta_filter.rs | 11 ++-- diskann-label-filter/src/query.rs | 10 ++-- .../provider/async_/inmem/full_precision.rs | 56 ------------------- 5 files changed, 35 insertions(+), 89 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index b12445b16..38da6bfc7 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -22,9 +22,7 @@ use diskann::{ }; use diskann_benchmark_core::{ build::{self, AsProgress, Build, Parallelism, Progress}, - recall, - search as search_api, - tokio, + recall, search as search_api, tokio, }; use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, @@ -48,11 +46,11 @@ use diskann_providers::model::graph::provider::async_::{ common::{self, NoStore, TableBasedDeletes}, inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, }; -use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; -use diskann_utils::views::MatrixView; use diskann_utils::views::Matrix; -use diskann_vector::PureDistanceFunction; +use diskann_utils::views::MatrixView; +use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; use diskann_vector::distance::SquaredL2; +use diskann_vector::PureDistanceFunction; use indicatif::{ProgressBar, ProgressStyle}; use serde::Serialize; @@ -148,7 +146,7 @@ fn hashmap_to_attributes(map: std::collections::HashMap) .collect() } -fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option +fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option where for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, { @@ -195,7 +193,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { + diskann::graph::SampleableForStart + diskann_utils::sampling::WithApproximateNorm + 'static, - for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]> + for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]>, { let build = &self.input.build; @@ -316,8 +314,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { ); bar }); - let build_results = - build::build_tracked(builder, parallelism, &rt, Some(&progress))?; + let build_results = build::build_tracked(builder, parallelism, &rt, Some(&progress))?; let insert_latencies: Vec = build_results .take_output() .into_iter() @@ -455,11 +452,15 @@ where impl search_api::Search for FilteredSearcher where - DP: diskann::provider::DataProvider - + Send + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + Sync + 'static, - InlineBetaStrategy: diskann::graph::glue::SearchStrategy>, u32>, + for<'a> InlineBetaStrategy: + diskann::graph::glue::SearchStrategy, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -486,7 +487,7 @@ where O: diskann::graph::SearchOutputBuffer + Send, { let ctx = DefaultContext; - let query_vec = self.queries.row(index).to_vec(); + let query_vec = self.queries.row(index); let (_, ref ast_expr) = self.predicates[index]; let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); @@ -671,8 +672,8 @@ where > + Send + Sync + 'static, - InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>>, + for<'a> InlineBetaStrategy: + diskann::graph::glue::SearchStrategy>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), @@ -698,7 +699,9 @@ where }, )?; - results.pop().ok_or_else(|| anyhow::anyhow!("no search results")) + results + .pop() + .ok_or_else(|| anyhow::anyhow!("no search results")) } #[derive(Debug, Serialize)] pub struct BuildParamsStats { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 50e835dc7..0b658fd4d 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -11,7 +11,7 @@ use diskann::{ provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, ANNError, ANNErrorKind, }; -use diskann_utils::{future::AsyncFriendly, Reborrow}; +use diskann_utils::Reborrow; use roaring::RoaringTreemap; use crate::traits::attribute_accessor::AttributeAccessor; @@ -204,17 +204,17 @@ where } } -impl BuildQueryComputer> for EncodedDocumentAccessor +impl<'a, IA, Q> BuildQueryComputer> for EncodedDocumentAccessor where IA: BuildQueryComputer, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; fn build_query_computer( &self, - from: &FilteredQuery, + from: &FilteredQuery<'a, Q>, ) -> Result { let inner_computer = self .inner_accessor @@ -234,7 +234,7 @@ impl ExpandBeam for EncodedDocumentAccessor where IA: Accessor, EncodedDocumentAccessor: BuildQueryComputer + AsNeighbor, - Q: Clone + AsyncFriendly, + Q: Send + Sync + ?Sized, { } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 58087b24e..8d1784029 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -9,7 +9,6 @@ use diskann::neighbor::Neighbor; use diskann::provider::{Accessor, BuildQueryComputer, DataProvider}; use diskann::ANNError; -use diskann_utils::future::AsyncFriendly; use diskann_vector::PreprocessedDistanceFunction; use roaring::RoaringTreemap; @@ -36,12 +35,12 @@ impl InlineBetaStrategy { } impl - SearchStrategy>, FilteredQuery> + SearchStrategy>, FilteredQuery<'_, Q>> for InlineBetaStrategy where DP: DataProvider, Strategy: SearchStrategy, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type QueryComputer = InlineBetaComputer; type PostProcessor = FilterResults; @@ -130,19 +129,19 @@ pub struct FilterResults { inner_post_processor: IPP, } -impl SearchPostProcess, FilteredQuery> +impl<'a, Q, IA, IPP> SearchPostProcess, FilteredQuery<'a, Q>> for FilterResults where IA: BuildQueryComputer, - Q: Clone + AsyncFriendly, IPP: SearchPostProcess + Send + Sync, + Q: Send + Sync + ?Sized, { type Error = ANNError; async fn post_process( &self, accessor: &mut EncodedDocumentAccessor, - query: &FilteredQuery, + query: &FilteredQuery<'a, Q>, computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index 15c42501f..d85406b5d 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -9,17 +9,17 @@ use crate::ASTExpr; /// The Readme.md file in the label-filter folder describes the format /// of the query expression. #[derive(Clone)] -pub struct FilteredQuery { - query: V, +pub struct FilteredQuery<'a, V : ?Sized> { + query: &'a V, filter_expr: ASTExpr, } -impl FilteredQuery { - pub fn new(query: V, filter_expr: ASTExpr) -> Self { +impl<'a, V: ?Sized> FilteredQuery<'a, V> { + pub fn new(query: &'a V, filter_expr: ASTExpr) -> Self { Self { query, filter_expr } } - pub(crate) fn query(&self) -> &V { + pub(crate) fn query(&self) -> &'a V { &self.query } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 9a48488fe..f83b2ae25 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -321,36 +321,6 @@ where { } -/// Support for Vec queries that delegates to the [T] impl via deref. -/// This allows InlineBetaStrategy to use Vec queries with FullAccessor. -impl BuildQueryComputer> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &Vec, - ) -> Result { - // Delegate to [T] impl via deref - Ok(T::query_distance(from.as_slice(), self.provider.metric)) - } -} - -/// Support for Vec queries that delegates to the [T] impl. -impl ExpandBeam> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr + Clone, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} impl FillSet for FullAccessor<'_, T, Q, D, Ctx> where @@ -527,32 +497,6 @@ where } } -/// Support for Vec queries that delegates to the [T] impl. -/// This allows InlineBetaStrategy to use Vec queries with FullPrecision. -impl SearchStrategy, Vec> for FullPrecision -where - T: VectorRepr + Clone, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} // Pruning impl PruneStrategy> for FullPrecision From 7a1244a18d826d650e3ed184c206cf4296b6b04b Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 15:43:50 +0530 Subject: [PATCH 22/39] Undo unecessary change --- test_data/disk_index_search/data.256.label.jsonl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_data/disk_index_search/data.256.label.jsonl b/test_data/disk_index_search/data.256.label.jsonl index a99cde8e2..83254af7b 100644 --- a/test_data/disk_index_search/data.256.label.jsonl +++ b/test_data/disk_index_search/data.256.label.jsonl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:92576896b10780a2cd80a16030f8384610498b76453f57fadeacb854379e0acf -size 17701 +oid sha256:7f8b6b99ca32173557689712d3fb5da30c5e4111130fd2accbccf32f5ce3e47e +size 17702 From 86208c89da5af9b4677e6aa688cd96d03686c2b2 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 15:58:29 +0530 Subject: [PATCH 23/39] Remove whitespaces from file --- .../provider/async_/inmem/full_precision.rs | 1162 ++++++++--------- 1 file changed, 580 insertions(+), 582 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index f83b2ae25..e74419a46 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -1,582 +1,580 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{collections::HashMap, fmt::Debug, future::Future}; - -use diskann::{ - ANNError, ANNResult, - graph::{ - SearchOutputBuffer, - glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, - }, - utils::{IntoUsize, VectorRepr}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; - -use crate::model::graph::{ - provider::async_::{ - FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, - common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, - }, - inmem::DefaultProvider, - postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, - }, - traits::AdHoc, -}; - -/// A type alias for the DefaultProvider with full-precision as the primary vector store. -pub type FullPrecisionProvider = - DefaultProvider, Q, D, Ctx>; - -/// The default full-precision vector store. -pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; - -/// A default full-precision vector store provider. -#[derive(Clone)] -pub struct CreateFullPrecision { - dim: usize, - prefetch_cache_line_level: Option, - _phantom: std::marker::PhantomData, -} - -impl CreateFullPrecision -where - T: VectorRepr, -{ - /// Create a new full-precision vector store provider. - pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { - Self { - dim, - prefetch_cache_line_level, - _phantom: std::marker::PhantomData, - } - } -} - -impl CreateVectorStore for CreateFullPrecision -where - T: VectorRepr, -{ - type Target = FullPrecisionStore; - fn create( - self, - max_points: usize, - metric: Metric, - prefetch_lookahead: Option, - ) -> Self::Target { - FullPrecisionStore::new( - max_points, - self.dim, - metric, - self.prefetch_cache_line_level, - prefetch_lookahead, - ) - } -} - -//////////////// -// SetElement // -//////////////// - -impl SetElementHelper for FullPrecisionStore -where - T: VectorRepr, -{ - /// Set the element at the given index. - fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { - unsafe { self.set_vector_sync(id.into_usize(), element) } - } -} - -////////////////// -// FullAccessor // -////////////////// - -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, -{ - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, -} - -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } -} - -impl HasId for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, -{ - type Id = u32; -} - -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } - } -} - -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The extended element inherets the lifetime of the Accessor. - type Extended = &'a [T]; - - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - /// - /// NOTE: We intentionally don't use `'b` here since our implementation borrows - /// the inner `Opaque` from the underlying provider. - type Element<'b> - = &'a [T] - where - Self: 'b; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'b> = &'b [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} - -impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} - - -impl FillSet for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - for i in itr { - set.entry(i).or_insert_with(|| unsafe { - self.provider.base_vectors.get_vector_sync(i.into_usize()) - }); - } - Ok(()) - } -} - -//-------------------// -// In-mem Extensions // -//-------------------// - -impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted - } -} - -////////////////// -// Post Process // -////////////////// - -pub trait GetFullPrecision { - type Repr: VectorRepr; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; -} - -/// A [`SearchPostProcess`]or that: -/// -/// 1. Filters out deleted ids from being returned. -/// 2. Reranks a candidate stream using full-precision distances. -/// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; - -impl glue::SearchPostProcess for Rerank -where - T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, -{ - type Error = Panics; - - fn post_process( - &self, - accessor: &mut A, - query: &[T], - _computer: &A::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, - { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) - } -} - -//////////////// -// Strategies // -//////////////// - -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - - -// Pruning -impl PruneStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputer = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -/// Implementing this trait allows `FullPrecision` to be used for multi-insert. -impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = diskann::error::Infallible; - fn as_element( - &mut self, - vector: &'a [T], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl InsertStrategy, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type PruneStrategy = Self; - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } -} - -// Inplace Delete // -impl InplaceDeleteStrategy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type DeleteElementError = Panics; - type DeleteElement<'a> = [T]; - type DeleteElementGuard = Box<[T]>; - type PruneStrategy = Self; - type SearchStrategy = Internal; - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) - } - - fn prune_strategy(&self) -> Self::PruneStrategy { - Self - } - - async fn get_delete_element<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - id: u32, - ) -> Result { - Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) - } -} +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashMap, fmt::Debug, future::Future}; + +use diskann::{ + ANNError, ANNResult, + graph::{ + SearchOutputBuffer, + glue::{ + self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, + }, + }, + neighbor::Neighbor, + provider::{ + Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasId, + }, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::{DistanceFunction, distance::Metric}; + +use crate::model::graph::{ + provider::async_::{ + FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, + common::{ + CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, + PrefetchCacheLineLevel, SetElementHelper, + }, + inmem::DefaultProvider, + postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, + }, + traits::AdHoc, +}; + +/// A type alias for the DefaultProvider with full-precision as the primary vector store. +pub type FullPrecisionProvider = + DefaultProvider, Q, D, Ctx>; + +/// The default full-precision vector store. +pub type FullPrecisionStore = FastMemoryVectorProviderAsync>; + +/// A default full-precision vector store provider. +#[derive(Clone)] +pub struct CreateFullPrecision { + dim: usize, + prefetch_cache_line_level: Option, + _phantom: std::marker::PhantomData, +} + +impl CreateFullPrecision +where + T: VectorRepr, +{ + /// Create a new full-precision vector store provider. + pub fn new(dim: usize, prefetch_cache_line_level: Option) -> Self { + Self { + dim, + prefetch_cache_line_level, + _phantom: std::marker::PhantomData, + } + } +} + +impl CreateVectorStore for CreateFullPrecision +where + T: VectorRepr, +{ + type Target = FullPrecisionStore; + fn create( + self, + max_points: usize, + metric: Metric, + prefetch_lookahead: Option, + ) -> Self::Target { + FullPrecisionStore::new( + max_points, + self.dim, + metric, + self.prefetch_cache_line_level, + prefetch_lookahead, + ) + } +} + +//////////////// +// SetElement // +//////////////// + +impl SetElementHelper for FullPrecisionStore +where + T: VectorRepr, +{ + /// Set the element at the given index. + fn set_element(&self, id: &u32, element: &[T]) -> Result<(), ANNError> { + unsafe { self.set_vector_sync(id.into_usize(), element) } + } +} + +////////////////// +// FullAccessor // +////////////////// + +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +/// * [`BuildQueryComputer`]. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, + + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: Vec, +} + +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors + } +} + +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } +} + +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + pub fn new(provider: &'a FullPrecisionProvider) -> Self { + Self { + provider, + id_buffer: Vec::new(), + } + } +} + +impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.provider.neighbors() + } +} + +impl<'a, T, Q, D, Ctx> Accessor for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + /// The extended element inherets the lifetime of the Accessor. + type Extended = &'a [T]; + + /// This accessor returns raw slices. There *is* a chance of racing when the fast + /// providers are used. We just have to live with it. + /// + /// NOTE: We intentionally don't use `'b` here since our implementation borrows + /// the inner `Opaque` from the underlying provider. + type Element<'b> + = &'a [T] + where + Self: 'b; + + /// `ElementRef` has an arbitrarily short lifetime. + type ElementRef<'b> = &'b [T]; + + /// Choose to panic on an out-of-bounds access rather than propagate an error. + type GetError = Panics; + + /// Return the full-precision vector stored at index `i`. + /// + /// This function always completes synchronously. + #[inline(always)] + fn get_element( + &mut self, + id: Self::Id, + ) -> impl Future, Self::GetError>> + Send { + // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // potentially mixing unsynchronized reads and writes on the underlying memory. + std::future::ready(Ok(unsafe { + self.provider.base_vectors.get_vector_sync(id.into_usize()) + })) + } + + /// Perform a bulk operation. + /// + /// This implementation uses prefetching. + fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> impl Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(itr); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + f( + unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + *id, + ) + } + + std::future::ready(Ok(())) + } +} + +impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.provider.metric, + Some(self.provider.base_vectors.dim()), + )) + } +} + +impl BuildQueryComputer<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; + + fn build_query_computer( + &self, + from: &[T], + ) -> Result { + Ok(T::query_distance(from, self.provider.metric)) + } +} + +impl ExpandBeam<[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + +impl FillSet for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + async fn fill_set( + &mut self, + set: &mut HashMap, + itr: Itr, + ) -> Result<(), Self::GetError> + where + Itr: Iterator + Send + Sync, + { + for i in itr { + set.entry(i).or_insert_with(|| unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + } + Ok(()) + } +} + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, T, Q, D, Ctx> AsDeletionCheck for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted + } +} + +////////////////// +// Post Process // +////////////////// + +pub trait GetFullPrecision { + type Repr: VectorRepr; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync>; +} + +/// A [`SearchPostProcess`]or that: +/// +/// 1. Filters out deleted ids from being returned. +/// 2. Reranks a candidate stream using full-precision distances. +/// 3. Copies back the results to the output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl glue::SearchPostProcess for Rerank +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = Panics; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator>, + B: SearchOutputBuffer + ?Sized, + { + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + let f = full.distance(); + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. + std::future::ready(Ok(output.extend(reranked))) + } +} + +//////////////// +// Strategies // +//////////////// + +// A layered approach is used for search strategies. The `Internal` version does the heavy +// lifting in terms of establishing accessors and post processing. +// +// However, during post-processing, the `Internal` versions of strategies will not filter +// out the start points. The publicly exposed types *will* filter out the start points. +// +// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust +// the adjacency list for the start point to reuse the `Internal` strategies. + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = RemoveDeletedIdsAndCopy; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +/// Perform a search entirely in the full-precision space. +/// +/// Starting points are not filtered out of the final results. +impl SearchStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; + type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type SearchAccessorError = Panics; + type PostProcessor = glue::Pipeline; + + fn search_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::SearchAccessorError> { + Ok(FullAccessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + Default::default() + } +} + +// Pruning +impl PruneStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type DistanceComputer = T::Distance; + type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessorError = diskann::error::Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + ) -> Result, Self::PruneAccessorError> { + Ok(FullAccessor::new(provider)) + } +} + +/// Implementing this trait allows `FullPrecision` to be used for multi-insert. +impl<'a, T, Q, D, Ctx> glue::AsElement<&'a [T]> for FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type Error = diskann::error::Infallible; + fn as_element( + &mut self, + vector: &'a [T], + _id: Self::Id, + ) -> impl Future, Self::Error>> + Send { + std::future::ready(Ok(vector)) + } +} + +impl InsertStrategy, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type PruneStrategy = Self; + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } +} + +// Inplace Delete // +impl InplaceDeleteStrategy> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type DeleteElementError = Panics; + type DeleteElement<'a> = [T]; + type DeleteElementGuard = Box<[T]>; + type PruneStrategy = Self; + type SearchStrategy = Internal; + fn search_strategy(&self) -> Self::SearchStrategy { + Internal(Self) + } + + fn prune_strategy(&self) -> Self::PruneStrategy { + Self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a FullPrecisionProvider, + _context: &'a Ctx, + id: u32, + ) -> Result { + Ok(unsafe { provider.base_vectors.get_vector_sync(id.into_usize()) }.into()) + } +} From 4441dc735266292b0e7dcc35feee165cffdd9f6c Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 16:08:56 +0530 Subject: [PATCH 24/39] Fix formatting errors --- .../encoded_attribute_provider/encoded_filter_expr.rs | 9 +++++---- .../src/inline_beta_search/encoded_document_accessor.rs | 3 ++- .../src/inline_beta_search/inline_beta_filter.rs | 6 ++++-- diskann-label-filter/src/query.rs | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index b621e347c..0ebdaf72f 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -20,12 +20,13 @@ pub(crate) struct EncodedFilterExpr { } impl EncodedFilterExpr { - pub fn try_create(ast_expr: &ASTExpr, attribute_map: Arc>) -> ANNResult { + pub fn try_create( + ast_expr: &ASTExpr, + attribute_map: Arc>, + ) -> ANNResult { let mut mapper = ASTLabelIdMapper::new(attribute_map); let ast_id_expr = ast_expr.accept(&mut mapper)?; - Ok(Self { - ast_id_expr, - }) + Ok(Self { ast_id_expr }) } pub(crate) fn encoded_filter_expr(&self) -> &ASTIdExpr { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0b658fd4d..418526af7 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -220,7 +220,8 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; + let id_query = + EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; Ok(InlineBetaComputer::new( inner_computer, diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 8d1784029..ae9a045d6 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -35,8 +35,10 @@ impl InlineBetaStrategy { } impl - SearchStrategy>, FilteredQuery<'_, Q>> - for InlineBetaStrategy + SearchStrategy< + DocumentProvider>, + FilteredQuery<'_, Q>, + > for InlineBetaStrategy where DP: DataProvider, Strategy: SearchStrategy, diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index d85406b5d..142867360 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -9,7 +9,7 @@ use crate::ASTExpr; /// The Readme.md file in the label-filter folder describes the format /// of the query expression. #[derive(Clone)] -pub struct FilteredQuery<'a, V : ?Sized> { +pub struct FilteredQuery<'a, V: ?Sized> { query: &'a V, filter_expr: ASTExpr, } From 5ed93b9ee4e50c5b50a6d07598660f36f8a2f5a5 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 16:12:56 +0530 Subject: [PATCH 25/39] Fix clippy warning --- diskann-label-filter/src/query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index 142867360..b1a7a6409 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -20,7 +20,7 @@ impl<'a, V: ?Sized> FilteredQuery<'a, V> { } pub(crate) fn query(&self) -> &'a V { - &self.query + self.query } pub(crate) fn filter_expr(&self) -> &ASTExpr { From 072edb5e48d85adae0757de4c6a77e0b5d9f038f Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 16:56:16 +0530 Subject: [PATCH 26/39] Fix build errors after merge with main - Use Knn instead of the old SearchParams --- .../src/backend/document_index/benchmark.rs | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 38da6bfc7..2d5b6a7ef 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,8 +15,8 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, - search_output_buffer, DiskANNIndex, SearchOutputBuffer, SearchParams, StartPointStrategy, + config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, search::Knn, + search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, }, provider::DefaultContext, }; @@ -428,6 +428,7 @@ impl<'a, T> DocumentIndexJob<'a, T> { Ok(stats) } } + /// Per-query output from [`FilteredSearcher::search`]. struct FilteredSearchOutput { distances: Vec, @@ -464,22 +465,20 @@ where T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; - type Parameters = SearchParams; + type Parameters = Knn; type Output = FilteredSearchOutput; fn num_queries(&self) -> usize { self.queries.nrows() } - fn id_count(&self, parameters: &SearchParams) -> search_api::IdCount { - search_api::IdCount::Fixed( - NonZeroUsize::new(parameters.k_value).unwrap_or(diskann::utils::ONE), - ) + fn id_count(&self, parameters: &Knn) -> search_api::IdCount { + search_api::IdCount::Fixed(parameters.k_value()) } async fn search( &self, - parameters: &SearchParams, + parameters: &Knn, buffer: &mut O, index: usize, ) -> diskann::ANNResult @@ -494,14 +493,14 @@ where // Use a concrete IdDistance scratch buffer so that both the IDs and distances // are captured. Afterwards, the valid IDs are forwarded into the framework buffer. - let k = parameters.k_value; + let k = parameters.k_value().get(); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; let mut scratch = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let stats = self + let stats = &self .index - .search(&strategy, &ctx, &filtered_query, parameters, &mut scratch) + .search(*parameters, &strategy, &ctx, &filtered_query, &mut scratch) .await?; let count = scratch.current_len(); @@ -526,18 +525,16 @@ struct FilteredSearchAggregator<'a> { recall_k: usize, } -impl search_api::Aggregate - for FilteredSearchAggregator<'_> -{ +impl search_api::Aggregate for FilteredSearchAggregator<'_> { type Output = SearchRunStats; fn aggregate( &mut self, - run: search_api::Run, + run: search_api::Run, results: Vec>, ) -> anyhow::Result { let parameters = run.parameters(); - let search_n = parameters.k_value; + let search_n = parameters.k_value().get(); let num_queries = results.first().map(|r| r.len()).unwrap_or(0); // Recall from first rep only. @@ -635,7 +632,7 @@ impl search_api::Aggregate num_threads: run.setup().threads.get(), num_queries, search_n, - search_l: parameters.l_value, + search_l: parameters.l_value().get(), recall: recall_metrics, qps, wall_clock_time: rep_latencies, @@ -682,7 +679,7 @@ where beta, }); - let parameters = SearchParams::new_default(search_n, search_l)?; + let parameters = Knn::new_default(search_n, search_l)?; let setup = search_api::Setup { threads: num_threads, tasks: num_threads, From 6a31782bdb86e97b2e68637d422bbc9b864aa697 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 16 Mar 2026 18:52:08 +0530 Subject: [PATCH 27/39] Undo rename of RecallMetrics --- diskann-benchmark/src/backend/document_index/benchmark.rs | 6 +++--- diskann-benchmark/src/backend/index/result.rs | 2 +- diskann-benchmark/src/utils/recall.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 2d5b6a7ef..d954f1f47 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -58,7 +58,7 @@ use crate::{ inputs::document_index::DocumentIndexBuild, utils::{ datafiles::{self, BinFile}, - recall::SerializableRecallMetrics, + recall::RecallMetrics, }, }; @@ -538,7 +538,7 @@ impl search_api::Aggregate for FilteredSearchAgg let num_queries = results.first().map(|r| r.len()).unwrap_or(0); // Recall from first rep only. - let recall_metrics: SerializableRecallMetrics = match results.first() { + let recall_metrics: RecallMetrics = match results.first() { Some(first) => (&recall::knn( self.groundtruth, None, @@ -751,7 +751,7 @@ pub struct SearchRunStats { pub num_queries: usize, pub search_n: usize, pub search_l: usize, - pub recall: SerializableRecallMetrics, + pub recall: RecallMetrics, pub qps: Vec, pub wall_clock_time: Vec, pub mean_latency: f64, diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 429d4f060..1d6102f9b 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -116,7 +116,7 @@ pub(super) struct SearchResults { pub(super) mean_latencies: Vec, pub(super) p90_latencies: Vec, pub(super) p99_latencies: Vec, - pub(super) recall: utils::recall::SerializableRecallMetrics, + pub(super) recall: utils::recall::RecallMetrics, pub(super) mean_cmps: f32, pub(super) mean_hops: f32, } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index 9628a6205..c0ed813cd 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -8,7 +8,7 @@ use serde::Serialize; #[derive(Debug, Clone, Serialize)] #[non_exhaustive] -pub(crate) struct SerializableRecallMetrics { +pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. pub(crate) recall_k: usize, /// The `n` value for `k-recall-at-n`. @@ -23,7 +23,7 @@ pub(crate) struct SerializableRecallMetrics { pub(crate) maximum: usize, } -impl From<&benchmark_core::recall::RecallMetrics> for SerializableRecallMetrics { +impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { fn from(m: &benchmark_core::recall::RecallMetrics) -> Self { Self { recall_k: m.recall_k, From a70ee53049601238d7ee61c19d3955e2deaadd59 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 18 Mar 2026 13:23:28 +0530 Subject: [PATCH 28/39] Remove fallback to unfiltered search. --- .../inline_beta_search/inline_beta_filter.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index ae9a045d6..8d6392381 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -109,20 +109,10 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim - } - } - Err(_) => { - //If predicate evaluation fails for any reason, we simply revert - //to unfiltered search. - tracing::warn!("Predicate evaluation failed"); - sim - } + if self.filter_expr.encoded_filter_expr().accept(&pred_eval).expect("Expected predicate evaluation to not error out!") { + sim * self.beta_value + } else { + sim } } } From a313013cdeb50cdb9806563b9321bf8ff5f2581d Mon Sep 17 00:00:00 2001 From: sampathrg Date: Wed, 18 Mar 2026 13:48:02 +0530 Subject: [PATCH 29/39] Fix formatting error --- .../src/inline_beta_search/inline_beta_filter.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 8d6392381..68dcad630 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -109,7 +109,12 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - if self.filter_expr.encoded_filter_expr().accept(&pred_eval).expect("Expected predicate evaluation to not error out!") { + if self + .filter_expr + .encoded_filter_expr() + .accept(&pred_eval) + .expect("Expected predicate evaluation to not error out!") + { sim * self.beta_value } else { sim From 7ea6e4ef08d699a75062290348f5aa4a475e8101 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 12:53:36 +0530 Subject: [PATCH 30/39] Address review comments --- .../example/document-filter.json | 4 + .../src/backend/document_index/benchmark.rs | 162 +++++------------- diskann-benchmark/src/backend/index/build.rs | 4 +- diskann-benchmark/src/backend/index/mod.rs | 2 +- .../src/inputs/document_index.rs | 31 ++-- .../document_insert_strategy.rs | 2 +- 6 files changed, 73 insertions(+), 132 deletions(-) diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json index d60cd4806..0bce1572d 100644 --- a/diskann-benchmark/example/document-filter.json +++ b/diskann-benchmark/example/document-filter.json @@ -21,6 +21,10 @@ "query_predicates": "query.10.label.jsonl", "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", "beta": 0.5, + "reps": 5, + "num_threads": [ + 1 + ], "runs": [ { "search_n": 20, diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index d954f1f47..6cd1ec1fa 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,20 +15,21 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - config::Builder as ConfigBuilder, config::MaxDegree, config::PruneKind, search::Knn, + search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, }, provider::DefaultContext, + ANNError, ANNErrorKind, }; use diskann_benchmark_core::{ - build::{self, AsProgress, Build, Parallelism, Progress}, + build::{self, Build, Parallelism}, recall, search as search_api, tokio, }; use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, registry::Benchmarks, - utils::{datatype, percentiles, MicroSeconds}, + utils::{datatype, fmt, percentiles, MicroSeconds}, Any, }; use diskann_label_filter::{ @@ -40,8 +41,11 @@ use diskann_label_filter::{ }, inline_beta_search::inline_beta_filter::InlineBetaStrategy, query::FilteredQuery, - read_and_parse_queries, read_baselabels, ASTExpr, + read_and_parse_queries, read_baselabels, + traits::attribute_store::AttributeStore, + ASTExpr, }; + use diskann_providers::model::graph::provider::async_::{ common::{self, NoStore, TableBasedDeletes}, inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, @@ -51,10 +55,10 @@ use diskann_utils::views::MatrixView; use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; use diskann_vector::distance::SquaredL2; use diskann_vector::PureDistanceFunction; -use indicatif::{ProgressBar, ProgressStyle}; use serde::Serialize; use crate::{ + backend::index::build::ProgressMeter, inputs::document_index::DocumentIndexBuild, utils::{ datafiles::{self, BinFile}, @@ -100,24 +104,12 @@ where type Error = std::convert::Infallible; fn try_match(_from: &&'a DocumentIndexBuild) -> Result { - match _from.build.data_type { - datatype::DataType::Float32 => Ok(MatchScore(0)), - datatype::DataType::UInt8 => Ok(MatchScore(0)), - datatype::DataType::Int8 => Ok(MatchScore(0)), - _ => Err(datatype::MATCH_FAIL), - } + datatype::Type::::try_match(&_from.build.data_type) } fn convert(from: &'a DocumentIndexBuild) -> Result { Ok(DocumentIndexJob::new(from)) } - - fn description( - f: &mut std::fmt::Formatter<'_>, - _from: Option<&&'a DocumentIndexBuild>, - ) -> std::fmt::Result { - writeln!(f, "tag: \"{}\"", DocumentIndexBuild::tag()) - } } // Central dispatch mapping from Any @@ -236,23 +228,14 @@ impl<'a, T> DocumentIndexJob<'a, T> { .collect(); // 3. Create the index configuration - let metric = build.distance.into(); - let prune_kind = PruneKind::from_metric(metric); - let mut config_builder = ConfigBuilder::new( - build.max_degree, // pruned_degree - MaxDegree::Same, // max_degree - build.l_build, // l_build - prune_kind, // prune_kind - ); - config_builder.alpha(build.alpha); - let config = config_builder.build()?; + let config = build.build_config()?; // 4. Create the data provider directly writeln!(output, "Creating index...")?; let params = DefaultProviderParameters { max_points: num_vectors, frozen_points: diskann::utils::ONE, - metric, + metric: build.distance.into(), dim, prefetch_lookahead: None, prefetch_cache_line_level: None, @@ -278,7 +261,6 @@ impl<'a, T> DocumentIndexJob<'a, T> { let medoid_idx = compute_medoid_index(&data)?; let start_point_id = num_vectors as u32; // Start points begin at max_points let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); - use diskann_label_filter::traits::attribute_store::AttributeStore; attribute_store.set_element(&start_point_id, &medoid_attrs)?; let doc_provider = DocumentProvider::new(inner_provider, attribute_store); @@ -306,15 +288,8 @@ impl<'a, T> DocumentIndexJob<'a, T> { ); let num_tasks = NonZeroUsize::new(build.num_threads).unwrap_or(diskann::utils::ONE); let parallelism = Parallelism::dynamic(diskann::utils::ONE, num_tasks); - let progress = IndicatifAsProgress({ - let bar = ProgressBar::with_draw_target(Some(num_vectors as u64), output.draw_target()); - bar.set_style( - ProgressStyle::with_template("Building [{elapsed_precise}] {wide_bar} {percent}") - .expect("valid template"), - ); - bar - }); - let build_results = build::build_tracked(builder, parallelism, &rt, Some(&progress))?; + let build_results = + build::build_tracked(builder, parallelism, &rt, Some(&ProgressMeter::new(output)))?; let insert_latencies: Vec = build_results .take_output() .into_iter() @@ -416,11 +391,6 @@ impl<'a, T> DocumentIndexJob<'a, T> { label_load_time, build_time, insert_latencies: insert_percentiles, - build_params: BuildParamsStats { - max_degree: build.max_degree, - l_build: build.l_build, - alpha: build.alpha, - }, search: search_results, }; @@ -700,12 +670,6 @@ where .pop() .ok_or_else(|| anyhow::anyhow!("no search results")) } -#[derive(Debug, Serialize)] -pub struct BuildParamsStats { - pub max_degree: usize, - pub l_build: usize, - pub alpha: f32, -} /// Helper module for serializing arrays as compact single-line JSON strings mod compact_array { @@ -772,7 +736,6 @@ pub struct DocumentIndexStats { pub label_load_time: MicroSeconds, pub build_time: MicroSeconds, pub insert_latencies: percentiles::Percentiles, - pub build_params: BuildParamsStats, pub search: Vec, } @@ -797,16 +760,9 @@ impl std::fmt::Display for DocumentIndexStats { writeln!(f, " P50: {} us", self.insert_latencies.median)?; writeln!(f, " P90: {} us", self.insert_latencies.p90)?; writeln!(f, " P99: {} us", self.insert_latencies.p99)?; - writeln!(f, " Build Parameters:")?; - writeln!(f, " max_degree (R): {}", self.build_params.max_degree)?; - writeln!(f, " l_build (L): {}", self.build_params.l_build)?; - writeln!(f, " alpha: {}", self.build_params.alpha)?; if !self.search.is_empty() { - writeln!(f, "\nFiltered Search Results:")?; - writeln!( - f, - " {:>8} {:>8} {:>10} {:>10} {:>15} {:>12} {:>12} {:>10} {:>8} {:>10} {:>12}", + let header = [ "L", "KNN", "Avg Cmps", @@ -817,41 +773,34 @@ impl std::fmt::Display for DocumentIndexStats { "Recall", "Threads", "Queries", - "WallClock(s)" - )?; - for s in &self.search { - let mean_qps = if s.qps.is_empty() { - 0.0 - } else { - s.qps.iter().sum::() / s.qps.len() as f64 - }; + "WallClock(s)", + ]; + writeln!(f, "\nFiltered Search Results:")?; + let mut table = fmt::Table::new(header, self.search.len()); + self.search.iter().enumerate().for_each(|(row_idx, s)| { + let mut row = table.row(row_idx); + let mean_qps = percentiles::mean(&s.qps).unwrap_or(0.0); let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); - let mean_wall_clock = if s.wall_clock_time.is_empty() { - 0.0 - } else { - s.wall_clock_time + let mean_wall_clock = percentiles::mean( + &s.wall_clock_time .iter() - .map(|t| t.as_seconds()) - .sum::() - / s.wall_clock_time.len() as f64 - }; - writeln!( - f, - " {:>8} {:>8} {:>10.1} {:>10.1} {:>7.1}({:>5.1}) {:>12.1} {:>12} {:>10.4} {:>8} {:>10} {:>12.3}", - s.search_l, - s.search_n, - s.mean_cmps, - s.mean_hops, - mean_qps, - max_qps, - s.mean_latency, - s.p99_latency, - s.recall.average, - s.num_threads, - s.num_queries, - mean_wall_clock - )?; - } + .map(|l| l.as_seconds()) + .collect::>(), + ) + .unwrap_or(0.0); + row.insert(s.search_l, 0); + row.insert(s.search_n, 1); + row.insert(format!("{:.1}", s.mean_cmps), 2); + row.insert(format!("{:.1}", s.mean_hops), 3); + row.insert(format!("{:.1}({:.1})", mean_qps, max_qps), 4); + row.insert(format!("{:.1} s", s.mean_latency), 5); + row.insert(format!("{:.1} s", s.p99_latency), 6); + row.insert(format!("{:.4}", s.recall.average), 7); + row.insert(s.num_threads, 8); + row.insert(s.num_queries, 9); + row.insert(format!("{:.3} s", mean_wall_clock), 10); + }); + write!(f, "{}", table)?; } Ok(()) } @@ -906,7 +855,12 @@ where async fn build(&self, range: std::ops::Range) -> diskann::ANNResult { let ctx = DefaultContext; for i in range { - let attrs = self.attributes.get(i).cloned().unwrap_or_default(); + let attrs = self.attributes.get(i).cloned().ok_or_else(|| { + ANNError::message( + ANNErrorKind::Opaque, + format!("Failed to get attributes at index {}", i), + ) + })?; let doc = Document::new(self.data.row(i), attrs); self.index .insert(self.strategy, &ctx, &(i as u32), &doc) @@ -915,25 +869,3 @@ where Ok(()) } } - -/// Adapts an already-constructed [`ProgressBar`] into the [`AsProgress`] / [`Progress`] -/// traits expected by [`build_tracked`]. -struct IndicatifAsProgress(ProgressBar); - -struct IndicatifProgress(ProgressBar); - -impl Progress for IndicatifProgress { - fn progress(&self, handled: usize) { - self.0.inc(handled as u64); - } - - fn finish(&self) { - self.0.finish(); - } -} - -impl AsProgress for IndicatifAsProgress { - fn as_progress(&self, _max: usize) -> Arc { - Arc::new(IndicatifProgress(self.0.clone())) - } -} diff --git a/diskann-benchmark/src/backend/index/build.rs b/diskann-benchmark/src/backend/index/build.rs index ef6284d2a..b674bf5e8 100644 --- a/diskann-benchmark/src/backend/index/build.rs +++ b/diskann-benchmark/src/backend/index/build.rs @@ -213,12 +213,12 @@ impl std::fmt::Display for BuildStats { } } -pub struct ProgressMeter<'a> { +pub(crate) struct ProgressMeter<'a> { output: &'a mut dyn Output, } impl<'a> ProgressMeter<'a> { - pub fn new(output: &'a mut dyn Output) -> Self { + pub(crate) fn new(output: &'a mut dyn Output) -> Self { Self { output } } } diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index 269887c6d..07ed0ccb8 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -mod build; +pub(crate) mod build; mod search; mod streaming; diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index f1d2d7c67..4d3e72235 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -8,6 +8,7 @@ use std::num::NonZeroUsize; use anyhow::Context; +use diskann::graph::{Config, config::{Builder, MaxDegree, PruneKind, ConfigError}}; use diskann_benchmark_runner::{ files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, }; @@ -42,21 +43,27 @@ pub(crate) struct DocumentBuildParams { pub(crate) num_threads: usize, } +impl DocumentBuildParams { + pub(crate) fn build_config(&self) -> Result { + let metric = self.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = Builder::new( + self.max_degree, // pruned_degree + MaxDegree::default_slack(), // max_degree + self.l_build, + prune_kind, + ); + config_builder.alpha(self.alpha); + let config = config_builder.build()?; + Ok(config) + } +} + impl CheckDeserialization for DocumentBuildParams { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.data.check_deserialization(checker)?; self.data_labels.check_deserialization(checker)?; - - // checking if the max_degree, l_build and alpha values are valid. - use diskann::graph::config::{Builder, MaxDegree, PruneKind}; - let mut builder = Builder::new( - self.max_degree, - MaxDegree::Value(self.max_degree), - self.l_build, - PruneKind::Occluding, - ); - builder.alpha(self.alpha); - builder.build()?; + self.build_config()?; Ok(()) } } @@ -67,9 +74,7 @@ pub(crate) struct DocumentSearchParams { pub(crate) query_predicates: InputFile, pub(crate) groundtruth: InputFile, pub(crate) beta: f32, - #[serde(default = "default_reps")] pub(crate) reps: NonZeroUsize, - #[serde(default = "default_thread_counts")] pub(crate) num_threads: Vec, pub(crate) runs: Vec, } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 9a3bad9a0..6270af72e 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -233,7 +233,7 @@ where fn prune_accessor<'a>( &'a self, provider: &'a DocumentProvider>, - context: &'a > as DataProvider>::Context, + context: &'a DP::Context, ) -> Result, Self::PruneAccessorError> { self.inner .prune_accessor(provider.inner_provider(), context) From 8010d374ae211136a9aa0b7ab1da913502694054 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:27:33 +0530 Subject: [PATCH 31/39] Formatting + revert to old names for some functions --- diskann-benchmark/src/inputs/document_index.rs | 9 ++++++--- diskann-benchmark/src/utils/recall.rs | 1 + .../encoded_attribute_provider/encoded_filter_expr.rs | 2 +- .../src/inline_beta_search/encoded_document_accessor.rs | 3 +-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs index 4d3e72235..f1ed3c063 100644 --- a/diskann-benchmark/src/inputs/document_index.rs +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -8,7 +8,10 @@ use std::num::NonZeroUsize; use anyhow::Context; -use diskann::graph::{Config, config::{Builder, MaxDegree, PruneKind, ConfigError}}; +use diskann::graph::{ + config::{Builder, ConfigError, MaxDegree, PruneKind}, + Config, +}; use diskann_benchmark_runner::{ files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, }; @@ -48,8 +51,8 @@ impl DocumentBuildParams { let metric = self.distance.into(); let prune_kind = PruneKind::from_metric(metric); let mut config_builder = Builder::new( - self.max_degree, // pruned_degree - MaxDegree::default_slack(), // max_degree + self.max_degree, // pruned_degree + MaxDegree::default_slack(), // max_degree self.l_build, prune_kind, ); diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index c0ed813cd..dcbe86d94 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -2,6 +2,7 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ + use diskann_benchmark_core as benchmark_core; use serde::Serialize; diff --git a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs index 0ebdaf72f..d56cb13c1 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/encoded_filter_expr.rs @@ -20,7 +20,7 @@ pub(crate) struct EncodedFilterExpr { } impl EncodedFilterExpr { - pub fn try_create( + pub fn new( ast_expr: &ASTExpr, attribute_map: Arc>, ) -> ANNResult { diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 418526af7..ab82dad56 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -220,8 +220,7 @@ where .inner_accessor .build_query_computer(from.query()) .into_ann_result()?; - let id_query = - EncodedFilterExpr::try_create(from.filter_expr(), self.attribute_map.clone())?; + let id_query = EncodedFilterExpr::new(from.filter_expr(), self.attribute_map.clone())?; Ok(InlineBetaComputer::new( inner_computer, From 8a86e0d1e5507802d3f672f1d1bfe2c724c1b3ea Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:31:37 +0530 Subject: [PATCH 32/39] Add some unit tests + smoke test for benchmark --- diskann-benchmark/src/main.rs | 33 +++ diskann-label-filter/Cargo.toml | 2 + .../document_insert_strategy.rs | 171 ++++++++++--- .../inline_beta_search/inline_beta_filter.rs | 227 ++++++++++++++++++ 4 files changed, 404 insertions(+), 29 deletions(-) diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index b3de5901e..7bc382782 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -779,4 +779,37 @@ mod tests { let mut output = Memory::new(); cli.check_target(&mut output).unwrap(); } + + #[test] + fn document_filter_integration() { + let input_path = example_directory().join("document-filter.json"); + + let tempdir = tempfile::tempdir().unwrap(); + let output_path = tempdir.path().join("output.json"); + assert!(!output_path.exists()); + + let modified_input_path = tempdir.path().join("input.json"); + + let mut raw = value_from_file(&input_path); + prefix_search_directories(&mut raw, &root_directory()); + save_to_file(&modified_input_path, &raw); + + let command = Commands::Run { + input_file: modified_input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + cli.run(&mut output).unwrap(); + + let output = String::from_utf8(output.into_inner()).unwrap(); + println!("output = {}", output); + // Check that the results file is generated. + assert!(output_path.exists()); + + let results: Vec = load_from_file(&output_path); + assert_eq!(results.len(), num_jobs(&raw)); + } } diff --git a/diskann-label-filter/Cargo.toml b/diskann-label-filter/Cargo.toml index 98fe3879c..b204dae23 100644 --- a/diskann-label-filter/Cargo.toml +++ b/diskann-label-filter/Cargo.toml @@ -33,6 +33,8 @@ tempfile.workspace = true anyhow.workspace = true futures-util.workspace = true tracing.workspace = true +diskann = { workspace = true, features = ["testing"] } +tokio = { workspace = true, features = ["rt"] } diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 6270af72e..15f40e5fe 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -6,8 +6,6 @@ //! A strategy wrapper that enables insertion of [Document] objects into a //! [DiskANNIndex] using a [DocumentProvider]. -use std::marker::PhantomData; - use diskann::{ graph::glue::{self, ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, @@ -19,28 +17,23 @@ use crate::document::Document; use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; /// A strategy wrapper that enables insertion of [Document] objects. -pub struct DocumentInsertStrategy { +pub struct DocumentInsertStrategy { inner: Inner, - _phantom: PhantomData VT>, } -impl Clone for DocumentInsertStrategy { +impl Clone for DocumentInsertStrategy { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - _phantom: PhantomData, } } } -impl Copy for DocumentInsertStrategy {} +impl Copy for DocumentInsertStrategy {} -impl DocumentInsertStrategy { +impl DocumentInsertStrategy { pub fn new(inner: Inner) -> Self { - Self { - inner, - _phantom: PhantomData, - } + Self { inner } } pub fn inner(&self) -> &Inner { @@ -49,32 +42,30 @@ impl DocumentInsertStrategy { } /// Wrapper accessor for Document queries -pub struct DocumentSearchAccessor { +pub struct DocumentSearchAccessor { inner: Inner, - _phantom: PhantomData VT>, + // _phantom: PhantomData VT>, } -impl DocumentSearchAccessor { +impl DocumentSearchAccessor { pub fn new(inner: Inner) -> Self { Self { inner, - _phantom: PhantomData, + // _phantom: PhantomData, } } } -impl HasId for DocumentSearchAccessor +impl HasId for DocumentSearchAccessor where Inner: HasId, - VT: ?Sized, { type Id = Inner::Id; } -impl Accessor for DocumentSearchAccessor +impl Accessor for DocumentSearchAccessor where Inner: Accessor, - VT: ?Sized, { type ElementRef<'a> = Inner::ElementRef<'a>; type Element<'a> @@ -105,7 +96,7 @@ where } } -impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor where Inner: BuildQueryComputer, VT: ?Sized, @@ -121,10 +112,9 @@ where } } -impl<'this, Inner, VT> DelegateNeighbor<'this> for DocumentSearchAccessor +impl<'this, Inner> DelegateNeighbor<'this> for DocumentSearchAccessor where Inner: DelegateNeighbor<'this>, - VT: ?Sized, { type Delegate = Inner::Delegate; fn delegate_neighbor(&'this mut self) -> Self::Delegate { @@ -132,17 +122,16 @@ where } } -impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor where Inner: ExpandBeam, VT: ?Sized, { } -impl SearchExt for DocumentSearchAccessor +impl SearchExt for DocumentSearchAccessor where Inner: SearchExt, - VT: ?Sized, { fn starting_points( &self, @@ -156,7 +145,7 @@ where impl<'doc, Inner, DP, VT> SearchStrategy>, Document<'doc, VT>> - for DocumentInsertStrategy + for DocumentInsertStrategy where Inner: InsertStrategy, DP: DataProvider, @@ -165,7 +154,7 @@ where type QueryComputer = Inner::QueryComputer; type PostProcessor = glue::CopyIds; type SearchAccessorError = Inner::SearchAccessorError; - type SearchAccessor<'a> = DocumentSearchAccessor, VT>; + type SearchAccessor<'a> = DocumentSearchAccessor>; fn search_accessor<'a>( &'a self, @@ -185,7 +174,7 @@ where impl<'doc, Inner, DP, VT> InsertStrategy>, Document<'doc, VT>> - for DocumentInsertStrategy + for DocumentInsertStrategy where Inner: InsertStrategy, DP: DataProvider, @@ -239,3 +228,127 @@ where .prune_accessor(provider.inner_provider(), context) } } + +#[cfg(test)] +mod tests { + use super::{DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor}; + use diskann::{ + graph::{ + glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + provider::BuildQueryComputer, + }; + use diskann_vector::distance::Metric; + + use crate::{ + document::Document, + encoded_attribute_provider::{ + document_provider::DocumentProvider, roaring_attribute_store::RoaringAttributeStore, + }, + }; + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /// Build a minimal test provider with a single start point and three dimensions. + fn make_test_provider() -> Provider { + let config = Config::new( + Metric::L2, + 10, + StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), + ) + .expect("test provider config should be valid"); + Provider::new(config) + } + + fn make_doc_provider( + provider: Provider, + ) -> DocumentProvider> { + DocumentProvider::new(provider, RoaringAttributeStore::new()) + } + + /// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the + /// inner accessor. + #[test] + fn test_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as SearchStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); + } + + #[test] + fn test_insert_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as InsertStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::insert_search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); + } + + #[test] + fn test_prune_accessor_delegates_to_inner_provider() { + let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as PruneStrategy< + DocumentProvider>, + >>::prune_accessor(&doc_prune_strategy, &provider, &context); + + assert!(result.is_ok()); + } + + #[test] + fn test_build_query_computer_extracts_vector_from_document() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let doc_accessor = DocumentSearchAccessor::new(inner_accessor); + + let vector = vec![1.0f32, 2.0, 0.0]; + let doc = Document::new(vector.as_slice(), vec![]); + + let result = as BuildQueryComputer< + Document<'_, [f32]>, + >>::build_query_computer(&doc_accessor, &doc); + + assert!( + result.is_ok(), + "build_query_computer should succeed for a valid vector" + ); + } + + #[test] + fn test_terminate_early_delegates_to_inner() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let mut inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let inner_terminate_early = inner_accessor.terminate_early(); + let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); + assert_eq!( + inner_terminate_early, + doc_accessor.terminate_early(), + "terminate_early should have same value as inner accessor" + ); + } +} diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 68dcad630..b1cfbca4b 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -175,3 +175,230 @@ where .map_err(|e| e.into()) } } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, RwLock}; + + use diskann::{ + graph::{ + glue::{self, SearchPostProcess, SearchStrategy}, + search_output_buffer::IdDistance, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, SetElement}, + }; + use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; + use roaring::RoaringTreemap; + use serde_json::Value; + + use crate::{ + attribute::{Attribute, AttributeValue}, + document::EncodedDocument, + encoded_attribute_provider::{ + attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::encoded_document_accessor::EncodedDocumentAccessor, + query::FilteredQuery, + traits::attribute_store::AttributeStore, + ASTExpr, CompareOp, + }; + + use super::{FilterResults, InlineBetaComputer}; + + // ----------------------------------------------------------------------- + // Stub inner distance computer + // ----------------------------------------------------------------------- + + /// Always returns a fixed constant distance, regardless of the vector value. + struct ConstComputer(f32); + + impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { + fn evaluate_similarity(&self, _: &[f32]) -> f32 { + self.0 + } + } + + // ----------------------------------------------------------------------- + // Helper: build an AttributeEncoder + ASTExpr for `field == value`, + // returning (attr_map, ast_expr, encoded_id_of_that_attribute). + // ----------------------------------------------------------------------- + + fn setup_encoder_and_filter( + field: &str, + value: &str, + ) -> (Arc>, ASTExpr, u64) { + let mut encoder = AttributeEncoder::new(); + let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); + let encoded_id = encoder.insert(&attr); + let attr_map = Arc::new(RwLock::new(encoder)); + let ast_expr = ASTExpr::Compare { + field: field.to_string(), + op: CompareOp::Eq(Value::String(value.to_string())), + }; + (attr_map, ast_expr, encoded_id) + } + + // ----------------------------------------------------------------------- + // Test 1: when the filter matches, evaluate_similarity returns inner * beta + // ----------------------------------------------------------------------- + + #[test] + fn test_evaluate_similarity_filter_match_scales_by_beta() { + let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Bitmap contains the encoded ID for "color=red" → predicate matches + let mut matching_map = RoaringTreemap::new(); + matching_map.insert(color_red_id); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist * beta, + "a matched filter should multiply the inner similarity by beta" + ); + } + + // ----------------------------------------------------------------------- + // Test 2: when the filter does not match, evaluate_similarity is unchanged + // ----------------------------------------------------------------------- + + #[test] + fn test_evaluate_similarity_no_filter_match_preserves_score() { + let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Empty bitmap → no attribute matches the predicate + let empty_map = RoaringTreemap::new(); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist, + "an unmatched filter should leave the inner similarity unchanged" + ); + } + + // ----------------------------------------------------------------------- + // Test 3: post_process forwards only filter-matching candidates to the + // inner post processor (and therefore to the output buffer). + // ----------------------------------------------------------------------- + + #[test] + fn test_post_process_only_passes_matching_candidates_to_inner() { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("test tokio runtime"); + + // IDs 0 and 1 carry color=red (should pass the filter) + // IDs 2 and 3 carry color=blue (should be dropped by the filter) + let attr_store = RoaringAttributeStore::::new(); + let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); + let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); + for id in 0u32..2 { + attr_store + .set_element(&id, std::slice::from_ref(&red)) + .expect("set red attr"); + } + for id in 2u32..4 { + attr_store + .set_element(&id, std::slice::from_ref(&blue)) + .expect("set blue attr"); + } + + // The attribute_map is shared so EncodedFilterExpr sees the same encodings + // as those stored by the attribute store. + let attr_map = attr_store.attribute_map(); + + let ast_expr = ASTExpr::Compare { + field: "color".to_string(), + op: CompareOp::Eq(Value::String("red".to_string())), + }; + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); + + // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 + let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) + .expect("provider config"); + let inner_provider = Provider::new(config); + let ctx = Context::new(); + rt.block_on(async { + for id in 0u32..4 { + inner_provider + .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) + .await + .expect("add vector to inner provider"); + } + }); + + // Obtain the inner search accessor and derive an inner computer from it + let strategy = Strategy::new(); + let inner_accessor = strategy + .search_accessor(&inner_provider, &ctx) + .expect("inner accessor"); + let inner_computer = inner_accessor + .build_query_computer(&[0.0f32, 0.0][..]) + .expect("inner computer"); + + // Wrap accessor + attribute store into an EncodedDocumentAccessor + let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); + let mut doc_accessor = + EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); + + let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); + + // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) + let candidates = [ + Neighbor::new(0u32, 1.0_f32), + Neighbor::new(1u32, 2.0_f32), + Neighbor::new(2u32, 3.0_f32), + Neighbor::new(3u32, 4.0_f32), + ]; + + let mut ids = [u32::MAX; 4]; + let mut distances = [f32::MAX; 4]; + let mut output = IdDistance::new(&mut ids, &mut distances); + + let query_vec = [0.0f32, 0.0]; + let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); + + // CopyIds simply copies whatever it receives into the output buffer, + // so the output reflects exactly what FilterResults lets through. + let count = rt + .block_on( + FilterResults { + inner_post_processor: glue::CopyIds, + } + .post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + ), + ) + .expect("post_process"); + + // Only the two red-labeled candidates should have been forwarded + assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); + let passed = &ids[..count]; + assert!( + passed.contains(&0), + "ID 0 (color=red) should pass the filter" + ); + assert!( + passed.contains(&1), + "ID 1 (color=red) should pass the filter" + ); + } +} From 5664dcaced77cdda5a3e619941496808886ec857 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:32:26 +0530 Subject: [PATCH 33/39] Update output serializer + remove unnecessary type parameter --- .../src/backend/document_index/benchmark.rs | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index 6cd1ec1fa..a284e93db 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,8 +15,7 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - search::Knn, - search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, + search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, }, provider::DefaultContext, ANNError, ANNErrorKind, @@ -671,40 +670,14 @@ where .ok_or_else(|| anyhow::anyhow!("no search results")) } -/// Helper module for serializing arrays as compact single-line JSON strings -mod compact_array { - use serde::Serializer; - - pub fn serialize_u32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } - - pub fn serialize_f32_vec(vec: &Vec, serializer: S) -> Result - where - S: Serializer, - { - // Serialize as a string containing the compact JSON array - let compact = serde_json::to_string(vec).unwrap_or_default(); - serializer.serialize_str(&compact) - } -} - /// Per-query detailed results for debugging/analysis #[derive(Debug, Serialize)] pub struct PerQueryDetails { pub query_id: usize, pub filter: String, pub recall: f64, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] pub result_ids: Vec, - #[serde(serialize_with = "compact_array::serialize_f32_vec")] pub result_distances: Vec, - #[serde(serialize_with = "compact_array::serialize_u32_vec")] pub groundtruth_ids: Vec, } @@ -817,7 +790,7 @@ struct DocumentIndexBuilder { index: Arc>, data: Arc>, attributes: Arc>>, - strategy: DocumentInsertStrategy, + strategy: DocumentInsertStrategy, } impl DocumentIndexBuilder { @@ -825,7 +798,7 @@ impl DocumentIndexBuilder { index: Arc>, data: Arc>, attributes: Arc>>, - strategy: DocumentInsertStrategy, + strategy: DocumentInsertStrategy, ) -> Arc { Arc::new(Self { index, @@ -841,9 +814,9 @@ where DP: diskann::provider::DataProvider + for<'doc> diskann::provider::SetElement> + AsyncFriendly, - for<'doc> DocumentInsertStrategy: + for<'doc> DocumentInsertStrategy: diskann::graph::glue::InsertStrategy>, - DocumentInsertStrategy: AsyncFriendly, + DocumentInsertStrategy: AsyncFriendly, T: AsyncFriendly, { type Output = (); From b03bef40f315cff77d137413dec31d5e9604787c Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:51:18 +0530 Subject: [PATCH 34/39] Put the benchmarks and the smoke test behind a feature --- .github/workflows/ci.yml | 2 +- diskann-benchmark/Cargo.toml | 3 ++ .../src/backend/document_index/mod.rs | 23 ++++++++++- diskann-benchmark/src/main.rs | 40 +++++++++++++++++-- 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0a31f86cc..fb8da8552 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -293,7 +293,7 @@ jobs: --cargo-profile ci \ --config "$RUST_CONFIG" \ --features \ - virtual_storage,bf_tree,spherical-quantization,product-quantization,tracing,experimental_diversity_search + virtual_storage,bf_tree,spherical-quantization,product-quantization,tracing,experimental_diversity_search,document-index cargo test --locked --doc --workspace --profile ci --config "$RUST_CONFIG" diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..39d64b1bb 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -63,6 +63,9 @@ scalar-quantization = [] # Enable minmax-quantization based algorithms minmax-quantization = [] +# Enable Document Index benchmarks +document-index = [] + # Enable Disk Index benchmarks disk-index = [ "diskann-disk/perf_test", diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs index 9937590cc..022470578 100644 --- a/diskann-benchmark/src/backend/document_index/mod.rs +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -8,6 +8,25 @@ //! This benchmark tests the DocumentInsertStrategy which enables inserting //! Document objects (vector + attributes) into a DiskANN index. -mod benchmark; +use diskann_benchmark_runner::registry::Benchmarks; -pub(crate) use benchmark::register_benchmarks; +cfg_if::cfg_if! { + if #[cfg(feature = "document-index")] { + mod benchmark; + + /// Register document index benchmarks when the `document-index` feature is enabled. + pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { + benchmark::register_benchmarks(registry); + } + } else { + crate::utils::stub_impl!( + "document-index", + inputs::document_index::DocumentIndexBuild + ); + + /// Register a stub that guides users to enable the `document-index` feature. + pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { + imp::register("document-index", registry); + } + } +} diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index 7bc382782..cdca116e3 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -794,8 +794,17 @@ mod tests { prefix_search_directories(&mut raw, &root_directory()); save_to_file(&modified_input_path, &raw); + run_document_filter_integration(&modified_input_path, &output_path, &raw); + } + + #[cfg(feature = "document-index")] + fn run_document_filter_integration( + input_path: &std::path::Path, + output_path: &std::path::Path, + raw: &serde_json::Value, + ) { let command = Commands::Run { - input_file: modified_input_path.to_owned(), + input_file: input_path.to_owned(), output_file: output_path.to_owned(), dry_run: false, }; @@ -809,7 +818,32 @@ mod tests { // Check that the results file is generated. assert!(output_path.exists()); - let results: Vec = load_from_file(&output_path); - assert_eq!(results.len(), num_jobs(&raw)); + let results: Vec = load_from_file(output_path); + assert_eq!(results.len(), num_jobs(raw)); + } + + #[cfg(not(feature = "document-index"))] + fn run_document_filter_integration( + input_path: &std::path::Path, + output_path: &std::path::Path, + _raw: &serde_json::Value, + ) { + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + let err = cli.run(&mut output).unwrap_err(); + println!("err = {:?}", err); + + let output = String::from_utf8(output.into_inner()).unwrap(); + assert!(output.contains("\"document-index\" feature")); + println!("output = {}", output); + + // The output file should not have been created because we failed. + assert!(!output_path.exists()); } } From c1b7a3c8aa1af0188d389d9ce7a93acdf498be60 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Mon, 23 Mar 2026 18:57:10 +0530 Subject: [PATCH 35/39] changes to Cargo.lock --- Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.lock b/Cargo.lock index 6a06f4dfc..67a10f71b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,6 +799,7 @@ dependencies = [ "serde_json", "tempfile", "thiserror 2.0.17", + "tokio", "tracing", ] From 5ebe4770c4a202c65a464fa5f237c8e31769460b Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 19:05:07 +0530 Subject: [PATCH 36/39] Move the tests to a separate folder as is the convention --- .../document_insert_strategy.rs | 123 --------- .../inline_beta_search/inline_beta_filter.rs | 233 +----------------- diskann-label-filter/src/lib.rs | 4 + .../tests/document_insert_strategy_test.rs | 126 ++++++++++ .../src/tests/inline_beta_filter_test.rs | 226 +++++++++++++++++ 5 files changed, 362 insertions(+), 350 deletions(-) create mode 100644 diskann-label-filter/src/tests/document_insert_strategy_test.rs create mode 100644 diskann-label-filter/src/tests/inline_beta_filter_test.rs diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 15f40e5fe..f26596147 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -229,126 +229,3 @@ where } } -#[cfg(test)] -mod tests { - use super::{DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor}; - use diskann::{ - graph::{ - glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, - test::provider::{Config, Context, Provider, StartPoint, Strategy}, - }, - provider::BuildQueryComputer, - }; - use diskann_vector::distance::Metric; - - use crate::{ - document::Document, - encoded_attribute_provider::{ - document_provider::DocumentProvider, roaring_attribute_store::RoaringAttributeStore, - }, - }; - - // --------------------------------------------------------------------------- - // Helpers - // --------------------------------------------------------------------------- - - /// Build a minimal test provider with a single start point and three dimensions. - fn make_test_provider() -> Provider { - let config = Config::new( - Metric::L2, - 10, - StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), - ) - .expect("test provider config should be valid"); - Provider::new(config) - } - - fn make_doc_provider( - provider: Provider, - ) -> DocumentProvider> { - DocumentProvider::new(provider, RoaringAttributeStore::new()) - } - - /// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the - /// inner accessor. - #[test] - fn test_search_accessor_creates_wrapped_accessor() { - let strategy = DocumentInsertStrategy::new(Strategy::new()); - let provider = make_doc_provider(make_test_provider()); - let context = Context::new(); - - let result = as SearchStrategy< - DocumentProvider>, - Document<'_, [f32]>, - >>::search_accessor(&strategy, &provider, &context); - - assert!(result.is_ok()); - } - - #[test] - fn test_insert_search_accessor_creates_wrapped_accessor() { - let strategy = DocumentInsertStrategy::new(Strategy::new()); - let provider = make_doc_provider(make_test_provider()); - let context = Context::new(); - - let result = as InsertStrategy< - DocumentProvider>, - Document<'_, [f32]>, - >>::insert_search_accessor(&strategy, &provider, &context); - - assert!(result.is_ok()); - } - - #[test] - fn test_prune_accessor_delegates_to_inner_provider() { - let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); - let provider = make_doc_provider(make_test_provider()); - let context = Context::new(); - - let result = as PruneStrategy< - DocumentProvider>, - >>::prune_accessor(&doc_prune_strategy, &provider, &context); - - assert!(result.is_ok()); - } - - #[test] - fn test_build_query_computer_extracts_vector_from_document() { - let provider = make_test_provider(); - let context = Context::new(); - let strategy_inner = Strategy::new(); - let inner_accessor = strategy_inner - .search_accessor(&provider, &context) - .expect("creating search accessor should succeed"); - let doc_accessor = DocumentSearchAccessor::new(inner_accessor); - - let vector = vec![1.0f32, 2.0, 0.0]; - let doc = Document::new(vector.as_slice(), vec![]); - - let result = as BuildQueryComputer< - Document<'_, [f32]>, - >>::build_query_computer(&doc_accessor, &doc); - - assert!( - result.is_ok(), - "build_query_computer should succeed for a valid vector" - ); - } - - #[test] - fn test_terminate_early_delegates_to_inner() { - let provider = make_test_provider(); - let context = Context::new(); - let strategy_inner = Strategy::new(); - let mut inner_accessor = strategy_inner - .search_accessor(&provider, &context) - .expect("creating search accessor should succeed"); - let inner_terminate_early = inner_accessor.terminate_early(); - let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); - assert_eq!( - inner_terminate_early, - doc_accessor.terminate_early(), - "terminate_early should have same value as inner accessor" - ); - } -} diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b1cfbca4b..33b7d7fc7 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -126,6 +126,12 @@ pub struct FilterResults { inner_post_processor: IPP, } +impl FilterResults { + pub(crate) fn new(inner_post_processor: IPP) -> Self { + Self { inner_post_processor } + } +} + impl<'a, Q, IA, IPP> SearchPostProcess, FilteredQuery<'a, Q>> for FilterResults where @@ -175,230 +181,3 @@ where .map_err(|e| e.into()) } } - -#[cfg(test)] -mod tests { - use std::sync::{Arc, RwLock}; - - use diskann::{ - graph::{ - glue::{self, SearchPostProcess, SearchStrategy}, - search_output_buffer::IdDistance, - test::provider::{Config, Context, Provider, StartPoint, Strategy}, - }, - neighbor::Neighbor, - provider::{BuildQueryComputer, SetElement}, - }; - use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; - use roaring::RoaringTreemap; - use serde_json::Value; - - use crate::{ - attribute::{Attribute, AttributeValue}, - document::EncodedDocument, - encoded_attribute_provider::{ - attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, - roaring_attribute_store::RoaringAttributeStore, - }, - inline_beta_search::encoded_document_accessor::EncodedDocumentAccessor, - query::FilteredQuery, - traits::attribute_store::AttributeStore, - ASTExpr, CompareOp, - }; - - use super::{FilterResults, InlineBetaComputer}; - - // ----------------------------------------------------------------------- - // Stub inner distance computer - // ----------------------------------------------------------------------- - - /// Always returns a fixed constant distance, regardless of the vector value. - struct ConstComputer(f32); - - impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { - fn evaluate_similarity(&self, _: &[f32]) -> f32 { - self.0 - } - } - - // ----------------------------------------------------------------------- - // Helper: build an AttributeEncoder + ASTExpr for `field == value`, - // returning (attr_map, ast_expr, encoded_id_of_that_attribute). - // ----------------------------------------------------------------------- - - fn setup_encoder_and_filter( - field: &str, - value: &str, - ) -> (Arc>, ASTExpr, u64) { - let mut encoder = AttributeEncoder::new(); - let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); - let encoded_id = encoder.insert(&attr); - let attr_map = Arc::new(RwLock::new(encoder)); - let ast_expr = ASTExpr::Compare { - field: field.to_string(), - op: CompareOp::Eq(Value::String(value.to_string())), - }; - (attr_map, ast_expr, encoded_id) - } - - // ----------------------------------------------------------------------- - // Test 1: when the filter matches, evaluate_similarity returns inner * beta - // ----------------------------------------------------------------------- - - #[test] - fn test_evaluate_similarity_filter_match_scales_by_beta() { - let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); - let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); - - let beta = 2.5_f32; - let inner_dist = 4.0_f32; - let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); - - // Bitmap contains the encoded ID for "color=red" → predicate matches - let mut matching_map = RoaringTreemap::new(); - matching_map.insert(color_red_id); - let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); - - assert_eq!( - computer.evaluate_similarity(doc), - inner_dist * beta, - "a matched filter should multiply the inner similarity by beta" - ); - } - - // ----------------------------------------------------------------------- - // Test 2: when the filter does not match, evaluate_similarity is unchanged - // ----------------------------------------------------------------------- - - #[test] - fn test_evaluate_similarity_no_filter_match_preserves_score() { - let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); - let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); - - let beta = 2.5_f32; - let inner_dist = 4.0_f32; - let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); - - // Empty bitmap → no attribute matches the predicate - let empty_map = RoaringTreemap::new(); - let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); - - assert_eq!( - computer.evaluate_similarity(doc), - inner_dist, - "an unmatched filter should leave the inner similarity unchanged" - ); - } - - // ----------------------------------------------------------------------- - // Test 3: post_process forwards only filter-matching candidates to the - // inner post processor (and therefore to the output buffer). - // ----------------------------------------------------------------------- - - #[test] - fn test_post_process_only_passes_matching_candidates_to_inner() { - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .expect("test tokio runtime"); - - // IDs 0 and 1 carry color=red (should pass the filter) - // IDs 2 and 3 carry color=blue (should be dropped by the filter) - let attr_store = RoaringAttributeStore::::new(); - let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); - let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); - for id in 0u32..2 { - attr_store - .set_element(&id, std::slice::from_ref(&red)) - .expect("set red attr"); - } - for id in 2u32..4 { - attr_store - .set_element(&id, std::slice::from_ref(&blue)) - .expect("set blue attr"); - } - - // The attribute_map is shared so EncodedFilterExpr sees the same encodings - // as those stored by the attribute store. - let attr_map = attr_store.attribute_map(); - - let ast_expr = ASTExpr::Compare { - field: "color".to_string(), - op: CompareOp::Eq(Value::String("red".to_string())), - }; - let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); - - // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 - let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) - .expect("provider config"); - let inner_provider = Provider::new(config); - let ctx = Context::new(); - rt.block_on(async { - for id in 0u32..4 { - inner_provider - .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) - .await - .expect("add vector to inner provider"); - } - }); - - // Obtain the inner search accessor and derive an inner computer from it - let strategy = Strategy::new(); - let inner_accessor = strategy - .search_accessor(&inner_provider, &ctx) - .expect("inner accessor"); - let inner_computer = inner_accessor - .build_query_computer(&[0.0f32, 0.0][..]) - .expect("inner computer"); - - // Wrap accessor + attribute store into an EncodedDocumentAccessor - let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); - let mut doc_accessor = - EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); - - let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); - - // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) - let candidates = [ - Neighbor::new(0u32, 1.0_f32), - Neighbor::new(1u32, 2.0_f32), - Neighbor::new(2u32, 3.0_f32), - Neighbor::new(3u32, 4.0_f32), - ]; - - let mut ids = [u32::MAX; 4]; - let mut distances = [f32::MAX; 4]; - let mut output = IdDistance::new(&mut ids, &mut distances); - - let query_vec = [0.0f32, 0.0]; - let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); - - // CopyIds simply copies whatever it receives into the output buffer, - // so the output reflects exactly what FilterResults lets through. - let count = rt - .block_on( - FilterResults { - inner_post_processor: glue::CopyIds, - } - .post_process( - &mut doc_accessor, - &filter_query, - &computer, - candidates.into_iter(), - &mut output, - ), - ) - .expect("post_process"); - - // Only the two red-labeled candidates should have been forwarded - assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); - let passed = &ids[..count]; - assert!( - passed.contains(&0), - "ID 0 (color=red) should pass the filter" - ); - assert!( - passed.contains(&1), - "ID 1 (color=red) should pass the filter" - ); - } -} diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 273475b15..414b868d4 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -53,6 +53,10 @@ pub mod tests { #[cfg(test)] pub mod common; #[cfg(test)] + pub mod document_insert_strategy_test; + #[cfg(test)] + pub mod inline_beta_filter_test; + #[cfg(test)] pub mod roaring_attribute_store_test; } diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs new file mode 100644 index 000000000..2d7f32b16 --- /dev/null +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -0,0 +1,126 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use diskann::{ + graph::{ + glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + provider::BuildQueryComputer, +}; +use diskann_vector::distance::Metric; + +use crate::{ + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::{ + DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor, + }, + document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Build a minimal test provider with a single start point and three dimensions. +fn make_test_provider() -> Provider { + let config = Config::new( + Metric::L2, + 10, + StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), + ) + .expect("test provider config should be valid"); + Provider::new(config) +} + +fn make_doc_provider( + provider: Provider, +) -> DocumentProvider> { + DocumentProvider::new(provider, RoaringAttributeStore::new()) +} + +/// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the +/// inner accessor. +#[test] +fn test_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as SearchStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_insert_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as InsertStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::insert_search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_prune_accessor_delegates_to_inner_provider() { + let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as PruneStrategy< + DocumentProvider>, + >>::prune_accessor(&doc_prune_strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_build_query_computer_extracts_vector_from_document() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let doc_accessor = DocumentSearchAccessor::new(inner_accessor); + + let vector = vec![1.0f32, 2.0, 0.0]; + let doc = Document::new(vector.as_slice(), vec![]); + + let result = as BuildQueryComputer>>::build_query_computer(&doc_accessor, &doc); + + assert!( + result.is_ok(), + "build_query_computer should succeed for a valid vector" + ); +} + +#[test] +fn test_terminate_early_delegates_to_inner() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let mut inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let inner_terminate_early = inner_accessor.terminate_early(); + let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); + assert_eq!( + inner_terminate_early, + doc_accessor.terminate_early(), + "terminate_early should have same value as inner accessor" + ); +} diff --git a/diskann-label-filter/src/tests/inline_beta_filter_test.rs b/diskann-label-filter/src/tests/inline_beta_filter_test.rs new file mode 100644 index 000000000..0ca16512d --- /dev/null +++ b/diskann-label-filter/src/tests/inline_beta_filter_test.rs @@ -0,0 +1,226 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use std::sync::{Arc, RwLock}; + +use diskann::{ + graph::{ + glue::{self, SearchPostProcess, SearchStrategy}, + search_output_buffer::IdDistance, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, SetElement}, +}; +use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; +use roaring::RoaringTreemap; +use serde_json::Value; + +use crate::{ + attribute::{Attribute, AttributeValue}, + document::EncodedDocument, + encoded_attribute_provider::{ + attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::{ + encoded_document_accessor::EncodedDocumentAccessor, + inline_beta_filter::{FilterResults, InlineBetaComputer}, + }, + query::FilteredQuery, + traits::attribute_store::AttributeStore, + ASTExpr, CompareOp, +}; + +// ----------------------------------------------------------------------- +// Stub inner distance computer +// ----------------------------------------------------------------------- + +/// Always returns a fixed constant distance, regardless of the vector value. +struct ConstComputer(f32); + +impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { + fn evaluate_similarity(&self, _: &[f32]) -> f32 { + self.0 + } +} + +// ----------------------------------------------------------------------- +// Helper: build an AttributeEncoder + ASTExpr for `field == value`, +// returning (attr_map, ast_expr, encoded_id_of_that_attribute). +// ----------------------------------------------------------------------- + +fn setup_encoder_and_filter( + field: &str, + value: &str, +) -> (Arc>, ASTExpr, u64) { + let mut encoder = AttributeEncoder::new(); + let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); + let encoded_id = encoder.insert(&attr); + let attr_map = Arc::new(RwLock::new(encoder)); + let ast_expr = ASTExpr::Compare { + field: field.to_string(), + op: CompareOp::Eq(Value::String(value.to_string())), + }; + (attr_map, ast_expr, encoded_id) +} + +// ----------------------------------------------------------------------- +// Test 1: when the filter matches, evaluate_similarity returns inner * beta +// ----------------------------------------------------------------------- + +#[test] +fn test_evaluate_similarity_filter_match_scales_by_beta() { + let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Bitmap contains the encoded ID for "color=red" → predicate matches + let mut matching_map = RoaringTreemap::new(); + matching_map.insert(color_red_id); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist * beta, + "a matched filter should multiply the inner similarity by beta" + ); +} + +// ----------------------------------------------------------------------- +// Test 2: when the filter does not match, evaluate_similarity is unchanged +// ----------------------------------------------------------------------- + +#[test] +fn test_evaluate_similarity_no_filter_match_preserves_score() { + let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Empty bitmap → no attribute matches the predicate + let empty_map = RoaringTreemap::new(); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist, + "an unmatched filter should leave the inner similarity unchanged" + ); +} + +// ----------------------------------------------------------------------- +// Test 3: post_process forwards only filter-matching candidates to the +// inner post processor (and therefore to the output buffer). +// ----------------------------------------------------------------------- + +#[test] +fn test_post_process_only_passes_matching_candidates_to_inner() { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("test tokio runtime"); + + // IDs 0 and 1 carry color=red (should pass the filter) + // IDs 2 and 3 carry color=blue (should be dropped by the filter) + let attr_store = RoaringAttributeStore::::new(); + let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); + let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); + for id in 0u32..2 { + attr_store + .set_element(&id, std::slice::from_ref(&red)) + .expect("set red attr"); + } + for id in 2u32..4 { + attr_store + .set_element(&id, std::slice::from_ref(&blue)) + .expect("set blue attr"); + } + + // The attribute_map is shared so EncodedFilterExpr sees the same encodings + // as those stored by the attribute store. + let attr_map = attr_store.attribute_map(); + + let ast_expr = ASTExpr::Compare { + field: "color".to_string(), + op: CompareOp::Eq(Value::String("red".to_string())), + }; + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); + + // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 + let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) + .expect("provider config"); + let inner_provider = Provider::new(config); + let ctx = Context::new(); + rt.block_on(async { + for id in 0u32..4 { + inner_provider + .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) + .await + .expect("add vector to inner provider"); + } + }); + + // Obtain the inner search accessor and derive an inner computer from it + let strategy = Strategy::new(); + let inner_accessor = strategy + .search_accessor(&inner_provider, &ctx) + .expect("inner accessor"); + let inner_computer = inner_accessor + .build_query_computer(&[0.0f32, 0.0][..]) + .expect("inner computer"); + + // Wrap accessor + attribute store into an EncodedDocumentAccessor + let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); + let mut doc_accessor = + EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); + + let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); + + // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) + let candidates = [ + Neighbor::new(0u32, 1.0_f32), + Neighbor::new(1u32, 2.0_f32), + Neighbor::new(2u32, 3.0_f32), + Neighbor::new(3u32, 4.0_f32), + ]; + + let mut ids = [u32::MAX; 4]; + let mut distances = [f32::MAX; 4]; + let mut output = IdDistance::new(&mut ids, &mut distances); + + let query_vec = [0.0f32, 0.0]; + let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); + + // CopyIds simply copies whatever it receives into the output buffer, + // so the output reflects exactly what FilterResults lets through. + let count = rt + .block_on( + FilterResults::new(glue::CopyIds).post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + ), + ) + .expect("post_process"); + + // Only the two red-labeled candidates should have been forwarded + assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); + let passed = &ids[..count]; + assert!( + passed.contains(&0), + "ID 0 (color=red) should pass the filter" + ); + assert!( + passed.contains(&1), + "ID 1 (color=red) should pass the filter" + ); +} From c9a97bdcf8232fa9984c275107f8f979b3bf8323 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 19:46:32 +0530 Subject: [PATCH 37/39] Fix formatting --- .../document_insert_strategy.rs | 1 - .../src/inline_beta_search/inline_beta_filter.rs | 5 ++++- .../src/tests/document_insert_strategy_test.rs | 4 +--- .../src/tests/inline_beta_filter_test.rs | 16 +++++++--------- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index f26596147..aca755833 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -228,4 +228,3 @@ where .prune_accessor(provider.inner_provider(), context) } } - diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 33b7d7fc7..ebed7d95f 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -127,8 +127,11 @@ pub struct FilterResults { } impl FilterResults { + #[cfg(test)] pub(crate) fn new(inner_post_processor: IPP) -> Self { - Self { inner_post_processor } + Self { + inner_post_processor, + } } } diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs index 2d7f32b16..5fb78dd75 100644 --- a/diskann-label-filter/src/tests/document_insert_strategy_test.rs +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -38,9 +38,7 @@ fn make_test_provider() -> Provider { Provider::new(config) } -fn make_doc_provider( - provider: Provider, -) -> DocumentProvider> { +fn make_doc_provider(provider: Provider) -> DocumentProvider> { DocumentProvider::new(provider, RoaringAttributeStore::new()) } diff --git a/diskann-label-filter/src/tests/inline_beta_filter_test.rs b/diskann-label-filter/src/tests/inline_beta_filter_test.rs index 0ca16512d..079f91eff 100644 --- a/diskann-label-filter/src/tests/inline_beta_filter_test.rs +++ b/diskann-label-filter/src/tests/inline_beta_filter_test.rs @@ -201,15 +201,13 @@ fn test_post_process_only_passes_matching_candidates_to_inner() { // CopyIds simply copies whatever it receives into the output buffer, // so the output reflects exactly what FilterResults lets through. let count = rt - .block_on( - FilterResults::new(glue::CopyIds).post_process( - &mut doc_accessor, - &filter_query, - &computer, - candidates.into_iter(), - &mut output, - ), - ) + .block_on(FilterResults::new(glue::CopyIds).post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + )) .expect("post_process"); // Only the two red-labeled candidates should have been forwarded From cc6a9b64131610ab1b105e0c6149e54e721837bf Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 21:20:07 +0530 Subject: [PATCH 38/39] Fix formatting --- .../encoded_attribute_provider/document_insert_strategy.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index aca755833..2c487bd80 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -44,15 +44,11 @@ impl DocumentInsertStrategy { /// Wrapper accessor for Document queries pub struct DocumentSearchAccessor { inner: Inner, - // _phantom: PhantomData VT>, } impl DocumentSearchAccessor { pub fn new(inner: Inner) -> Self { - Self { - inner, - // _phantom: PhantomData, - } + Self { inner } } } From 47a2e6946a397f048b88da872f58d09d0ee996f5 Mon Sep 17 00:00:00 2001 From: sampathrg Date: Thu, 26 Mar 2026 21:44:25 +0530 Subject: [PATCH 39/39] Fix build errors after merging with main --- .../src/backend/document_index/benchmark.rs | 11 ++++++----- .../document_insert_strategy.rs | 7 +------ .../src/inline_beta_search/inline_beta_filter.rs | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs index a284e93db..8f3d56175 100644 --- a/diskann-benchmark/src/backend/document_index/benchmark.rs +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -15,7 +15,8 @@ use std::sync::Arc; use anyhow::Result; use diskann::{ graph::{ - search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, StartPointStrategy, + glue, search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, + StartPointStrategy, }, provider::DefaultContext, ANNError, ANNErrorKind, @@ -429,8 +430,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: - diskann::graph::glue::SearchStrategy, u32>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, T: bytemuck::Pod + Copy + Send + Sync + 'static, { type Id = DP::ExternalId; @@ -638,8 +639,8 @@ where > + Send + Sync + 'static, - for<'a> InlineBetaStrategy: - diskann::graph::glue::SearchStrategy>, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, { let searcher = Arc::new(FilteredSearcher { index: index.clone(), diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs index 2c487bd80..6e6886eda 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -7,7 +7,7 @@ //! [DiskANNIndex] using a [DocumentProvider]. use diskann::{ - graph::glue::{self, ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + graph::glue::{ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, ANNResult, }; @@ -148,7 +148,6 @@ where VT: Sync + Send + ?Sized + 'static, { type QueryComputer = Inner::QueryComputer; - type PostProcessor = glue::CopyIds; type SearchAccessorError = Inner::SearchAccessorError; type SearchAccessor<'a> = DocumentSearchAccessor>; @@ -162,10 +161,6 @@ where .search_accessor(provider.inner_provider(), context)?; Ok(DocumentSearchAccessor::new(inner_accessor)) } - - fn post_processor(&self) -> Self::PostProcessor { - glue::CopyIds - } } impl<'doc, Inner, DP, VT> diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index bac127b37..f7a4f505d 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -74,12 +74,12 @@ where impl diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, - FilteredQuery, + FilteredQuery<'_, Q>, > for InlineBetaStrategy where DP: DataProvider, Strategy: diskann::graph::glue::DefaultPostProcessor, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type Processor = FilterResults;