From dcdc37743c890f5efa94ae54a8f45e7c8a0ae763 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 28 Apr 2026 14:47:37 -0700 Subject: [PATCH 01/38] Make benchmarks stateful. --- diskann-benchmark-runner/src/benchmark.rs | 65 ++-- diskann-benchmark-runner/src/registry.rs | 18 +- diskann-benchmark-runner/src/test/dim.rs | 19 +- diskann-benchmark-runner/src/test/mod.rs | 13 +- diskann-benchmark-runner/src/test/typed.rs | 32 +- diskann-benchmark-simd/src/lib.rs | 322 ++++++++-------- .../src/backend/disk_index/benchmarks.rs | 92 +++-- .../src/backend/exhaustive/minmax.rs | 45 +-- .../src/backend/exhaustive/product.rs | 34 +- .../src/backend/exhaustive/spherical.rs | 38 +- .../src/backend/filters/benchmark.rs | 167 ++++---- .../src/backend/index/benchmarks.rs | 355 ++++++++---------- .../src/backend/index/product.rs | 71 ++-- diskann-benchmark/src/backend/index/scalar.rs | 73 ++-- .../src/backend/index/spherical.rs | 76 ++-- diskann-benchmark/src/utils/mod.rs | 6 +- 16 files changed, 674 insertions(+), 752 deletions(-) diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 6c196d8af..27cb910a9 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -15,7 +15,7 @@ use crate::{ /// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will /// first be validated with the benchmark using [`try_match`](Self::try_match). Only /// successful matches will be passed to [`run`](Self::run). -pub trait Benchmark { +pub trait Benchmark: 'static { /// The [`Input`] type this benchmark matches against. type Input: Input + 'static; @@ -32,7 +32,7 @@ pub trait Benchmark { /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`] /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging. - fn try_match(input: &Self::Input) -> Result; + fn try_match(&self, input: &Self::Input) -> Result; /// Return descriptive information about the benchmark. /// @@ -40,6 +40,7 @@ pub trait Benchmark { /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what /// was expected should be generated to help users. fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&Self::Input>, ) -> std::fmt::Result; @@ -52,6 +53,7 @@ pub trait Benchmark { /// /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`. fn run( + &self, input: &Self::Input, checkpoint: Checkpoint<'_>, output: &mut dyn Output, @@ -88,6 +90,7 @@ pub trait Regression: Benchmark Deserialize<'a>> { /// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type /// for reporting upstream. fn check( + &self, tolerances: &Self::Tolerances, input: &Self::Input, before: &Self::Output, @@ -109,8 +112,6 @@ pub enum PassFail { pub(crate) mod internal { use super::*; - use std::marker::PhantomData; - use anyhow::Context; use thiserror::Error; @@ -176,38 +177,32 @@ pub(crate) mod internal { } } - pub(crate) trait AsRegression { - fn as_regression(&self) -> Option<&dyn Regression>; + pub(crate) trait AsRegression { + fn as_regression(benchmark: &T) -> Option<&dyn Regression>; } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, Copy)] pub(crate) struct NoRegression; - impl AsRegression for NoRegression { - fn as_regression(&self) -> Option<&dyn Regression> { + impl AsRegression for NoRegression { + fn as_regression(_benchmark: &T) -> Option<&dyn Regression> { None } } #[derive(Debug, Clone, Copy)] - pub(crate) struct WithRegression(PhantomData); + pub(crate) struct WithRegression; - impl WithRegression { - pub(crate) const fn new() -> Self { - Self(PhantomData) - } - } - - impl AsRegression for WithRegression + impl AsRegression for WithRegression where T: super::Regression, { - fn as_regression(&self) -> Option<&dyn Regression> { - Some(self) + fn as_regression(benchmark: &T) -> Option<&dyn Regression> { + Some(benchmark) } } - impl Regression for WithRegression + impl Regression for T where T: super::Regression, { @@ -242,7 +237,7 @@ pub(crate) mod internal { let after = T::Output::deserialize(after) .map_err(|err| DeserializationError::new(Kind::After, err))?; - let passfail = match T::check(tolerance, input, &before, &after)? { + let passfail = match self.check(tolerance, input, &before, &after)? { PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?), PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?), }; @@ -253,21 +248,15 @@ pub(crate) mod internal { #[derive(Debug, Clone, Copy)] pub(crate) struct Wrapper { - regression: R, - _type: PhantomData, - } - - impl Wrapper { - pub(crate) const fn new() -> Self { - Self::new_with(NoRegression) - } + benchmark: T, + _regression: R, } impl Wrapper { - pub(crate) const fn new_with(regression: R) -> Self { + pub(crate) const fn new(benchmark: T, regression: R) -> Self { Self { - regression, - _type: PhantomData, + benchmark, + _regression: regression, } } } @@ -278,11 +267,11 @@ pub(crate) mod internal { impl Benchmark for Wrapper where T: super::Benchmark, - R: AsRegression, + R: AsRegression, { fn try_match(&self, input: &Any) -> Result { if let Some(cast) = input.downcast_ref::() { - T::try_match(cast) + self.benchmark.try_match(cast) } else { Err(MATCH_FAIL) } @@ -295,7 +284,7 @@ pub(crate) mod internal { ) -> std::fmt::Result { match input { Some(input) => match input.downcast_ref::() { - Some(cast) => T::description(f, Some(cast)), + Some(cast) => self.benchmark.description(f, Some(cast)), None => write!( f, "expected tag \"{}\" - instead got \"{}\"", @@ -305,7 +294,7 @@ pub(crate) mod internal { }, None => { writeln!(f, "tag \"{}\"", ::tag())?; - T::description(f, None) + self.benchmark.description(f, None) } } } @@ -318,7 +307,7 @@ pub(crate) mod internal { ) -> anyhow::Result { match input.downcast_ref::() { Some(input) => { - let result = T::run(input, checkpoint, output)?; + let result = self.benchmark.run(input, checkpoint, output)?; Ok(serde_json::to_value(result)?) } None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()), @@ -327,7 +316,7 @@ pub(crate) mod internal { // Extensions fn as_regression(&self) -> Option<&dyn Regression> { - self.regression.as_regression() + R::as_regression(&self.benchmark) } } diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 73e10c605..5d8c7366c 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -108,13 +108,16 @@ impl Benchmarks { } /// Register a new benchmark with the given name. - pub fn register(&mut self, name: impl Into) + pub fn register(&mut self, name: impl Into, benchmark: T) where - T: Benchmark + 'static, + T: Benchmark, { self.benchmarks.push(RegisteredBenchmark { name: name.into(), - benchmark: Box::new(benchmark::internal::Wrapper::::new()), + benchmark: Box::new(benchmark::internal::Wrapper::::new( + benchmark, + benchmark::internal::NoRegression, + )), }); } @@ -212,12 +215,13 @@ impl Benchmarks { /// /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark /// itself will be reachable via [`Check`](crate::app::Check). - pub fn register_regression(&mut self, name: impl Into) + pub fn register_regression(&mut self, name: impl Into, benchmark: T) where - T: Regression + 'static, + T: Regression, { - let registered = benchmark::internal::Wrapper::::new_with( - benchmark::internal::WithRegression::::new(), + let registered = benchmark::internal::Wrapper::::new( + benchmark, + benchmark::internal::WithRegression, ); self.benchmarks.push(RegisteredBenchmark { name: name.into(), diff --git a/diskann-benchmark-runner/src/test/dim.rs b/diskann-benchmark-runner/src/test/dim.rs index 07e73e2a6..f0eae36a3 100644 --- a/diskann-benchmark-runner/src/test/dim.rs +++ b/diskann-benchmark-runner/src/test/dim.rs @@ -99,7 +99,7 @@ impl Benchmark for SimpleBench { type Input = DimInput; type Output = usize; - fn try_match(input: &DimInput) -> Result { + fn try_match(&self, input: &DimInput) -> Result { if input.dim.is_none() { Ok(MatchScore(0)) } else { @@ -107,7 +107,11 @@ impl Benchmark for SimpleBench { } } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&DimInput>, + ) -> std::fmt::Result { match input { Some(input) if input.dim.is_none() => write!(f, "successful match"), Some(_) => write!(f, "expected dim=None"), @@ -116,6 +120,7 @@ impl Benchmark for SimpleBench { } fn run( + &self, input: &DimInput, _checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -133,11 +138,15 @@ impl Benchmark for DimBench { type Input = DimInput; type Output = usize; - fn try_match(_input: &DimInput) -> Result { + fn try_match(&self, _input: &DimInput) -> Result { Ok(MatchScore(0)) } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&DimInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&DimInput>, + ) -> std::fmt::Result { if input.is_some() { write!(f, "perfect match") } else { @@ -146,6 +155,7 @@ impl Benchmark for DimBench { } fn run( + &self, input: &DimInput, _checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -161,6 +171,7 @@ impl Regression for DimBench { type Fail = &'static str; fn check( + &self, tolerance: &Tolerance, input: &DimInput, before: &usize, diff --git a/diskann-benchmark-runner/src/test/mod.rs b/diskann-benchmark-runner/src/test/mod.rs index 540842d4f..ea9853e5e 100644 --- a/diskann-benchmark-runner/src/test/mod.rs +++ b/diskann-benchmark-runner/src/test/mod.rs @@ -22,10 +22,13 @@ pub fn register_inputs(inputs: &mut registry::Inputs) -> anyhow::Result<()> { } pub fn register_benchmarks(benchmarks: &mut registry::Benchmarks) { - benchmarks.register_regression::>("type-bench-f32"); - benchmarks.register_regression::>("type-bench-i8"); - benchmarks.register_regression::>("exact-type-bench-f32-1000"); + benchmarks.register_regression("type-bench-f32", typed::TypeBench::::new()); + benchmarks.register_regression("type-bench-i8", typed::TypeBench::::new()); + benchmarks.register_regression( + "exact-type-bench-f32-1000", + typed::ExactTypeBench::::new(), + ); - benchmarks.register::("simple-bench"); - benchmarks.register_regression::("dim-bench"); + benchmarks.register("simple-bench", dim::SimpleBench); + benchmarks.register_regression("dim-bench", dim::DimBench); } diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index ed49b8b22..cae95f66d 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -129,6 +129,12 @@ impl CheckDeserialization for Tolerance { #[derive(Debug)] pub(super) struct TypeBench(std::marker::PhantomData); +impl TypeBench { + pub(super) fn new() -> Self { + Self(std::marker::PhantomData) + } +} + impl Benchmark for TypeBench where T: 'static, @@ -137,17 +143,22 @@ where type Input = TypeInput; type Output = String; - fn try_match(input: &TypeInput) -> Result { + fn try_match(&self, input: &TypeInput) -> Result { // Try to match based on data type. // Add a small penalty so `ExactTypeBench` can be more specific if it hits. Type::::try_match(&input.data_type).map(|m| MatchScore(m.0 + 10)) } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&TypeInput>, + ) -> std::fmt::Result { Type::::description(f, input.map(|i| &i.data_type)) } fn run( + &self, input: &TypeInput, checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -169,6 +180,7 @@ where type Fail = DataType; fn check( + &self, _tolerance: &Tolerance, input: &TypeInput, before: &String, @@ -189,6 +201,12 @@ where #[derive(Debug)] pub(super) struct ExactTypeBench(std::marker::PhantomData); +impl ExactTypeBench { + pub(super) fn new() -> Self { + Self(std::marker::PhantomData) + } +} + impl Benchmark for ExactTypeBench where T: 'static, @@ -197,7 +215,7 @@ where type Input = TypeInput; type Output = String; - fn try_match(input: &TypeInput) -> Result { + fn try_match(&self, input: &TypeInput) -> Result { if input.dim == N { Type::::try_match(&input.data_type) } else { @@ -205,7 +223,11 @@ where } } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&TypeInput>, + ) -> std::fmt::Result { match input { None => { write!(f, "{}, dim={}", Description::>::new(), N) @@ -232,6 +254,7 @@ where } fn run( + &self, input: &TypeInput, checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, @@ -253,6 +276,7 @@ where type Fail = String; fn check( + &self, _tolerance: &Tolerance, input: &TypeInput, before: &String, diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 4fb921590..8d72efb91 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -303,137 +303,104 @@ impl std::fmt::Display for CheckResult { // Benchmark Registration // //////////////////////////// -macro_rules! register { - ($arch:literal, $dispatcher:ident, $name:literal, $($kernel:tt)*) => { - #[cfg(target_arch = $arch)] - $dispatcher.register_regression::<$($kernel)*>($name) - }; - ($dispatcher:ident, $name:literal, $($kernel:tt)*) => { - $dispatcher.register_regression::<$($kernel)*>($name) - }; -} - fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) { // x86-64-v4 - register!( - "x86_64", - dispatcher, - "simd-op-f32xf32-x86_64_V4", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-f16xf16-x86_64_V4", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-u8xu8-x86_64_V4", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-i8xi8-x86_64_V4", - Kernel - ); + #[cfg(target_arch = "x86_64")] + { + dispatcher.register_regression( + "simd-op-f32xf32-x86_64_V4", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-f16xf16-x86_64_V4", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-u8xu8-x86_64_V4", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-i8xi8-x86_64_V4", + Kernel::::new(), + ); + } // x86-64-v3 - register!( - "x86_64", - dispatcher, - "simd-op-f32xf32-x86_64_V3", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-f16xf16-x86_64_V3", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-u8xu8-x86_64_V3", - Kernel - ); - register!( - "x86_64", - dispatcher, - "simd-op-i8xi8-x86_64_V3", - Kernel - ); + #[cfg(target_arch = "x86_64")] + { + dispatcher.register_regression( + "simd-op-f32xf32-x86_64_V3", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-f16xf16-x86_64_V3", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-u8xu8-x86_64_V3", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-i8xi8-x86_64_V3", + Kernel::::new(), + ); + } // aarch64-neon - register!( - "aarch64", - dispatcher, - "simd-op-f32xf32-aarch64_neon", - Kernel - ); - register!( - "aarch64", - dispatcher, - "simd-op-f16xf16-aarch64_neon", - Kernel - ); - register!( - "aarch64", - dispatcher, - "simd-op-u8xu8-aarch64_neon", - Kernel - ); - register!( - "aarch64", - dispatcher, - "simd-op-i8xi8-aarch64_neon", - Kernel - ); + #[cfg(target_arch = "aarch64")] + { + dispatcher.register_regression( + "simd-op-f32xf32-aarch64_neon", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-f16xf16-aarch64_neon", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-u8xu8-aarch64_neon", + Kernel::::new(), + ); + dispatcher.register_regression( + "simd-op-i8xi8-aarch64_neon", + Kernel::::new(), + ); + } // scalar - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f32xf32-scalar", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f16xf16-scalar", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-u8xu8-scalar", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-i8xi8-scalar", - Kernel + Kernel::::new(), ); // reference - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f32xf32-reference", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-f16xf16-reference", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-u8xu8-reference", - Kernel + Kernel::::new(), ); - register!( - dispatcher, + dispatcher.register_regression( "simd-op-i8xi8-reference", - Kernel + Kernel::::new(), ); } @@ -449,14 +416,12 @@ struct Reference; struct Identity(T); struct Kernel { - arch: A, _type: std::marker::PhantomData<(A, Q, D)>, } impl Kernel { - fn new(arch: A) -> Self { + fn new() -> Self { Self { - arch, _type: std::marker::PhantomData, } } @@ -582,13 +547,16 @@ where datatype::Type: DispatchRule, datatype::Type: DispatchRule, Identity: DispatchRule, - Kernel: RunBenchmark, + Kernel: RunBenchmark, + A: 'static, + Q: 'static, + D: 'static, { type Input = SimdOp; type Output = Vec; // Matching simply requires that we match the inner type. - fn try_match(from: &SimdOp) -> Result { + fn try_match(&self, from: &SimdOp) -> Result { let mut failscore: Option = None; if datatype::Type::::try_match(&from.query_type).is_err() { *failscore.get_or_insert(0) += 10; @@ -607,19 +575,23 @@ where } fn run( + &self, input: &SimdOp, _: diskann_benchmark_runner::Checkpoint<'_>, mut output: &mut dyn diskann_benchmark_runner::Output, ) -> anyhow::Result { let arch = Identity::::convert(input.arch)?.0; - let kernel = Self::new(arch); writeln!(output, "{}", input)?; - let results = kernel.run(input)?; + let results = self.run_benchmark(input, arch)?; writeln!(output, "\n\n{}", DisplayWrapper(&*results))?; Ok(results) } - fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&SimdOp>) -> std::fmt::Result { + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&SimdOp>, + ) -> std::fmt::Result { match input { None => { describeln!( @@ -659,13 +631,17 @@ where datatype::Type: DispatchRule, datatype::Type: DispatchRule, Identity: DispatchRule, - Kernel: RunBenchmark, + Kernel: RunBenchmark, + A: 'static, + Q: 'static, + D: 'static, { type Tolerances = SimdTolerance; type Pass = CheckResult; type Fail = CheckResult; fn check( + &self, tolerance: &SimdTolerance, _input: &SimdOp, before: &Vec, @@ -724,8 +700,8 @@ where // Benchmark // /////////////// -trait RunBenchmark { - fn run(self, input: &SimdOp) -> Result, anyhow::Error>; +trait RunBenchmark { + fn run_benchmark(&self, input: &SimdOp, arch: A) -> Result, anyhow::Error>; } #[derive(Debug, Serialize, Deserialize)] @@ -856,8 +832,12 @@ impl Data { macro_rules! stamp { (reference, $Q:ty, $D:ty, $f_l2:ident, $f_ip:ident, $f_cosine:ident) => { - impl RunBenchmark for Kernel { - fn run(self, input: &SimdOp) -> Result, anyhow::Error> { + impl RunBenchmark for Kernel { + fn run_benchmark( + &self, + input: &SimdOp, + _arch: Reference, + ) -> Result, anyhow::Error> { let mut results = Vec::new(); for run in input.runs.iter() { let data = Data::<$Q, $D>::new(run); @@ -873,8 +853,12 @@ macro_rules! stamp { } }; ($arch:path, $Q:ty, $D:ty) => { - impl RunBenchmark for Kernel<$arch, $Q, $D> { - fn run(self, input: &SimdOp) -> Result, anyhow::Error> { + impl RunBenchmark<$arch> for Kernel<$arch, $Q, $D> { + fn run_benchmark( + &self, + input: &SimdOp, + arch: $arch, + ) -> Result, anyhow::Error> { let mut results = Vec::new(); let l2 = &simd::L2 {}; @@ -891,16 +875,13 @@ macro_rules! stamp { // target features. let result = match run.distance { SimilarityMeasure::SquaredL2 => data.run(run, |q, d| { - self.arch - .run2(|q, d| simd::simd_op(l2, self.arch, q, d), q, d) + arch.run2(|q, d| simd::simd_op(l2, arch, q, d), q, d) }), SimilarityMeasure::InnerProduct => data.run(run, |q, d| { - self.arch - .run2(|q, d| simd::simd_op(ip, self.arch, q, d), q, d) + arch.run2(|q, d| simd::simd_op(ip, arch, q, d), q, d) }), SimilarityMeasure::Cosine => data.run(run, |q, d| { - self.arch - .run2(|q, d| simd::simd_op(cosine, self.arch, q, d), q, d) + arch.run2(|q, d| simd::simd_op(cosine, arch, q, d), q, d) }), }; results.push(result) @@ -1237,60 +1218,64 @@ mod tests { #[test] fn check_rejects_mismatched_runs() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let err = Bench::check( - &tolerance(0.0), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::Cosine, 100)], - ) - .unwrap_err(); + let err = kernel + .check( + &tolerance(0.0), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::Cosine, 100)], + ) + .unwrap_err(); assert_eq!(err.to_string(), "run 0 mismatched"); } #[test] fn check_allows_negative_relative_change() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let result = Bench::check( - &tolerance(0.0), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)], - ) - .unwrap(); + let result = kernel + .check( + &tolerance(0.0), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 95)], + ) + .unwrap(); assert!(matches!(result, PassFail::Pass(_))); } #[test] fn check_passes_on_tolerance_boundary() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let result = Bench::check( - &tolerance(0.05), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)], - ) - .unwrap(); + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 105)], + ) + .unwrap(); assert!(matches!(result, PassFail::Pass(_))); } #[test] fn check_fails_above_tolerance_boundary() { - type Bench = Kernel; + let kernel = Kernel::::new(); - let result = Bench::check( - &tolerance(0.05), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)], - ) - .unwrap(); + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 100)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 106)], + ) + .unwrap(); assert!(matches!(result, PassFail::Fail(_))); } @@ -1322,15 +1307,16 @@ mod tests { // We require at least a non-zero value. #[test] fn zero_values_rejected() { - type Bench = Kernel; - - let result = Bench::check( - &tolerance(0.05), - &tiny_op(), - &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], - &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], - ) - .unwrap(); + let kernel = Kernel::::new(); + + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], + &vec![tiny_result(SimilarityMeasure::SquaredL2, 0)], + ) + .unwrap(); assert!(matches!(result, PassFail::Fail(_))); } diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index fa9b036ad..6c5298dd8 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -30,8 +30,7 @@ use crate::{ }; /// Disk Index -struct DiskIndex<'a, T> { - input: &'a DiskIndexOperation, +struct DiskIndex { _vector_type: std::marker::PhantomData, } @@ -41,53 +40,18 @@ pub(super) struct DiskIndexStats { pub(super) search: DiskSearchStats, } -impl<'a, T> DiskIndex<'a, T> +impl DiskIndex where T: VectorRepr, { - fn new(input: &'a DiskIndexOperation) -> Self { + fn new() -> Self { Self { - input, _vector_type: std::marker::PhantomData, } } - - fn run( - &self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input.source)?; - let (build_stats, index_load) = match &self.input.source { - DiskIndexSource::Load(load) => Ok((None, (*load).clone())), - DiskIndexSource::Build(build) => build_disk_index::(&FileStorageProvider, build) - .map(|stats| { - ( - Some(stats), - DiskIndexLoad { - data_type: build.data_type, - load_path: build.save_path.clone(), - }, - ) - }), - }?; - if let Some(build_stats) = &build_stats { - writeln!(output, "{}", build_stats)?; - } - - writeln!(output, "{}", self.input.search_phase)?; - let search_stats = - search_disk_index::(&index_load, &self.input.search_phase, &FileStorageProvider)?; - writeln!(output, "{}", search_stats)?; - - Ok(DiskIndexStats { - build: build_stats, - search: search_stats, - }) - } } -impl Benchmark for DiskIndex<'static, T> +impl Benchmark for DiskIndex where T: VectorRepr + 'static, Type: DispatchRule, @@ -95,7 +59,7 @@ where type Input = DiskIndexOperation; type Output = DiskIndexStats; - fn try_match(input: &DiskIndexOperation) -> Result { + fn try_match(&self, input: &DiskIndexOperation) -> Result { match &input.source { DiskIndexSource::Load(load) => Type::::try_match(&load.data_type), DiskIndexSource::Build(build) => Type::::try_match(&build.data_type), @@ -103,6 +67,7 @@ where } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&DiskIndexOperation>, ) -> std::fmt::Result { @@ -116,11 +81,39 @@ where } fn run( + &self, input: &DiskIndexOperation, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, ) -> anyhow::Result { - DiskIndex::::new(input).run(checkpoint, output) + writeln!(output, "{}", input.source)?; + + let (build_stats, index_load) = match &input.source { + DiskIndexSource::Load(load) => Ok((None, (*load).clone())), + DiskIndexSource::Build(build) => build_disk_index::(&FileStorageProvider, build) + .map(|stats| { + ( + Some(stats), + DiskIndexLoad { + data_type: build.data_type, + load_path: build.save_path.clone(), + }, + ) + }), + }?; + if let Some(build_stats) = &build_stats { + writeln!(output, "{}", build_stats)?; + } + + writeln!(output, "{}", input.search_phase)?; + let search_stats = + search_disk_index::(&index_load, &input.search_phase, &FileStorageProvider)?; + writeln!(output, "{}", search_stats)?; + + Ok(DiskIndexStats { + build: build_stats, + search: search_stats, + }) } } @@ -129,10 +122,10 @@ where //////////////////////////// pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { - benchmarks.register_regression::>("disk-index-f32"); - benchmarks.register_regression::>("disk-index-f16"); - benchmarks.register_regression::>("disk-index-u8"); - benchmarks.register_regression::>("disk-index-i8"); + benchmarks.register_regression("disk-index-f32", DiskIndex::::new()); + benchmarks.register_regression("disk-index-f16", DiskIndex::::new()); + benchmarks.register_regression("disk-index-u8", DiskIndex::::new()); + benchmarks.register_regression("disk-index-i8", DiskIndex::::new()); } ///////////////////////// @@ -302,7 +295,7 @@ fn check_metric( } } -impl Regression for DiskIndex<'static, T> +impl Regression for DiskIndex where T: VectorRepr + 'static, Type: DispatchRule, @@ -312,6 +305,7 @@ where type Fail = DiskIndexCheckResult; fn check( + &self, tolerances: &DiskIndexTolerance, _input: &DiskIndexOperation, before: &DiskIndexStats, diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index 26b3da16a..73b57733f 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -12,10 +12,10 @@ crate::utils::stub_impl!("minmax-quantization", inputs::exhaustive::MinMax); // MinMax - requires feature "minmax-quantization" #[cfg(feature = "minmax-quantization")] pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::MinMaxQ::<1>); + benchmarks.register(NAME, imp::MinMaxQ::<2>); + benchmarks.register(NAME, imp::MinMaxQ::<4>); + benchmarks.register(NAME, imp::MinMaxQ::<8>); } // Stub implementation @@ -85,23 +85,21 @@ mod imp { Ok(progress) } - /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct MinMaxQ<'a, const NBITS: usize> { - input: &'a inputs::exhaustive::MinMax, - } - - impl<'a, const NBITS: usize> MinMaxQ<'a, NBITS> { - pub(super) fn new(input: &'a inputs::exhaustive::MinMax) -> Self { - Self { input } - } + /// The dispatcher target for `minmax-quantization` operations. + #[derive(Debug, Clone, Copy)] + pub(super) struct MinMaxQ; - pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result + impl MinMaxQ { + pub(super) fn run( + &self, + input: &inputs::exhaustive::MinMax, + mut output: &mut dyn Output, + ) -> anyhow::Result where Unsigned: Representation, Plan: algos::CreateQuantComputer>, { - let input = &self.input; - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; // Training let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?; @@ -111,13 +109,13 @@ mod imp { let dim = NonZeroUsize::new(data.ncols()).unwrap(); let transform = Transform::new( - (&self.input.transform_kind).into(), + (&input.transform_kind).into(), dim, Some(&mut rng), diskann_quantization::alloc::GlobalAllocator, )?; - let quantizer = MinMaxQuantizer::new(transform, Positive::new(self.input.scale)?); + let quantizer = MinMaxQuantizer::new(transform, Positive::new(input.scale)?); let training_time: MicroSeconds = start.elapsed().into(); @@ -198,7 +196,7 @@ mod imp { } } - impl Benchmark for MinMaxQ<'static, NBITS> + impl Benchmark for MinMaxQ where Unsigned: Representation, Plan: algos::CreateQuantComputer>, @@ -206,7 +204,10 @@ mod imp { type Input = inputs::exhaustive::MinMax; type Output = Results; - fn try_match(input: &inputs::exhaustive::MinMax) -> Result { + fn try_match( + &self, + input: &inputs::exhaustive::MinMax, + ) -> Result { let num_bits = input.num_bits.get(); if num_bits == NBITS { Ok(MatchScore(0)) @@ -218,6 +219,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&inputs::exhaustive::MinMax>, ) -> std::fmt::Result { @@ -246,11 +248,12 @@ mod imp { } fn run( + &self, input: &inputs::exhaustive::MinMax, _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - MinMaxQ::::new(input).run(output) + self.run(input, output) } } diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 4723753a0..7504a28e3 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -11,7 +11,7 @@ crate::utils::stub_impl!("product-quantization", inputs::exhaustive::Product); pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "product-quantization")] - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::ProductQ); #[cfg(not(feature = "product-quantization"))] imp::register(NAME, benchmarks) @@ -65,18 +65,17 @@ mod imp { } /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct ProductQ<'a> { - input: &'a inputs::exhaustive::Product, - } - - impl<'a> ProductQ<'a> { - pub(super) fn new(input: &'a inputs::exhaustive::Product) -> Self { - Self { input } - } + #[derive(Debug, Clone, Copy)] + pub(super) struct ProductQ; - pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result { - let input = &self.input; - writeln!(output, "{}", self.input)?; + impl ProductQ { + pub(super) fn run( + &self, + input: &inputs::exhaustive::Product, + mut output: &mut dyn Output, + ) -> anyhow::Result { + let input = &input; + writeln!(output, "{}", input)?; // Training let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?; @@ -190,15 +189,19 @@ mod imp { } } - impl Benchmark for ProductQ<'static> { + impl Benchmark for ProductQ { type Input = inputs::exhaustive::Product; type Output = Results; - fn try_match(_input: &inputs::exhaustive::Product) -> Result { + fn try_match( + &self, + _input: &inputs::exhaustive::Product, + ) -> Result { Ok(MatchScore(0)) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&inputs::exhaustive::Product>, ) -> std::fmt::Result { @@ -210,11 +213,12 @@ mod imp { } fn run( + &self, input: &inputs::exhaustive::Product, _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - ProductQ::new(input).run(output) + self.run(input, output) } } diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index 1c0881c56..9b1f9a935 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -12,10 +12,10 @@ crate::utils::stub_impl!("spherical-quantization", inputs::exhaustive::Spherical // Spherical - requires feature "spherical-quantization" #[cfg(feature = "spherical-quantization")] pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::SphericalQ::<1>); + benchmarks.register(NAME, imp::SphericalQ::<2>); + benchmarks.register(NAME, imp::SphericalQ::<4>); + benchmarks.register(NAME, imp::SphericalQ::<8>); } // Stub implementation @@ -79,16 +79,14 @@ mod imp { } /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct SphericalQ<'a, const NBITS: usize> { - input: &'a inputs::exhaustive::Spherical, - } - - impl<'a, const NBITS: usize> SphericalQ<'a, NBITS> { - pub(super) fn new(input: &'a inputs::exhaustive::Spherical) -> Self { - Self { input } - } + pub(super) struct SphericalQ; - pub(super) fn run(self, mut output: &mut dyn Output) -> anyhow::Result + impl SphericalQ { + pub(super) fn run( + &self, + input: &inputs::exhaustive::Spherical, + mut output: &mut dyn Output, + ) -> anyhow::Result where Unsigned: Representation, Plan: algos::CreateQuantComputer>, @@ -97,8 +95,7 @@ mod imp { SphericalQuantizer: for<'x> CompressIntoWith<&'x [f32], DataMut<'x, NBITS>, ScopedAllocator<'x>>, { - let input = &self.input; - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; // Training let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?; @@ -202,7 +199,7 @@ mod imp { } } - impl Benchmark for SphericalQ<'static, NBITS> + impl Benchmark for SphericalQ where Unsigned: Representation, Plan: algos::CreateQuantComputer>, @@ -214,7 +211,10 @@ mod imp { type Input = inputs::exhaustive::Spherical; type Output = Results; - fn try_match(input: &inputs::exhaustive::Spherical) -> Result { + fn try_match( + &self, + input: &inputs::exhaustive::Spherical, + ) -> Result { let num_bits = input.num_bits.get(); if num_bits == NBITS { Ok(MatchScore(0)) @@ -226,6 +226,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&inputs::exhaustive::Spherical>, ) -> std::fmt::Result { @@ -254,11 +255,12 @@ mod imp { } fn run( + &self, input: &inputs::exhaustive::Spherical, _checkpoint: diskann_benchmark_runner::Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - SphericalQ::::new(input).run(output) + self.run(input, output) } } diff --git a/diskann-benchmark/src/backend/filters/benchmark.rs b/diskann-benchmark/src/backend/filters/benchmark.rs index a90ea41ed..43a0717b4 100644 --- a/diskann-benchmark/src/backend/filters/benchmark.rs +++ b/diskann-benchmark/src/backend/filters/benchmark.rs @@ -29,29 +29,23 @@ use crate::{ }; pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register::>("metadata-index-build"); + benchmarks.register("metadata-index-build", MetadataIndexJob); } -// Metadata-only index job wrapper -pub(super) struct MetadataIndexJob<'a> { - input: &'a crate::inputs::filters::MetadataIndexBuild, -} - -impl<'a> MetadataIndexJob<'a> { - fn new(input: &'a crate::inputs::filters::MetadataIndexBuild) -> Self { - Self { input } - } -} +// Metadata-only index job. +#[derive(Debug)] +struct MetadataIndexJob; -impl Benchmark for MetadataIndexJob<'static> { +impl Benchmark for MetadataIndexJob { type Input = MetadataIndexBuild; type Output = MetadataIndexBuildStats; - fn try_match(_input: &MetadataIndexBuild) -> Result { + fn try_match(&self, _input: &MetadataIndexBuild) -> Result { Ok(MatchScore(1)) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, _input: Option<&MetadataIndexBuild>, ) -> std::fmt::Result { @@ -63,90 +57,89 @@ impl Benchmark for MetadataIndexJob<'static> { } fn run( + &self, input: &MetadataIndexBuild, checkpoint: Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { - MetadataIndexJob::new(input).run(checkpoint, output) + run(input, checkpoint, output) } } -impl<'a> MetadataIndexJob<'a> { - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - // Print the input description so the user sees the job configuration. - writeln!(output, "{}", self.input)?; - - // Use the supplied filter parameters (required for metadata-only build) - let filter_params = &self.input.filter_params; - - // Reuse the helper: build index, parse predicates, produce BitmapFilters and telemetry - let (bitmap_filters_vec, filter_search_results, _label_count) = - prepare_bitmap_filters_from_paths_with_kind( - filter_params.data_labels.as_ref(), - filter_params.query_predicates.as_ref(), - self.input.inverted_index_type, - checkpoint, - )?; - - // Collect per-query matching counts and compute aggregates - let counts: Vec = bitmap_filters_vec.iter().map(|bf| bf.count()).collect(); - let query_count = counts.len(); - let total_matching: usize = counts.iter().cloned().sum(); +fn run( + input: &crate::inputs::filters::MetadataIndexBuild, + checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, +) -> Result { + // Print the input description so the user sees the job configuration. + writeln!(output, "{}", input)?; + + // Use the supplied filter parameters (required for metadata-only build) + let filter_params = &input.filter_params; + + // Reuse the helper: build index, parse predicates, produce BitmapFilters and telemetry + let (bitmap_filters_vec, filter_search_results, _label_count) = + prepare_bitmap_filters_from_paths_with_kind( + filter_params.data_labels.as_ref(), + filter_params.query_predicates.as_ref(), + input.inverted_index_type, + checkpoint, + )?; - // counts_avg will be computed below via the shared percentiles utility - let mut sorted = counts.clone(); - // Use the shared percentiles utility when we have values. - let ( - counts_p1, - counts_p5, - counts_p10, - counts_p50, - counts_p90, - counts_p95, - counts_p99, - counts_avg, - ) = if sorted.is_empty() { - ( - 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0.0f64, - ) - } else { - sorted.sort_unstable(); - let p = percentiles::compute_percentiles(&mut sorted)?; - // p.median is f64; round to nearest usize for display/storage - let p50 = p.median.round() as usize; - let p90 = p.p90; - let p99 = p.p99; - let n = sorted.len(); - let p1 = sorted[(n / 100).min(n - 1)]; - let p5 = sorted[((5 * n) / 100).min(n - 1)]; - let p10 = sorted[((10 * n) / 100).min(n - 1)]; - let p95 = sorted[((95 * n) / 100).min(n - 1)]; - (p1, p5, p10, p50, p90, p95, p99, p.mean) - }; + // Collect per-query matching counts and compute aggregates + let counts: Vec = bitmap_filters_vec.iter().map(|bf| bf.count()).collect(); + let query_count = counts.len(); + let total_matching: usize = counts.iter().cloned().sum(); + + // counts_avg will be computed below via the shared percentiles utility + let mut sorted = counts.clone(); + // Use the shared percentiles utility when we have values. + let ( + counts_p1, + counts_p5, + counts_p10, + counts_p50, + counts_p90, + counts_p95, + counts_p99, + counts_avg, + ) = if sorted.is_empty() { + ( + 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0usize, 0.0f64, + ) + } else { + sorted.sort_unstable(); + let p = percentiles::compute_percentiles(&mut sorted)?; + // p.median is f64; round to nearest usize for display/storage + let p50 = p.median.round() as usize; + let p90 = p.p90; + let p99 = p.p99; + let n = sorted.len(); + let p1 = sorted[(n / 100).min(n - 1)]; + let p5 = sorted[((5 * n) / 100).min(n - 1)]; + let p10 = sorted[((10 * n) / 100).min(n - 1)]; + let p95 = sorted[((95 * n) / 100).min(n - 1)]; + (p1, p5, p10, p50, p90, p95, p99, p.mean) + }; - let stats = MetadataIndexBuildStats { - label_count: _label_count, - query_count, - total_matching, - counts_avg, - counts_p1, - counts_p5, - counts_p10, - counts_p50, - counts_p90, - counts_p95, - counts_p99, - filter: filter_search_results, - }; + let stats = MetadataIndexBuildStats { + label_count: _label_count, + query_count, + total_matching, + counts_avg, + counts_p1, + counts_p5, + counts_p10, + counts_p50, + counts_p90, + counts_p95, + counts_p99, + filter: filter_search_results, + }; - // Print the human-readable summary for interactive runs. - writeln!(output, "\n\n{}", stats)?; - Ok(stats) - } + // Print the human-readable summary for interactive runs. + writeln!(output, "\n\n{}", stats)?; + Ok(stats) } #[derive(Debug, Serialize)] diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index d38332ee1..ad7fd697a 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -32,7 +32,6 @@ use diskann_utils::{ views::{Matrix, MatrixView}, }; use half::f16; -use serde::Serialize; use super::{ build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, @@ -57,66 +56,48 @@ use crate::{ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { // Full Precision - benchmarks.register::>("async-full-precision-f32"); - benchmarks.register::>("async-full-precision-f16"); - benchmarks.register::>("async-full-precision-u8"); - benchmarks.register::>("async-full-precision-i8"); + benchmarks.register("async-full-precision-f32", FullPrecision::::new()); + benchmarks.register("async-full-precision-f16", FullPrecision::::new()); + benchmarks.register("async-full-precision-u8", FullPrecision::::new()); + benchmarks.register("async-full-precision-i8", FullPrecision::::new()); // Dynamic Full Precision - benchmarks.register::>("async-dynamic-full-precision-f32"); - benchmarks.register::>("async-dynamic-full-precision-f16"); - benchmarks.register::>("async-dynamic-full-precision-u8"); - benchmarks.register::>("async-dynamic-full-precision-i8"); + benchmarks.register( + "async-dynamic-full-precision-f32", + DynamicFullPrecision::::new(), + ); + benchmarks.register( + "async-dynamic-full-precision-f16", + DynamicFullPrecision::::new(), + ); + benchmarks.register( + "async-dynamic-full-precision-u8", + DynamicFullPrecision::::new(), + ); + benchmarks.register( + "async-dynamic-full-precision-i8", + DynamicFullPrecision::::new(), + ); product::register_benchmarks(benchmarks); scalar::register_benchmarks(benchmarks); spherical::register_benchmarks(benchmarks); } -////////////// -// Dispatch // -////////////// - -pub(super) trait BuildAndSearch<'a> { - /// The telemetry associated with the build and search. - type Data: Serialize; - - /// Run the job, returning either the completed data or an error. - fn run( - self, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> Result; -} - -pub(super) trait BuildAndDynamicRun<'a> { - /// The telemetry associated with the build and dynamic run. - type Data: Serialize; - - /// Run the runbook, returning either the completed data or an error. - fn run( - self, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> Result; -} - // Full Precision -pub(super) struct FullPrecision<'a, T> { - input: &'a IndexOperation, +pub(super) struct FullPrecision { _type: std::marker::PhantomData, } -impl<'a, T> FullPrecision<'a, T> { - fn new(input: &'a IndexOperation) -> Self { +impl FullPrecision { + pub(super) fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } } -impl Benchmark for FullPrecision<'static, T> +impl Benchmark for FullPrecision where T: VectorRepr + diskann_utils::sampling::WithApproximateNorm @@ -126,7 +107,7 @@ where type Input = IndexOperation; type Output = BuildResult; - fn try_match(input: &IndexOperation) -> Result { + fn try_match(&self, input: &IndexOperation) -> Result { match &input.source { IndexSource::Load(load) => datatype::Type::::try_match(&load.data_type), IndexSource::Build(build) => datatype::Type::::try_match(&build.data_type), @@ -134,6 +115,7 @@ where } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&IndexOperation>, ) -> std::fmt::Result { @@ -151,30 +133,79 @@ where } fn run( + &self, input: &IndexOperation, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + mut output: &mut dyn Output, ) -> anyhow::Result { - BuildAndSearch::run(FullPrecision::::new(input), checkpoint, output) + writeln!(output, "{}", input)?; + let (index, build_stats) = match &input.source { + IndexSource::Build(build) => { + let (index, build_stats) = run_build( + build, + common::FullPrecision, + None, + output, + |data| { + let index = diskann_async::new_index::( + build.try_as_config()?.build()?, + build.inmem_parameters(data.nrows(), data.ncols()), + common::NoDeletes, + )?; + build::set_start_points( + index.provider(), + data.as_view(), + build.start_point_strategy, + )?; + Ok(index) + }, + single_or_multi_insert, + )?; + + // save the index if requested + if let Some(save_path) = &build.save_path { + utils::tokio::block_on(save_index(index.clone(), save_path))?; + } + + (index, Some(build_stats)) + } + IndexSource::Load(load) => { + let index_config: &IndexConfiguration = &load.to_config()?; + + let index = + { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; + + (Arc::new(index), None::) + } + }; + + let result = run_search_outer( + &input.search_phase, + common::FullPrecision, + index, + build_stats, + checkpoint, + )?; + + writeln!(output, "\n\n{}", result)?; + Ok(result) } } // Async Dynamic Run -pub(super) struct DynamicFullPrecision<'a, T> { - input: &'a DynamicIndexRun, +pub(super) struct DynamicFullPrecision { _type: std::marker::PhantomData, } -impl<'a, T> DynamicFullPrecision<'a, T> { - fn new(input: &'a DynamicIndexRun) -> Self { +impl DynamicFullPrecision { + fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } } -impl Benchmark for DynamicFullPrecision<'static, T> +impl Benchmark for DynamicFullPrecision where T: VectorRepr + diskann_utils::sampling::WithApproximateNorm @@ -184,11 +215,12 @@ where type Input = DynamicIndexRun; type Output = Vec>; - fn try_match(input: &DynamicIndexRun) -> Result { + fn try_match(&self, input: &DynamicIndexRun) -> Result { datatype::Type::::try_match(&input.build.data_type) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&DynamicIndexRun>, ) -> std::fmt::Result { @@ -196,11 +228,64 @@ where } fn run( + &self, input: &DynamicIndexRun, - checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + _checkpoint: Checkpoint<'_>, + mut output: &mut dyn Output, ) -> anyhow::Result>> { - BuildAndDynamicRun::run(DynamicFullPrecision::::new(input), checkpoint, output) + writeln!(output, "{}", input)?; + + let groundtruth_directory = input + .runbook_params + .resolved_gt_directory + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!("Ground truth directory path was not resolved during validation") + })?; + + let mut runbook = bigann::RunBook::load( + &input.runbook_params.runbook_path, + &input.runbook_params.dataset_name, + &mut bigann::ScanDirectory::new(groundtruth_directory)?, + )?; + + let mut streamer = full_precision_streaming::(input, runbook.max_points())?; + + let mut results = Vec::new(); + let stages = runbook.len(); + let mut i = 1; + + runbook.run_with( + &mut streamer, + |o: managed::Stats| -> anyhow::Result<()> { + if o.inner().is_maintain() { + let message = format!("Ran maintenance before stage {}", i); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + } else { + let message = + format!("Finished stage {} of {}: {}", i, stages, o.inner().kind()); + write!(output, "{}", crate::utils::SmallBanner(&message))?; + i += 1; + } + writeln!(output, "{}", o)?; + results.push(o); + Ok(()) + }, + )?; + + write!( + output, + "{}", + crate::utils::SmallBanner("End of Run Summary") + )?; + + writeln!( + output, + "{}", + streaming::stats::Summary::new(results.iter().map(|r| r.inner())) + )?; + + Ok(results) } } @@ -395,141 +480,6 @@ where } } -impl<'a, T> BuildAndSearch<'a> for FullPrecision<'a, T> -where - T: VectorRepr - + diskann_utils::sampling::WithApproximateNorm - + diskann::graph::SampleableForStart, -{ - type Data = BuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - let (index, build_stats) = match &self.input.source { - IndexSource::Build(build) => { - let (index, build_stats) = run_build( - build, - common::FullPrecision, - None, - output, - |data| { - let index = diskann_async::new_index::( - build.try_as_config()?.build()?, - build.inmem_parameters(data.nrows(), data.ncols()), - common::NoDeletes, - )?; - build::set_start_points( - index.provider(), - data.as_view(), - build.start_point_strategy, - )?; - Ok(index) - }, - single_or_multi_insert, - )?; - - // save the index if requested - if let Some(save_path) = &build.save_path { - utils::tokio::block_on(save_index(index.clone(), save_path))?; - } - - (index, Some(build_stats)) - } - IndexSource::Load(load) => { - let index_config: &IndexConfiguration = &load.to_config()?; - - let index = - { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; - - (Arc::new(index), None::) - } - }; - - let result = run_search_outer( - &self.input.search_phase, - common::FullPrecision, - index, - build_stats, - checkpoint, - )?; - - writeln!(output, "\n\n{}", result)?; - Ok(result) - } -} - -impl<'a, T> BuildAndDynamicRun<'a> for DynamicFullPrecision<'a, T> -where - T: VectorRepr - + diskann_utils::sampling::WithApproximateNorm - + diskann::graph::SampleableForStart, -{ - type Data = Vec>; - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; - - let groundtruth_directory = self - .input - .runbook_params - .resolved_gt_directory - .as_ref() - .ok_or_else(|| { - anyhow::anyhow!("Ground truth directory path was not resolved during validation") - })?; - - let mut runbook = bigann::RunBook::load( - &self.input.runbook_params.runbook_path, - &self.input.runbook_params.dataset_name, - &mut bigann::ScanDirectory::new(groundtruth_directory)?, - )?; - - let mut streamer = full_precision_streaming(&self, runbook.max_points())?; - - let mut results = Vec::new(); - let stages = runbook.len(); - let mut i = 1; - - runbook.run_with( - &mut streamer, - |o: managed::Stats| -> anyhow::Result<()> { - if o.inner().is_maintain() { - let message = format!("Ran maintenance before stage {}", i); - write!(output, "{}", crate::utils::SmallBanner(&message))?; - } else { - let message = - format!("Finished stage {} of {}: {}", i, stages, o.inner().kind()); - write!(output, "{}", crate::utils::SmallBanner(&message))?; - i += 1; - } - writeln!(output, "{}", o)?; - results.push(o); - Ok(()) - }, - )?; - - write!( - output, - "{}", - crate::utils::SmallBanner("End of Run Summary") - )?; - - writeln!( - output, - "{}", - streaming::stats::Summary::new(results.iter().map(|r| r.inner())) - )?; - - Ok(results) - } -} - /// The stack looks like this: /// /// - Bottom: [`FullPrecisionStream`]: The core streaming index implementation. @@ -540,19 +490,19 @@ where /// /// This function constructs the entire stack. fn full_precision_streaming( - config: &DynamicFullPrecision<'_, T>, + input: &DynamicIndexRun, max_points: usize, ) -> anyhow::Result>> where T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart, { - let topk = match &config.input.search_phase { + let topk = match &input.search_phase { SearchPhase::Topk(topk) => topk, _ => anyhow::bail!("Only TopK is currently supported by the streaming index"), }; - let consolidate_threshold: f32 = config.input.runbook_params.consolidate_threshold; + let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold; - let data = datafiles::load_dataset::(datafiles::BinFile(&config.input.build.data))?; + let data = datafiles::load_dataset::(datafiles::BinFile(&input.build.data))?; let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( &topk.queries, ))?); @@ -561,28 +511,25 @@ where let max_points = ((max_points as f32) * (1.0 + 2.0 * consolidate_threshold)).ceil() as usize; let index = diskann_async::new_index::( - config - .input - .try_as_config(config.input.build.l_build)? - .build()?, - config.input.inmem_parameters(max_points, data.ncols()), + input.try_as_config(input.build.l_build)?.build()?, + input.inmem_parameters(max_points, data.ncols()), common::TableBasedDeletes, )?; build::set_start_points( index.provider(), data.as_view(), - config.input.build.start_point_strategy, + input.build.start_point_strategy, )?; - let num_threads_and_tasks = NonZeroUsize::new(config.input.build.num_threads).unwrap(); + let num_threads_and_tasks = NonZeroUsize::new(input.build.num_threads).unwrap(); let managed_stream = FullPrecisionStream { index, search: topk.clone(), runtime: benchmark_core::tokio::runtime(num_threads_and_tasks.get())?, ntasks: num_threads_and_tasks, - inplace_delete_num_to_replace: config.input.runbook_params.ip_delete_num_to_replace, - inplace_delete_method: config.input.runbook_params.ip_delete_method.into(), + inplace_delete_num_to_replace: input.runbook_params.ip_delete_num_to_replace, + inplace_delete_method: input.runbook_params.ip_delete_method.into(), }; let managed = Managed::new(max_points, consolidate_threshold, managed_stream); diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index a857e4e57..a529217bf 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -13,8 +13,8 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { { use half::f16; - benchmarks.register::>("async-pq-f32"); - benchmarks.register::>("async-pq-f16"); + benchmarks.register("async-pq-f32", imp::ProductQuantized::::new()); + benchmarks.register("async-pq-f16", imp::ProductQuantized::::new()); } // Stub implementation @@ -42,7 +42,7 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, run_search_outer, BuildAndSearch, FullPrecision}, + benchmarks::{run_build, run_search_outer, FullPrecision}, build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, result::QuantBuildResult, }, @@ -50,21 +50,19 @@ mod imp { utils::{self, datafiles}, }; - pub(super) struct ProductQuantized<'a, T> { - input: &'a IndexPQOperation, + pub(super) struct ProductQuantized { _type: std::marker::PhantomData, } - impl<'a, T> ProductQuantized<'a, T> { - fn new(input: &'a IndexPQOperation) -> Self { + impl ProductQuantized { + pub(super) fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } } - impl Benchmark for ProductQuantized<'static, T> + impl Benchmark for ProductQuantized where T: VectorRepr + diskann_utils::sampling::WithApproximateNorm @@ -74,51 +72,31 @@ mod imp { type Input = IndexPQOperation; type Output = QuantBuildResult; - fn try_match(input: &IndexPQOperation) -> Result { - as Benchmark>::try_match(&input.index_operation) + fn try_match(&self, input: &IndexPQOperation) -> Result { + FullPrecision::::new().try_match(&input.index_operation) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&IndexPQOperation>, ) -> std::fmt::Result { - as Benchmark>::description( - f, - input.map(|f| &f.index_operation), - ) + FullPrecision::::new().description(f, input.map(|f| &f.index_operation)) } fn run( + &self, input: &IndexPQOperation, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, - ) -> anyhow::Result { - let pq = ProductQuantized::::new(input); - BuildAndSearch::run(pq, checkpoint, output) - } - } - - impl<'a, T> BuildAndSearch<'a> for ProductQuantized<'a, T> - where - T: VectorRepr - + diskann_utils::sampling::WithApproximateNorm - + diskann::graph::SampleableForStart, - datatype::Type: DispatchRule, - { - type Data = QuantBuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; + ) -> anyhow::Result { + writeln!(output, "{}", input)?; - let hybrid = common::Hybrid::new(self.input.max_fp_vecs_per_prune); + let hybrid = common::Hybrid::new(input.max_fp_vecs_per_prune); - let (index, build_stats, quant_training_time) = match &self.input.index_operation.source - { + let (index, build_stats, quant_training_time) = match &input.index_operation.source { IndexSource::Load(load) => { - let index_config: &IndexConfiguration = &self.input.to_config()?; + let index_config: &IndexConfiguration = &input.to_config()?; let index = { utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))? }; @@ -139,17 +117,16 @@ mod imp { diskann_async::train_pq( train_data.as_view(), - self.input.num_pq_chunks, - &mut StdRng::seed_from_u64(self.input.seed), + input.num_pq_chunks, + &mut StdRng::seed_from_u64(input.seed), build.num_threads, )? }; let create_index = |data_view: MatrixView| { let index = diskann_async::new_quant_index::( - self.input.try_as_config()?.build()?, - self.input - .inmem_parameters(data_view.nrows(), data_view.ncols())?, + input.try_as_config()?.build()?, + input.inmem_parameters(data_view.nrows(), data_view.ncols())?, table, common::NoDeletes, )?; @@ -180,9 +157,9 @@ mod imp { } }; - let build = if self.input.use_fp_for_search { + let build = if input.use_fp_for_search { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, common::FullPrecision, index, build_stats, @@ -190,7 +167,7 @@ mod imp { )? } else { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, hybrid, index, build_stats, diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index b418c0d7b..216500cf5 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -14,17 +14,17 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { use half::f16; // f32 - benchmarks.register::>("async-sq-8-bit-f32"); - benchmarks.register::>("async-sq-4-bit-f32"); - benchmarks.register::>("async-sq-2-bit-f32"); - benchmarks.register::>("async-sq-1-bit-f32"); - // f16 - benchmarks.register::>("async-sq-8-bit-f16"); - benchmarks.register::>("async-sq-4-bit-f16"); - benchmarks.register::>("async-sq-2-bit-f16"); - benchmarks.register::>("async-sq-1-bit-f16"); + benchmarks.register("async-sq-8-bit-f32", imp::ScalarQuantized::<8, f32>::new()); + benchmarks.register("async-sq-4-bit-f32", imp::ScalarQuantized::<4, f32>::new()); + benchmarks.register("async-sq-2-bit-f32", imp::ScalarQuantized::<2, f32>::new()); + benchmarks.register("async-sq-1-bit-f32", imp::ScalarQuantized::<1, f32>::new()); + // f16 , + benchmarks.register("async-sq-8-bit-f16", imp::ScalarQuantized::<8, f16>::new()); + benchmarks.register("async-sq-4-bit-f16", imp::ScalarQuantized::<4, f16>::new()); + benchmarks.register("async-sq-2-bit-f16", imp::ScalarQuantized::<2, f16>::new()); + benchmarks.register("async-sq-1-bit-f16", imp::ScalarQuantized::<1, f16>::new()); // i8 - benchmarks.register::>("async-sq-1-bit-i8"); + benchmarks.register("async-sq-1-bit-i8", imp::ScalarQuantized::<1, i8>::new()); } // Stub implementation @@ -55,7 +55,7 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, run_search_outer, BuildAndSearch, FullPrecision}, + benchmarks::{run_build, run_search_outer, FullPrecision}, build::{self, load_index, only_single_insert, save_index, BuildStats}, result::QuantBuildResult, }, @@ -64,16 +64,13 @@ mod imp { }; // Scalar Quantized - pub(super) struct ScalarQuantized<'a, const NBITS: usize, T> { - input: &'a IndexSQOperation, + pub(super) struct ScalarQuantized { _type: std::marker::PhantomData, } - impl<'a, const NBITS: usize, T> ScalarQuantized<'a, NBITS, T> { - fn new(input: &'a IndexSQOperation) -> Self { - assert_eq!(input.num_bits, NBITS); + impl ScalarQuantized { + pub(super) fn new() -> Self { Self { - input, _type: std::marker::PhantomData, } } @@ -81,11 +78,11 @@ mod imp { macro_rules! impl_sq_build { ($N:literal, $T: ty) => { - impl Benchmark for ScalarQuantized<'static, $N, $T> { + impl Benchmark for ScalarQuantized<$N, $T> { type Input = IndexSQOperation; type Output = QuantBuildResult; - fn try_match(input: &IndexSQOperation) -> Result { + fn try_match(&self, input: &IndexSQOperation) -> Result { let mut failure_score: Option = None; match input.index_operation.source { IndexSource::Load(_) => {} @@ -96,7 +93,7 @@ mod imp { } } - if as Benchmark>::try_match(&input.index_operation) + if FullPrecision::<$T>::new().try_match(&input.index_operation) .is_err() { *failure_score.get_or_insert(0) += 1; @@ -113,6 +110,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&IndexSQOperation>, ) -> std::fmt::Result { @@ -173,25 +171,20 @@ mod imp { } fn run( + &self, input: &IndexSQOperation, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + mut output: &mut dyn Output, ) -> anyhow::Result { - let sq = ScalarQuantized::<$N, $T>::new(input); - BuildAndSearch::run(sq, checkpoint, output) - } - } + assert_eq!( + input.num_bits, + $N, + "INTERNAL ERROR: this should not have passed the match check" + ); - impl<'a> BuildAndSearch<'a> for ScalarQuantized<'a, $N, $T> { - type Data = QuantBuildResult; - fn run( - self, - checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; - let (index, build_stats, quant_training_time) = match &self.input.index_operation.source { + let (index, build_stats, quant_training_time) = match &input.index_operation.source { IndexSource::Load(load) => { let index_config: &IndexConfiguration = &load.to_config()?; @@ -208,7 +201,7 @@ mod imp { let start = std::time::Instant::now(); let quantizer = diskann_quantization::scalar::train::ScalarQuantizationParameters::new( - diskann_quantization::num::Positive::new(self.input.standard_deviations).context( + diskann_quantization::num::Positive::new(input.standard_deviations).context( "please file a bug report, this should not have made it past the\ front end", )?, @@ -216,8 +209,8 @@ mod imp { .train(data.as_view()); let create_index = |data_view: MatrixView<$T>| { let index = diskann_async::new_quant_index::<$T, _, _>( - self.input.try_as_config()?.build()?, - self.input + input.try_as_config()?.build()?, + input .inmem_parameters(data_view.nrows(), data_view.ncols())?, inmem::WithBits::<$N>::new(quantizer), common::NoDeletes, @@ -247,9 +240,9 @@ mod imp { }; - let build = if self.input.use_fp_for_search { + let build = if input.use_fp_for_search { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, common::FullPrecision, index, build_stats, @@ -257,7 +250,7 @@ mod imp { )? } else { run_search_outer( - &self.input.index_operation.search_phase, + &input.index_operation.search_phase, common::Quantized, index, build_stats, diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 33cb2e2fe..507337da7 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -16,9 +16,9 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "spherical-quantization")] { - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); - benchmarks.register::>(NAME); + benchmarks.register(NAME, imp::SphericalQ::<1>); + benchmarks.register(NAME, imp::SphericalQ::<2>); + benchmarks.register(NAME, imp::SphericalQ::<4>); } // Stub implementation @@ -52,7 +52,6 @@ mod imp { use crate::{ backend::index::{ - benchmarks::BuildAndSearch, build::{self, only_single_insert, BuildStats}, result::AggregatedSearchResults, search, @@ -68,15 +67,7 @@ mod imp { }; /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct SphericalQ<'a, const NBITS: usize> { - input: &'a SphericalQuantBuild, - } - - impl<'a, const NBITS: usize> SphericalQ<'a, NBITS> { - pub(super) fn new(input: &'a SphericalQuantBuild) -> Self { - Self { input } - } - } + pub(super) struct SphericalQ; macro_rules! write_field { ($f:ident, $field:tt, $fmt:literal, $($expr:tt)*) => { @@ -126,11 +117,14 @@ mod imp { macro_rules! build_and_search { ($N:literal) => { - impl Benchmark for SphericalQ<'static, $N> { + impl Benchmark for SphericalQ<$N> { type Input = SphericalQuantBuild; type Output = SphericalBuildResult; - fn try_match(input: &SphericalQuantBuild) -> Result { + fn try_match( + &self, + input: &SphericalQuantBuild, + ) -> Result { let mut failure_score: Option = None; if input.build.multi_insert.is_some() { failure_score = Some(1); @@ -157,6 +151,7 @@ mod imp { } fn description( + &self, f: &mut std::fmt::Formatter<'_>, input: Option<&SphericalQuantBuild>, ) -> std::fmt::Result { @@ -200,42 +195,37 @@ mod imp { } fn run( + &self, input: &SphericalQuantBuild, checkpoint: Checkpoint<'_>, - output: &mut dyn Output, + mut output: &mut dyn Output, ) -> anyhow::Result { - let sq = SphericalQ::<$N>::new(input); - BuildAndSearch::run(sq, checkpoint, output) - } - } + assert_eq!( + input.num_bits.get(), + $N, + "INTERNAL ERROR: this should not have passed the match check" + ); - impl<'a> BuildAndSearch<'a> for SphericalQ<'a, $N> { - type Data = SphericalBuildResult; - fn run( - self, - _checkpoint: Checkpoint<'_>, - mut output: &mut dyn Output, - ) -> Result { - writeln!(output, "{}", self.input)?; + writeln!(output, "{}", input)?; - let build = &self.input.build; + let build = &input.build; let data: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&build.data))?); let start = std::time::Instant::now(); let m: diskann_vector::distance::Metric = build.distance.into(); - let pre_scale = match self.input.pre_scale { + let pre_scale = match input.pre_scale { Some(v) => v.try_into()?, None => diskann_quantization::spherical::PreScale::None, }; let quantizer = diskann_quantization::spherical::SphericalQuantizer::train( data.as_view(), - (&self.input.transform_kind).into(), + (&input.transform_kind).into(), m.try_into()?, pre_scale, - &mut rand::rngs::StdRng::seed_from_u64(self.input.seed), + &mut rand::rngs::StdRng::seed_from_u64(input.seed), GlobalAllocator, )?; @@ -244,8 +234,8 @@ mod imp { // We manually inline the build and search loops because we support // multiple different kinds of searches. let index = diskann_async::new_quant_index::( - self.input.try_as_config()?.build()?, - self.input.inmem_parameters(data.nrows(), data.ncols()), + input.try_as_config()?.build()?, + input.inmem_parameters(data.nrows(), data.ncols()), diskann_quantization::spherical::iface::Impl::<$N>::new(quantizer)?, NoDeletes, )?; @@ -274,12 +264,12 @@ mod imp { runs: Vec::new(), }; - match &self.input.search_phase { + match &input.search_phase { SearchPhase::Topk(search_phase) => { // Handle Topk search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -295,7 +285,7 @@ mod imp { &search_phase.runs, ); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let knn = benchmark_core::search::graph::KNN::new( index.clone(), queries.clone(), @@ -317,7 +307,7 @@ mod imp { // Handle Range search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -333,7 +323,7 @@ mod imp { &search_phase.runs, ); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let range = benchmark_core::search::graph::Range::new( index.clone(), queries.clone(), @@ -358,7 +348,7 @@ mod imp { // Handle Beta Filtered Topk search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -384,7 +374,7 @@ mod imp { .map(utils::filters::as_query_label_provider) .collect(); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let strategy = inmem::spherical::Quantized::search(layout.into()); let search_strategies = setup_filter_strategies( search_phase.beta, @@ -414,7 +404,7 @@ mod imp { // Handle Beta Filtered Topk search phase // Save construction stats before running queries. - _checkpoint.checkpoint(&result)?; + checkpoint.checkpoint(&result)?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&search_phase.queries), @@ -440,7 +430,7 @@ mod imp { .map(utils::filters::as_query_label_provider) .collect(); - for &layout in self.input.query_layouts.iter() { + for &layout in input.query_layouts.iter() { let multihop = benchmark_core::search::graph::MultiHop::new( index.clone(), queries.clone(), diff --git a/diskann-benchmark/src/utils/mod.rs b/diskann-benchmark/src/utils/mod.rs index ebdac8116..cee6cacd3 100644 --- a/diskann-benchmark/src/utils/mod.rs +++ b/diskann-benchmark/src/utils/mod.rs @@ -111,7 +111,7 @@ macro_rules! stub_impl { use crate::inputs; pub(super) fn register(name: &str, registry: &mut Benchmarks) { - registry.register::(name); + registry.register(name, Stub); } /// An empty placeholder to provide a hint for the necessary feature. @@ -121,11 +121,12 @@ macro_rules! stub_impl { type Input = $input; type Output = serde_json::Value; - fn try_match(_input: &$input) -> Result { + fn try_match(&self, _input: &$input) -> Result { Err(FailureScore(0)) } fn description( + &self, f: &mut std::fmt::Formatter<'_>, _input: Option<&$input>, ) -> std::fmt::Result { @@ -134,6 +135,7 @@ macro_rules! stub_impl { } fn run( + &self, _input: &$input, _checkpoint: Checkpoint<'_>, _output: &mut dyn Output, From d60ac072081527db60bc3df66e0a37cf6347bbe3 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 28 Apr 2026 16:49:03 -0700 Subject: [PATCH 02/38] Checkpoint. --- .../src/backend/index/benchmarks.rs | 568 +++++++++++++----- diskann-benchmark/src/backend/index/result.rs | 4 + .../src/backend/index/search/mod.rs | 3 + .../src/backend/index/search/plugins.rs | 113 ++++ diskann-benchmark/src/inputs/async_.rs | 64 ++ 5 files changed, 585 insertions(+), 167 deletions(-) create mode 100644 diskann-benchmark/src/backend/index/search/plugins.rs diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index ad7fd697a..abc21c538 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -3,8 +3,7 @@ * Licensed under the MIT license. */ -use core::option::Option::None; -use std::{io::Write, num::NonZeroUsize, sync::Arc}; +use std::{any::Any, io::Write, marker::PhantomData, num::NonZeroUsize, sync::Arc}; use diskann::{ graph::SampleableForStart, @@ -24,7 +23,10 @@ use diskann_benchmark_runner::{ }; use diskann_providers::{ index::diskann_async, - model::{configuration::IndexConfiguration, graph::provider::async_::common}, + model::{ + configuration::IndexConfiguration, + graph::provider::async_::{common, inmem}, + }, }; use diskann_utils::{ future::AsyncFriendly, @@ -40,9 +42,12 @@ use super::{ use crate::{ backend::index::{ result::{AggregatedSearchResults, BuildResult}, + search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, + inputs::async_::{ + DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, SearchPhaseKind, + }, utils::{ self, datafiles::{self}, @@ -56,10 +61,27 @@ use crate::{ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { // Full Precision - benchmarks.register("async-full-precision-f32", FullPrecision::::new()); - benchmarks.register("async-full-precision-f16", FullPrecision::::new()); - benchmarks.register("async-full-precision-u8", FullPrecision::::new()); - benchmarks.register("async-full-precision-i8", FullPrecision::::new()); + benchmarks.register( + "async-full-precision-f32", + FullPrecision::::new() + .search(plugins::Topk) + .search(plugins::Range) + .search(plugins::BetaFilter) + .search(plugins::MultihopFilter), + ); + + benchmarks.register( + "async-full-precision-f16", + FullPrecision::::new().search(plugins::Topk), + ); + benchmarks.register( + "async-full-precision-u8", + FullPrecision::::new().search(plugins::Topk), + ); + benchmarks.register( + "async-full-precision-i8", + FullPrecision::::new().search(plugins::Topk), + ); // Dynamic Full Precision benchmarks.register( @@ -84,17 +106,45 @@ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::reg spherical::register_benchmarks(benchmarks); } +type FullPrecisionProvider = inmem::DefaultProvider< + inmem::FullPrecisionStore, + common::NoStore, + common::NoDeletes, + DefaultContext, +>; + +impl QueryType for FullPrecisionProvider +where + T: VectorRepr, +{ + type Element = T; +} + // Full Precision -pub(super) struct FullPrecision { - _type: std::marker::PhantomData, +pub(super) struct FullPrecision +where + T: VectorRepr, +{ + plugins: plugins::Plugins, Strategy>, } -impl FullPrecision { +impl FullPrecision +where + T: VectorRepr, +{ pub(super) fn new() -> Self { Self { - _type: std::marker::PhantomData, + plugins: plugins::Plugins::new(), } } + + pub(super) fn search

(mut self, plugin: P) -> Self + where + P: plugins::Plugin, Strategy> + 'static, + { + self.plugins.register(plugin); + self + } } impl Benchmark for FullPrecision @@ -179,14 +229,14 @@ where } }; - let result = run_search_outer( - &input.search_phase, - common::FullPrecision, + let search_results = self.plugins.run( index, - build_stats, - checkpoint, + &Strategy::new(common::FullPrecision), + &input.search_phase, )?; + let result = BuildResult::new(build_stats, search_results); + writeln!(output, "\n\n{}", result)?; Ok(result) } @@ -325,161 +375,345 @@ where Ok((index, build_stats)) } -pub(super) fn run_search_outer( - input: &SearchPhase, - search_strategy: S, - index: Index, - build_stats: Option, - checkpoint: Checkpoint<'_>, -) -> anyhow::Result +// pub(super) fn run_search_outer( +// input: &SearchPhase, +// search_strategy: S, +// index: Index, +// build_stats: Option, +// checkpoint: Checkpoint<'_>, +// ) -> anyhow::Result +// where +// DP: DataProvider +// + for<'a> provider::SetElement<&'a [T]>, +// T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, +// S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, +// { +// match &input { +// SearchPhase::Topk(search_phase) => { +// // Handle Topk search phase +// let mut result = BuildResult::new_topk(build_stats); +// +// // Save construction stats before running queries. +// checkpoint.checkpoint(&result)?; +// +// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( +// &search_phase.queries, +// ))?); +// +// let groundtruth = +// datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; +// +// let knn = benchmark_core::search::graph::KNN::new( +// index, +// queries, +// benchmark_core::search::graph::Strategy::broadcast(search_strategy), +// )?; +// +// let steps = search::knn::SearchSteps::new( +// search_phase.reps, +// &search_phase.num_threads, +// &search_phase.runs, +// ); +// +// let search_results = search::knn::run(&knn, &groundtruth, steps)?; +// result.append(AggregatedSearchResults::Topk(search_results)); +// Ok(result) +// } +// SearchPhase::Range(search_phase) => { +// // Handle Range search phase +// let mut result = BuildResult::new_range(build_stats); +// +// // Save construction stats before running queries. +// checkpoint.checkpoint(&result)?; +// +// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( +// &search_phase.queries, +// ))?); +// +// let groundtruth = +// datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; +// +// let steps = search::range::RangeSearchSteps::new( +// search_phase.reps, +// &search_phase.num_threads, +// &search_phase.runs, +// ); +// +// let range = benchmark_core::search::graph::Range::new( +// index, +// queries, +// benchmark_core::search::graph::Strategy::broadcast(search_strategy), +// )?; +// +// let search_results = search::range::run(&range, &groundtruth, steps)?; +// result.append(AggregatedSearchResults::Range(search_results)); +// Ok(result) +// } +// SearchPhase::TopkBetaFilter(search_phase) => { +// // Handle Beta Filtered Topk search phase +// let mut result = BuildResult::new_topk(build_stats); +// +// // Save construction stats before running queries. +// checkpoint.checkpoint(&result)?; +// +// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( +// &search_phase.queries, +// ))?); +// +// let groundtruth = +// datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; +// +// let bit_maps = +// generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; +// +// let search_strategies = setup_filter_strategies( +// search_phase.beta, +// bit_maps +// .into_iter() +// .map(utils::filters::as_query_label_provider), +// search_strategy.clone(), +// ); +// +// let knn = benchmark_core::search::graph::KNN::new( +// index, +// queries, +// benchmark_core::search::graph::Strategy::collection(search_strategies), +// )?; +// +// let steps = search::knn::SearchSteps::new( +// search_phase.reps, +// &search_phase.num_threads, +// &search_phase.runs, +// ); +// +// let search_results = search::knn::run(&knn, &groundtruth, steps)?; +// result.append(AggregatedSearchResults::Topk(search_results)); +// Ok(result) +// } +// SearchPhase::TopkMultihopFilter(search_phase) => { +// // Handle MultiHop Topk search phase +// let mut result = BuildResult::new_topk(build_stats); +// +// // Save construction stats before running queries. +// checkpoint.checkpoint(&result)?; +// +// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( +// &search_phase.queries, +// ))?); +// +// let groundtruth = +// datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; +// +// let steps = search::knn::SearchSteps::new( +// search_phase.reps, +// &search_phase.num_threads, +// &search_phase.runs, +// ); +// +// let bit_maps = +// generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; +// +// let multihop = benchmark_core::search::graph::MultiHop::new( +// index, +// queries, +// benchmark_core::search::graph::Strategy::broadcast(search_strategy), +// bit_maps +// .into_iter() +// .map(utils::filters::as_query_label_provider) +// .collect(), +// )?; +// +// let search_results = search::knn::run(&multihop, &groundtruth, steps)?; +// result.append(AggregatedSearchResults::Topk(search_results)); +// Ok(result) +// } +// } +// } + +trait QueryType { + type Element: VectorRepr; +} + +#[derive(Debug, Clone, Copy)] +pub(super) struct Strategy(S); + +impl Strategy { + pub(super) fn new(strategy: S) -> Self { + Self(strategy) + } +} + +impl search::Plugin> for plugins::Topk where - DP: DataProvider - + for<'a> provider::SetElement<&'a [T]>, - T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, - S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, + DP: DataProvider + QueryType, + S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { - match &input { - SearchPhase::Topk(search_phase) => { - // Handle Topk search phase - let mut result = BuildResult::new_topk(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Topk(search_results)); - Ok(result) - } - SearchPhase::Range(search_phase) => { - // Handle Range search phase - let mut result = BuildResult::new_range(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let steps = search::range::RangeSearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let range = benchmark_core::search::graph::Range::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - - let search_results = search::range::run(&range, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Range(search_results)); - Ok(result) - } - SearchPhase::TopkBetaFilter(search_phase) => { - // Handle Beta Filtered Topk search phase - let mut result = BuildResult::new_topk(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let bit_maps = - generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; - - let search_strategies = setup_filter_strategies( - search_phase.beta, - bit_maps - .into_iter() - .map(utils::filters::as_query_label_provider), - search_strategy.clone(), - ); - - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::collection(search_strategies), - )?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Topk(search_results)); - Ok(result) - } - SearchPhase::TopkMultihopFilter(search_phase) => { - // Handle MultiHop Topk search phase - let mut result = BuildResult::new_topk(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let bit_maps = - generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; - - let multihop = benchmark_core::search::graph::MultiHop::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - bit_maps - .into_iter() - .map(utils::filters::as_query_label_provider) - .collect(), - )?; - - let search_results = search::knn::run(&multihop, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Topk(search_results)); - Ok(result) - } + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } + + fn search( + &self, + index: Arc>, + strategy: &Strategy, + phase: &SearchPhase, + ) -> anyhow::Result { + let topk = phase.as_topk().unwrap(); + + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); + + let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + + let knn = benchmark_core::search::graph::KNN::new( + index.clone(), + queries, + benchmark_core::search::graph::Strategy::broadcast(strategy.0.clone()), + )?; + + let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); + + let results = search::knn::run(&knn, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Topk(results)) } } +impl search::Plugin> for plugins::Range +where + DP: DataProvider + QueryType, + S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, +{ + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } + + fn search( + &self, + index: Arc>, + strategy: &Strategy, + phase: &SearchPhase, + ) -> anyhow::Result { + let range = phase.as_range().unwrap(); + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&range.queries))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&range.groundtruth))?; + + let steps = + search::range::RangeSearchSteps::new(range.reps, &range.num_threads, &range.runs); + + let range = benchmark_core::search::graph::Range::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(strategy.0.clone()), + )?; + + let result = search::range::run(&range, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Range(result)) + } +} + +impl search::Plugin> for plugins::BetaFilter +where + DP: DataProvider + QueryType, + S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, +{ + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } + + fn search( + &self, + index: Arc>, + strategy: &Strategy, + phase: &SearchPhase, + ) -> anyhow::Result { + let beta_filter = phase.as_topk_beta_filter().unwrap(); + + let queries: Arc> = Arc::new(datafiles::load_dataset( + datafiles::BinFile(&beta_filter.queries), + )?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&beta_filter.groundtruth))?; + + let bit_maps = generate_bitmaps(&beta_filter.query_predicates, &beta_filter.data_labels)?; + + let search_strategies = setup_filter_strategies( + beta_filter.beta, + bit_maps + .into_iter() + .map(utils::filters::as_query_label_provider), + strategy.0.clone(), + ); + + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::collection(search_strategies), + )?; + + let steps = search::knn::SearchSteps::new( + beta_filter.reps, + &beta_filter.num_threads, + &beta_filter.runs, + ); + + let result = search::knn::run(&knn, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Topk(result)) + } +} + +impl search::Plugin> for plugins::MultihopFilter +where + DP: DataProvider + QueryType, + S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, +{ + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } + + fn search( + &self, + index: Arc>, + strategy: &Strategy, + phase: &SearchPhase, + ) -> anyhow::Result { + let multihop = phase.as_topk_multihop_filter().unwrap(); + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &multihop.queries, + ))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; + + let steps = search::knn::SearchSteps::new( + multihop.reps, + &multihop.num_threads, + &multihop.runs, + ); + + let bit_maps = + generate_bitmaps(&multihop.query_predicates, &multihop.data_labels)?; + + let multihop = benchmark_core::search::graph::MultiHop::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(strategy.0.clone()), + bit_maps + .into_iter() + .map(utils::filters::as_query_label_provider) + .collect(), + )?; + + let result = search::knn::run(&multihop, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Topk(result)) + } +} + + /// The stack looks like this: /// /// - Bottom: [`FullPrecisionStream`]: The core streaming index implementation. diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 1d6102f9b..3cb8c86dc 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -36,6 +36,10 @@ impl BuildResult { } } + pub(super) fn new(build: Option, search: AggregatedSearchResults) -> Self { + Self { build, search } + } + pub(super) fn append(&mut self, search: AggregatedSearchResults) { self.search.append(search); } diff --git a/diskann-benchmark/src/backend/index/search/mod.rs b/diskann-benchmark/src/backend/index/search/mod.rs index e789b6702..e9c00085c 100644 --- a/diskann-benchmark/src/backend/index/search/mod.rs +++ b/diskann-benchmark/src/backend/index/search/mod.rs @@ -5,3 +5,6 @@ pub(crate) mod knn; pub(crate) mod range; + +pub(crate) mod plugins; +pub(crate) use plugins::Plugin; diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs new file mode 100644 index 000000000..1945bea47 --- /dev/null +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -0,0 +1,113 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{any::Any, sync::Arc}; + +use diskann::{graph::DiskANNIndex, provider::DataProvider}; + +use crate::{ + backend::index::result::AggregatedSearchResults, + inputs::async_::{SearchPhase, SearchPhaseKind}, +}; + +pub(crate) trait Plugin: std::fmt::Debug +where + DP: DataProvider, +{ + /// The flavor of `SearchPhase` this plugin is compiled for. + fn kind(&self) -> SearchPhaseKind; + + fn search( + &self, + index: Arc>, + parameters: &P, + phase: &SearchPhase, + ) -> anyhow::Result; +} + +#[derive(Debug)] +pub(crate) struct Plugins +where + DP: DataProvider, +{ + plugins: Vec>>, +} + +impl Plugins +where + DP: DataProvider, +{ + pub(crate) fn new() -> Self { + Self { + plugins: Vec::new(), + } + } + + pub(crate) fn register(&mut self, plugin: T) + where + T: Plugin + 'static, + { + self.plugins.push(Box::new(plugin)); + } + + pub(crate) fn kinds(&self) -> Vec { + self.plugins.iter().map(|p| p.kind()).collect() + } + + pub(crate) fn is_match(&self, phase: &SearchPhase) -> bool { + self.plugins.iter().any(|p| p.kind() == phase.kind()) + } + + pub(crate) fn run( + &self, + index: Arc>, + parameters: &P, + phase: &SearchPhase, + ) -> anyhow::Result { + match self.plugins.iter().find(|p| p.kind() == phase.kind()) { + Some(plugin) => plugin.search(index, parameters, phase), + None => Err(anyhow::anyhow!( + "INTERNAL ERROR: Could not find a search plugin for {}", + phase.kind() + )), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct Topk; + +impl Topk { + pub(crate) fn kind() -> SearchPhaseKind { + SearchPhaseKind::Topk + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct Range; + +impl Range { + pub(crate) fn kind() -> SearchPhaseKind { + SearchPhaseKind::Range + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct BetaFilter; + +impl BetaFilter { + pub(crate) fn kind() -> SearchPhaseKind { + SearchPhaseKind::TopkBetaFilter + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct MultihopFilter; + +impl MultihopFilter { + pub(crate) fn kind() -> SearchPhaseKind { + SearchPhaseKind::TopkMultihopFilter + } +} diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index c76fdb594..d987d7067 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -335,6 +335,45 @@ pub(crate) enum SearchPhase { TopkMultihopFilter(MultiHopSearchPhase), } +impl SearchPhase { + pub(crate) fn kind(&self) -> SearchPhaseKind { + match self { + Self::Topk(_) => SearchPhaseKind::Topk, + Self::Range(_) => SearchPhaseKind::Range, + Self::TopkBetaFilter(_) => SearchPhaseKind::TopkBetaFilter, + Self::TopkMultihopFilter(_) => SearchPhaseKind::TopkMultihopFilter, + } + } + + pub(crate) fn as_topk(&self) -> Option<&TopkSearchPhase> { + match self { + Self::Topk(phase) => Some(phase), + _ => None, + } + } + + pub(crate) fn as_range(&self) -> Option<&RangeSearchPhase> { + match self { + Self::Range(phase) => Some(phase), + _ => None, + } + } + + pub(crate) fn as_topk_beta_filter(&self) -> Option<&BetaSearchPhase> { + match self { + Self::TopkBetaFilter(phase) => Some(phase), + _ => None, + } + } + + pub(crate) fn as_topk_multihop_filter(&self) -> Option<&MultiHopSearchPhase> { + match self { + Self::TopkMultihopFilter(phase) => Some(phase), + _ => None, + } + } +} + impl CheckDeserialization for SearchPhase { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match self { @@ -346,6 +385,31 @@ impl CheckDeserialization for SearchPhase { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum SearchPhaseKind { + Topk, + Range, + TopkBetaFilter, + TopkMultihopFilter, +} + +impl SearchPhaseKind { + fn as_str(&self) -> &'static str { + match self { + Self::Topk => "topk", + Self::Range => "range", + Self::TopkBetaFilter => "beta-filter", + Self::TopkMultihopFilter => "multihop-filter", + } + } +} + +impl std::fmt::Display for SearchPhaseKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + //////////////////////////// // Build - Full Precision // //////////////////////////// From 6cc92bdd64bec510da44aefd4d490046617df49a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 28 Apr 2026 17:26:45 -0700 Subject: [PATCH 03/38] Checkopint. --- .../src/backend/index/benchmarks.rs | 331 +++++++++--------- .../src/backend/index/product.rs | 85 +++-- diskann-benchmark/src/backend/index/scalar.rs | 114 ++++-- 3 files changed, 314 insertions(+), 216 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index abc21c538..ae0e88abb 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -113,6 +113,10 @@ type FullPrecisionProvider = inmem::DefaultProvider< DefaultContext, >; +pub(super) trait QueryType { + type Element: VectorRepr; +} + impl QueryType for FullPrecisionProvider where T: VectorRepr, @@ -375,163 +379,159 @@ where Ok((index, build_stats)) } -// pub(super) fn run_search_outer( -// input: &SearchPhase, -// search_strategy: S, -// index: Index, -// build_stats: Option, -// checkpoint: Checkpoint<'_>, -// ) -> anyhow::Result -// where -// DP: DataProvider -// + for<'a> provider::SetElement<&'a [T]>, -// T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, -// S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, -// { -// match &input { -// SearchPhase::Topk(search_phase) => { -// // Handle Topk search phase -// let mut result = BuildResult::new_topk(build_stats); -// -// // Save construction stats before running queries. -// checkpoint.checkpoint(&result)?; -// -// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( -// &search_phase.queries, -// ))?); -// -// let groundtruth = -// datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; -// -// let knn = benchmark_core::search::graph::KNN::new( -// index, -// queries, -// benchmark_core::search::graph::Strategy::broadcast(search_strategy), -// )?; -// -// let steps = search::knn::SearchSteps::new( -// search_phase.reps, -// &search_phase.num_threads, -// &search_phase.runs, -// ); -// -// let search_results = search::knn::run(&knn, &groundtruth, steps)?; -// result.append(AggregatedSearchResults::Topk(search_results)); -// Ok(result) -// } -// SearchPhase::Range(search_phase) => { -// // Handle Range search phase -// let mut result = BuildResult::new_range(build_stats); -// -// // Save construction stats before running queries. -// checkpoint.checkpoint(&result)?; -// -// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( -// &search_phase.queries, -// ))?); -// -// let groundtruth = -// datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; -// -// let steps = search::range::RangeSearchSteps::new( -// search_phase.reps, -// &search_phase.num_threads, -// &search_phase.runs, -// ); -// -// let range = benchmark_core::search::graph::Range::new( -// index, -// queries, -// benchmark_core::search::graph::Strategy::broadcast(search_strategy), -// )?; -// -// let search_results = search::range::run(&range, &groundtruth, steps)?; -// result.append(AggregatedSearchResults::Range(search_results)); -// Ok(result) -// } -// SearchPhase::TopkBetaFilter(search_phase) => { -// // Handle Beta Filtered Topk search phase -// let mut result = BuildResult::new_topk(build_stats); -// -// // Save construction stats before running queries. -// checkpoint.checkpoint(&result)?; -// -// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( -// &search_phase.queries, -// ))?); -// -// let groundtruth = -// datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; -// -// let bit_maps = -// generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; -// -// let search_strategies = setup_filter_strategies( -// search_phase.beta, -// bit_maps -// .into_iter() -// .map(utils::filters::as_query_label_provider), -// search_strategy.clone(), -// ); -// -// let knn = benchmark_core::search::graph::KNN::new( -// index, -// queries, -// benchmark_core::search::graph::Strategy::collection(search_strategies), -// )?; -// -// let steps = search::knn::SearchSteps::new( -// search_phase.reps, -// &search_phase.num_threads, -// &search_phase.runs, -// ); -// -// let search_results = search::knn::run(&knn, &groundtruth, steps)?; -// result.append(AggregatedSearchResults::Topk(search_results)); -// Ok(result) -// } -// SearchPhase::TopkMultihopFilter(search_phase) => { -// // Handle MultiHop Topk search phase -// let mut result = BuildResult::new_topk(build_stats); -// -// // Save construction stats before running queries. -// checkpoint.checkpoint(&result)?; -// -// let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( -// &search_phase.queries, -// ))?); -// -// let groundtruth = -// datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; -// -// let steps = search::knn::SearchSteps::new( -// search_phase.reps, -// &search_phase.num_threads, -// &search_phase.runs, -// ); -// -// let bit_maps = -// generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; -// -// let multihop = benchmark_core::search::graph::MultiHop::new( -// index, -// queries, -// benchmark_core::search::graph::Strategy::broadcast(search_strategy), -// bit_maps -// .into_iter() -// .map(utils::filters::as_query_label_provider) -// .collect(), -// )?; -// -// let search_results = search::knn::run(&multihop, &groundtruth, steps)?; -// result.append(AggregatedSearchResults::Topk(search_results)); -// Ok(result) -// } -// } -// } - -trait QueryType { - type Element: VectorRepr; +pub(super) fn run_search_outer( + input: &SearchPhase, + search_strategy: S, + index: Index, + build_stats: Option, + checkpoint: Checkpoint<'_>, +) -> anyhow::Result +where + DP: DataProvider + + for<'a> provider::SetElement<&'a [T]>, + T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, + S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, +{ + match &input { + SearchPhase::Topk(search_phase) => { + // Handle Topk search phase + let mut result = BuildResult::new_topk(build_stats); + + // Save construction stats before running queries. + checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &search_phase.queries, + ))?); + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; + + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + )?; + + let steps = search::knn::SearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let search_results = search::knn::run(&knn, &groundtruth, steps)?; + result.append(AggregatedSearchResults::Topk(search_results)); + Ok(result) + } + SearchPhase::Range(search_phase) => { + // Handle Range search phase + let mut result = BuildResult::new_range(build_stats); + + // Save construction stats before running queries. + checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &search_phase.queries, + ))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; + + let steps = search::range::RangeSearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let range = benchmark_core::search::graph::Range::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + )?; + + let search_results = search::range::run(&range, &groundtruth, steps)?; + result.append(AggregatedSearchResults::Range(search_results)); + Ok(result) + } + SearchPhase::TopkBetaFilter(search_phase) => { + // Handle Beta Filtered Topk search phase + let mut result = BuildResult::new_topk(build_stats); + + // Save construction stats before running queries. + checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &search_phase.queries, + ))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; + + let bit_maps = + generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; + + let search_strategies = setup_filter_strategies( + search_phase.beta, + bit_maps + .into_iter() + .map(utils::filters::as_query_label_provider), + search_strategy.clone(), + ); + + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::collection(search_strategies), + )?; + + let steps = search::knn::SearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let search_results = search::knn::run(&knn, &groundtruth, steps)?; + result.append(AggregatedSearchResults::Topk(search_results)); + Ok(result) + } + SearchPhase::TopkMultihopFilter(search_phase) => { + // Handle MultiHop Topk search phase + let mut result = BuildResult::new_topk(build_stats); + + // Save construction stats before running queries. + checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &search_phase.queries, + ))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; + + let steps = search::knn::SearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let bit_maps = + generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; + + let multihop = benchmark_core::search::graph::MultiHop::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + bit_maps + .into_iter() + .map(utils::filters::as_query_label_provider) + .collect(), + )?; + + let search_results = search::knn::run(&multihop, &groundtruth, steps)?; + result.append(AggregatedSearchResults::Topk(search_results)); + Ok(result) + } + } } #[derive(Debug, Clone, Copy)] @@ -682,21 +682,17 @@ where ) -> anyhow::Result { let multihop = phase.as_topk_multihop_filter().unwrap(); - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &multihop.queries, - ))?); + let queries: Arc> = Arc::new(datafiles::load_dataset( + datafiles::BinFile(&multihop.queries), + )?); let groundtruth = datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; - let steps = search::knn::SearchSteps::new( - multihop.reps, - &multihop.num_threads, - &multihop.runs, - ); + let steps = + search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs); - let bit_maps = - generate_bitmaps(&multihop.query_predicates, &multihop.data_labels)?; + let bit_maps = generate_bitmaps(&multihop.query_predicates, &multihop.data_labels)?; let multihop = benchmark_core::search::graph::MultiHop::new( index, @@ -713,7 +709,6 @@ where } } - /// The stack looks like this: /// /// - Bottom: [`FullPrecisionStream`]: The core streaming index implementation. diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index a529217bf..f1839fc85 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -11,10 +11,19 @@ crate::utils::stub_impl!("product-quantization", inputs::async_::IndexPQOperatio pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "product-quantization")] { + use crate::backend::index::search::plugins; use half::f16; - benchmarks.register("async-pq-f32", imp::ProductQuantized::::new()); - benchmarks.register("async-pq-f16", imp::ProductQuantized::::new()); + benchmarks.register( + "async-pq-f32", + imp::ProductQuantized::::new() + .search(plugins::Topk) + .search(plugins::Range), + ); + benchmarks.register( + "async-pq-f16", + imp::ProductQuantized::::new().search(plugins::Topk), + ); } // Stub implementation @@ -29,7 +38,10 @@ mod imp { use diskann::utils::VectorRepr; use diskann_providers::{ index::diskann_async::{self}, - model::{graph::provider::async_::common, IndexConfiguration}, + model::{ + graph::provider::async_::{common, inmem}, + IndexConfiguration, + }, }; use diskann_utils::views::{Matrix, MatrixView}; @@ -42,24 +54,59 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, run_search_outer, FullPrecision}, + benchmarks::{run_build, FullPrecision, QueryType, Strategy}, build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, - result::QuantBuildResult, + result::{BuildResult, QuantBuildResult}, + search::plugins, }, inputs::async_::{IndexPQOperation, IndexSource}, utils::{self, datafiles}, }; - pub(super) struct ProductQuantized { - _type: std::marker::PhantomData, + type PQProvider = inmem::DefaultProvider< + inmem::FullPrecisionStore, + inmem::DefaultQuant, + common::NoDeletes, + diskann::provider::DefaultContext, + >; + + impl QueryType for PQProvider + where + T: VectorRepr, + { + type Element = T; + } + + pub(super) struct ProductQuantized + where + T: VectorRepr, + { + quant_search: plugins::Plugins, Strategy>, + full_search: plugins::Plugins, Strategy>, } - impl ProductQuantized { + impl ProductQuantized + where + T: VectorRepr, + { pub(super) fn new() -> Self { Self { - _type: std::marker::PhantomData, + quant_search: plugins::Plugins::new(), + full_search: plugins::Plugins::new(), } } + + pub(super) fn search

(mut self, plugin: P) -> Self + where + P: plugins::Plugin, Strategy> + + plugins::Plugin, Strategy> + + Clone + + 'static, + { + self.quant_search.register(plugin.clone()); + self.full_search.register(plugin); + self + } } impl Benchmark for ProductQuantized @@ -157,27 +204,23 @@ mod imp { } }; - let build = if input.use_fp_for_search { - run_search_outer( - &input.index_operation.search_phase, - common::FullPrecision, + let search = if input.use_fp_for_search { + self.full_search.run( index, - build_stats, - checkpoint, + &Strategy::new(common::FullPrecision), + &input.index_operation.search_phase, )? } else { - run_search_outer( - &input.index_operation.search_phase, - hybrid, + self.quant_search.run( index, - build_stats, - checkpoint, + &Strategy::new(hybrid), + &input.index_operation.search_phase, )? }; let result = QuantBuildResult { quant_training_time, - build, + build: BuildResult::new(build_stats, search), }; writeln!(output, "\n\n{}", result)?; diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index 216500cf5..503bbc9d6 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -11,20 +11,48 @@ crate::utils::stub_impl!("scalar-quantization", inputs::async_::IndexSQOperation pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "scalar-quantization")] { + use crate::backend::index::search::plugins::Topk; use half::f16; // f32 - benchmarks.register("async-sq-8-bit-f32", imp::ScalarQuantized::<8, f32>::new()); - benchmarks.register("async-sq-4-bit-f32", imp::ScalarQuantized::<4, f32>::new()); - benchmarks.register("async-sq-2-bit-f32", imp::ScalarQuantized::<2, f32>::new()); - benchmarks.register("async-sq-1-bit-f32", imp::ScalarQuantized::<1, f32>::new()); + benchmarks.register( + "async-sq-8-bit-f32", + imp::ScalarQuantized::<8, f32>::new().search(Topk), + ); + benchmarks.register( + "async-sq-4-bit-f32", + imp::ScalarQuantized::<4, f32>::new().search(Topk), + ); + benchmarks.register( + "async-sq-2-bit-f32", + imp::ScalarQuantized::<2, f32>::new().search(Topk), + ); + benchmarks.register( + "async-sq-1-bit-f32", + imp::ScalarQuantized::<1, f32>::new().search(Topk), + ); // f16 , - benchmarks.register("async-sq-8-bit-f16", imp::ScalarQuantized::<8, f16>::new()); - benchmarks.register("async-sq-4-bit-f16", imp::ScalarQuantized::<4, f16>::new()); - benchmarks.register("async-sq-2-bit-f16", imp::ScalarQuantized::<2, f16>::new()); - benchmarks.register("async-sq-1-bit-f16", imp::ScalarQuantized::<1, f16>::new()); + benchmarks.register( + "async-sq-8-bit-f16", + imp::ScalarQuantized::<8, f16>::new().search(Topk), + ); + benchmarks.register( + "async-sq-4-bit-f16", + imp::ScalarQuantized::<4, f16>::new().search(Topk), + ); + benchmarks.register( + "async-sq-2-bit-f16", + imp::ScalarQuantized::<2, f16>::new().search(Topk), + ); + benchmarks.register( + "async-sq-1-bit-f16", + imp::ScalarQuantized::<1, f16>::new().search(Topk), + ); // i8 - benchmarks.register("async-sq-1-bit-i8", imp::ScalarQuantized::<1, i8>::new()); + benchmarks.register( + "async-sq-1-bit-i8", + imp::ScalarQuantized::<1, i8>::new().search(Topk), + ); } // Stub implementation @@ -37,6 +65,7 @@ mod imp { use std::{io::Write, sync::Arc}; use anyhow::Context; + use diskann::utils::VectorRepr; use diskann_benchmark_runner::{ describeln, dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, @@ -55,25 +84,60 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, run_search_outer, FullPrecision}, + benchmarks::{run_build, FullPrecision, QueryType, Strategy}, build::{self, load_index, only_single_insert, save_index, BuildStats}, - result::QuantBuildResult, + result::{BuildResult, QuantBuildResult}, + search::plugins, }, inputs::async_::{IndexSQOperation, IndexSource}, utils::{self, datafiles}, }; + type SQProvider = inmem::DefaultProvider< + inmem::FullPrecisionStore, + inmem::SQStore, + common::NoDeletes, + diskann::provider::DefaultContext, + >; + + impl QueryType for SQProvider + where + T: VectorRepr, + { + type Element = T; + } + // Scalar Quantized - pub(super) struct ScalarQuantized { - _type: std::marker::PhantomData, + pub(super) struct ScalarQuantized + where + T: VectorRepr, + { + quant_search: plugins::Plugins, Strategy>, + full_search: plugins::Plugins, Strategy>, } - impl ScalarQuantized { + impl ScalarQuantized + where + T: VectorRepr, + { pub(super) fn new() -> Self { Self { - _type: std::marker::PhantomData, + quant_search: plugins::Plugins::new(), + full_search: plugins::Plugins::new(), } } + + pub(super) fn search

(mut self, plugin: P) -> Self + where + P: plugins::Plugin, Strategy> + + plugins::Plugin, Strategy> + + Clone + + 'static, + { + self.quant_search.register(plugin.clone()); + self.full_search.register(plugin); + self + } } macro_rules! impl_sq_build { @@ -240,27 +304,23 @@ mod imp { }; - let build = if input.use_fp_for_search { - run_search_outer( - &input.index_operation.search_phase, - common::FullPrecision, + let search = if input.use_fp_for_search { + self.full_search.run( index, - build_stats, - checkpoint, + &Strategy::new(common::FullPrecision), + &input.index_operation.search_phase, )? } else { - run_search_outer( - &input.index_operation.search_phase, - common::Quantized, + self.quant_search.run( index, - build_stats, - checkpoint, + &Strategy::new(common::Quantized), + &input.index_operation.search_phase, )? }; let result = QuantBuildResult { quant_training_time, - build, + build: BuildResult::new(build_stats, search), }; writeln!(output, "\n\n{}", result)?; From f050e3d9cbe8e601e11216b32d6ec3929b80e523 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 29 Apr 2026 13:55:40 -0700 Subject: [PATCH 04/38] Oh boy. --- diskann-benchmark-runner/src/any.rs | 12 - diskann-benchmark-runner/src/app.rs | 23 +- diskann-benchmark-runner/src/utils/fmt.rs | 234 +++++++++ .../tests/benchmark/test-4/stdout.txt | 15 +- .../test-debug-mode-error/stdout.txt | 2 +- .../benchmark/test-mismatch-0/stdout.txt | 12 +- .../benchmark/test-mismatch-1/stdout.txt | 12 +- diskann-benchmark-simd/src/lib.rs | 13 +- .../src/backend/exhaustive/minmax.rs | 9 +- .../src/backend/exhaustive/product.rs | 5 +- .../src/backend/exhaustive/spherical.rs | 9 +- .../src/backend/filters/benchmark.rs | 8 +- .../src/backend/index/benchmarks.rs | 275 ++++------- .../src/backend/index/product.rs | 63 ++- diskann-benchmark/src/backend/index/result.rs | 28 -- diskann-benchmark/src/backend/index/scalar.rs | 77 +-- .../src/backend/index/search/plugins.rs | 69 ++- .../src/backend/index/spherical.rs | 450 ++++++++++-------- diskann-benchmark/src/inputs/async_.rs | 68 ++- diskann-benchmark/src/utils/mod.rs | 6 +- 20 files changed, 867 insertions(+), 523 deletions(-) diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs index dd57dbf69..b06cfdad3 100644 --- a/diskann-benchmark-runner/src/any.rs +++ b/diskann-benchmark-runner/src/any.rs @@ -255,18 +255,6 @@ impl Any { } } -/// Used in `DispatchRule::description(f, _)` to ensure that additional description -/// lines are properly aligned. -#[macro_export] -macro_rules! describeln { - ($writer:ident, $fmt:literal) => { - writeln!($writer, concat!(" ", $fmt)) - }; - ($writer:ident, $fmt:literal, $($args:expr),* $(,)?) => { - writeln!($writer, concat!(" ", $fmt), $($args,)*) - }; -} - trait SerializableAny: std::fmt::Debug { fn as_any(&self) -> &dyn std::any::Any; fn dump(&self) -> Result; diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index d5700815e..ad5ff3936 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -70,7 +70,7 @@ use crate::{ output::Output, registry, result::Checkpoint, - utils::fmt::Banner, + utils::fmt::{Banner, Indent}, }; /// Check if we're running in debug mode and error if not allowed. @@ -227,14 +227,12 @@ impl App { Commands::Benchmarks {} => { writeln!(output, "Registered Benchmarks:")?; for (name, description) in benchmarks.names() { - let mut lines = description.lines(); - if let Some(first) = lines.next() { - writeln!(output, " {}: {}", name, first)?; - for line in lines { - writeln!(output, " {}", line)?; - } + write!(output, " {name}:")?; + if description.is_empty() { + writeln!(output)?; } else { - writeln!(output, " {}: ", name)?; + writeln!(output)?; + write!(output, "{}", Indent::new(&description, 8))?; } } } @@ -264,13 +262,8 @@ impl App { )?; writeln!(output, "Closest matches:\n")?; for (i, mismatch) in mismatches.into_iter().enumerate() { - writeln!( - output, - " {}. \"{}\": {}", - i + 1, - mismatch.method(), - mismatch.reason(), - )?; + writeln!(output, " {}. \"{}\":", i + 1, mismatch.method(),)?; + writeln!(output, "{}", Indent::new(mismatch.reason(), 8),)?; } writeln!(output)?; diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index e00ec9275..b0823f5e1 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -189,6 +189,163 @@ impl std::fmt::Display for Banner<'_> { } } +//////////// +// Indent // +//////////// + +/// Indents each line of a string by a fixed number of spaces. +/// +/// Each line is prefixed with `spaces` spaces and terminated with a newline. +/// +/// # Examples +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::Indent; +/// +/// let indented = Indent::new("hello\nworld", 4).to_string(); +/// assert_eq!(indented, " hello\n world\n"); +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct Indent<'a> { + string: &'a str, + spaces: usize, +} + +impl<'a> Indent<'a> { + /// Create a new [`Indent`] that will prefix each line of `string` with `spaces` spaces. + pub fn new(string: &'a str, spaces: usize) -> Self { + Self { string, spaces } + } +} + +impl std::fmt::Display for Indent<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let spaces = self.spaces; + self.string + .lines() + .try_for_each(|ln| writeln!(f, "{: >spaces$}{}", "", ln)) + } +} + +///////////// +// Delimit // +///////////// + +/// Formats an iterator with a delimiter between items and an optional distinct last delimiter. +/// +/// This is a single-use wrapper: the iterator is consumed on the first call to [`Display::fmt`]. +/// Subsequent calls will print ``. +/// +/// The `last` parameter allows a different delimiter before the final item (e.g., `", and "`), +/// which is useful for natural-language lists like `"a, b, and c"`. +/// +/// # Examples +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::Delimit; +/// +/// let d = Delimit::new(["a", "b", "c"], ", ", Some(", and ")); +/// assert_eq!(d.to_string(), "a, b, and c"); +/// ``` +pub struct Delimit<'a, I> { + itr: std::cell::Cell>, + delimiter: &'a str, + last: Option<&'a str>, +} + +impl<'a, I> Delimit<'a, I> { + /// Create a new [`Delimit`] from an iterable, a delimiter, and an optional last delimiter. + /// + /// If `last` is `None`, the regular `delimiter` is used before the final item. + pub fn new( + itr: impl IntoIterator, + delimiter: &'a str, + last: Option<&'a str>, + ) -> Self { + Self { + itr: std::cell::Cell::new(Some(itr.into_iter())), + delimiter, + last, + } + } +} + +impl std::fmt::Display for Delimit<'_, I> +where + I: Iterator, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Some(mut itr) = self.itr.take() else { + return write!(f, ""); + }; + + let mut first = true; + let mut current = if let Some(item) = itr.next() { + item + } else { + // Empty iterator + return Ok(()); + }; + + loop { + match itr.next() { + None => { + // "current" is the last item. If it is also the first, we write it + // directly. Otherwise, we use the "last" delimiter if available, falling + // back to "delimiter". + let delimiter = if first { + "" + } else if let Some(last) = self.last { + last + } else { + self.delimiter + }; + + return write!(f, "{}{}", delimiter, current); + } + Some(next) => { + // There is at least one item next. We print "current" and move on. + let delimiter = if first { + first = false; + "" + } else { + self.delimiter + }; + + write!(f, "{}{}", delimiter, current)?; + current = next; + } + } + } + } +} + +/////////// +// Quote // +/////////// + +/// Wraps a value in double quotes when displayed. +/// +/// # Examples +/// +/// ``` +/// use diskann_benchmark_runner::utils::fmt::Quote; +/// +/// assert_eq!(Quote("hello").to_string(), "\"hello\""); +/// assert_eq!(Quote(42).to_string(), "\"42\""); +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct Quote(pub T); + +impl std::fmt::Display for Quote +where + T: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "\"{}\"", self.0) + } +} + /////////// // Tests // /////////// @@ -327,4 +484,81 @@ string, , string let mut row = table.row(0); row.insert(1, 3); } + + #[test] + fn test_indent_single_line() { + let s = Indent::new("hello", 4).to_string(); + assert_eq!(s, " hello\n"); + } + + #[test] + fn test_indent_multi_line() { + let s = Indent::new("hello\nworld\nfoo", 2).to_string(); + assert_eq!(s, " hello\n world\n foo\n"); + } + + #[test] + fn test_indent_zero_spaces() { + let s = Indent::new("hello\nworld", 0).to_string(); + assert_eq!(s, "hello\nworld\n"); + } + + #[test] + fn test_indent_empty_string() { + let s = Indent::new("", 4).to_string(); + assert_eq!(s, ""); + } + + #[test] + fn test_delimit_empty() { + let d = Delimit::new(std::iter::empty::<&str>(), ", ", None); + assert_eq!(d.to_string(), ""); + } + + #[test] + fn test_delimit_single_item() { + let d = Delimit::new(["a"], ", ", Some(", and ")); + assert_eq!(d.to_string(), "a"); + } + + #[test] + fn test_delimit_two_items_with_last() { + let d = Delimit::new(["a", "b"], ", ", Some(", and ")); + assert_eq!(d.to_string(), "a, and b"); + } + + #[test] + fn test_delimit_three_items_with_last() { + let d = Delimit::new(["a", "b", "c"], ", ", Some(", and ")); + assert_eq!(d.to_string(), "a, b, and c"); + } + + #[test] + fn test_delimit_without_last() { + let d = Delimit::new(["x", "y", "z"], " | ", None); + assert_eq!(d.to_string(), "x | y | z"); + } + + #[test] + fn test_delimit_second_display_prints_missing() { + let d = Delimit::new(["a", "b"], ", ", None); + assert_eq!(d.to_string(), "a, b"); + assert_eq!(d.to_string(), ""); + } + + #[test] + fn test_quote() { + assert_eq!(Quote("hello").to_string(), "\"hello\""); + } + + #[test] + fn test_quote_with_integer() { + assert_eq!(Quote(42).to_string(), "\"42\""); + } + + #[test] + fn test_delimit_with_quote() { + let d = Delimit::new(["topk", "range"].iter().map(Quote), ", ", Some(", and ")); + assert_eq!(d.to_string(), "\"topk\", and \"range\""); + } } diff --git a/diskann-benchmark-runner/tests/benchmark/test-4/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-4/stdout.txt index 0a37ae5a5..05b860c2b 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-4/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-4/stdout.txt @@ -1,11 +1,16 @@ Registered Benchmarks: - type-bench-f32: tag "test-input-types" + type-bench-f32: + tag "test-input-types" float32 - type-bench-i8: tag "test-input-types" + type-bench-i8: + tag "test-input-types" int8 - exact-type-bench-f32-1000: tag "test-input-types" + exact-type-bench-f32-1000: + tag "test-input-types" float32, dim=1000 - simple-bench: tag "test-input-dim" + simple-bench: + tag "test-input-dim" dim=None only - dim-bench: tag "test-input-dim" + dim-bench: + tag "test-input-dim" matches all \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/benchmark/test-debug-mode-error/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-debug-mode-error/stdout.txt index 59f66e1b6..f5868b940 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-debug-mode-error/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-debug-mode-error/stdout.txt @@ -1,2 +1,2 @@ Benchmarking in debug mode produces misleading performance results. -Please compile in release mode or use the --allow-debug flag to bypass this check. +Please compile in release mode or use the --allow-debug flag to bypass this check. \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt index 62e7a0ba9..ba72e9bbf 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt @@ -8,8 +8,14 @@ Could not find a match for the following input: Closest matches: - 1. "type-bench-f32": expected "float32" but found "float16" - 2. "type-bench-i8": expected "int8" but found "float16" - 3. "exact-type-bench-f32-1000": expected "float32" but found "float16"; expected dim=1000, but found dim=128 + 1. "type-bench-f32": + expected "float32" but found "float16" + + 2. "type-bench-i8": + expected "int8" but found "float16" + + 3. "exact-type-bench-f32-1000": + expected "float32" but found "float16"; expected dim=1000, but found dim=128 + could not find a benchmark for all inputs \ No newline at end of file diff --git a/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt index 3e4c4ca50..34be87554 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt @@ -8,8 +8,14 @@ Could not find a match for the following input: Closest matches: - 1. "type-bench-f32": expected "float32" but found "float16" - 2. "type-bench-i8": expected "int8" but found "float16" - 3. "exact-type-bench-f32-1000": expected "float32" but found "float16" + 1. "type-bench-f32": + expected "float32" but found "float16" + + 2. "type-bench-i8": + expected "int8" but found "float16" + + 3. "exact-type-bench-f32-1000": + expected "float32" but found "float16" + could not find a benchmark for all inputs \ No newline at end of file diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 8d72efb91..d6d0f86bb 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -21,7 +21,6 @@ use thiserror::Error; use diskann_benchmark_runner::{ benchmark::{PassFail, Regression}, - describeln, dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, utils::{ datatype::{self, DataType}, @@ -594,17 +593,17 @@ where ) -> std::fmt::Result { match input { None => { - describeln!( + writeln!( f, "- Query Type: {}", Description::>::new() )?; - describeln!( + writeln!( f, "- Data Type: {}", Description::>::new() )?; - describeln!( + writeln!( f, "- Implementation: {}", Description::>::new() @@ -612,13 +611,13 @@ where } Some(input) => { if let Err(err) = datatype::Type::::try_match_verbose(&input.query_type) { - describeln!(f, "\n - Mismatched query type: {}", err)?; + writeln!(f, "\n - Mismatched query type: {}", err)?; } if let Err(err) = datatype::Type::::try_match_verbose(&input.data_type) { - describeln!(f, "\n - Mismatched data type: {}", err)?; + writeln!(f, "\n - Mismatched data type: {}", err)?; } if let Err(err) = Identity::::try_match_verbose(&input.arch) { - describeln!(f, "\n - Mismatched architecture: {}", err)?; + writeln!(f, "\n - Mismatched architecture: {}", err)?; } } } diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index 73b57733f..3516ab568 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -33,7 +33,6 @@ mod imp { use std::{io::Write, num::NonZeroUsize}; use diskann_benchmark_runner::{ - describeln, dispatcher::{FailureScore, MatchScore}, utils::{percentiles, MicroSeconds}, Benchmark, Output, @@ -225,17 +224,17 @@ mod imp { ) -> std::fmt::Result { match input { None => { - describeln!( + writeln!( f, "- Exhaustive search for {}-bit minmax quantization", NBITS )?; - describeln!(f, "- Requires `float32` data")?; - describeln!(f, "- Implements `squared_l2` or `inner_product` distance")?; + writeln!(f, "- Requires `float32` data")?; + writeln!(f, "- Implements `squared_l2` or `inner_product` distance")?; } Some(from) => { if from.num_bits.get() != NBITS { - describeln!( + writeln!( f, "- Expected \"num_bits = {}\", instead got {}", NBITS, diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 7504a28e3..ca15a02f7 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -26,7 +26,6 @@ mod imp { use std::io::Write; use diskann_benchmark_runner::{ - describeln, dispatcher::{FailureScore, MatchScore}, utils::{percentiles, MicroSeconds}, Benchmark, Output, @@ -206,8 +205,8 @@ mod imp { input: Option<&inputs::exhaustive::Product>, ) -> std::fmt::Result { if input.is_none() { - describeln!(f, "- Exhaustive search for product quantization",)?; - describeln!(f, "- Requires `float32` data")?; + writeln!(f, "- Exhaustive search for product quantization",)?; + writeln!(f, "- Requires `float32` data")?; } Ok(()) } diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index 9b1f9a935..b7dfd69b0 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -33,7 +33,6 @@ mod imp { use std::io::Write; use diskann_benchmark_runner::{ - describeln, dispatcher::{FailureScore, MatchScore}, utils::{percentiles, MicroSeconds}, Benchmark, Output, @@ -232,17 +231,17 @@ mod imp { ) -> std::fmt::Result { match input { None => { - describeln!( + writeln!( f, "- Exhaustive search for {}-bit spherical quantization", NBITS )?; - describeln!(f, "- Requires `float32` data")?; - describeln!(f, "- Implements `squared_l2` or `inner_product` distance")?; + writeln!(f, "- Requires `float32` data")?; + writeln!(f, "- Implements `squared_l2` or `inner_product` distance")?; } Some(from) => { if from.num_bits.get() != NBITS { - describeln!( + writeln!( f, "- Expected \"num_bits = {}\", instead got {}", NBITS, diff --git a/diskann-benchmark/src/backend/filters/benchmark.rs b/diskann-benchmark/src/backend/filters/benchmark.rs index 43a0717b4..7ce92420b 100644 --- a/diskann-benchmark/src/backend/filters/benchmark.rs +++ b/diskann-benchmark/src/backend/filters/benchmark.rs @@ -46,14 +46,10 @@ impl Benchmark for MetadataIndexJob { fn description( &self, - f: &mut std::fmt::Formatter<'_>, + _f: &mut std::fmt::Formatter<'_>, _input: Option<&MetadataIndexBuild>, ) -> std::fmt::Result { - writeln!( - f, - "tag: \"{}\"", - crate::inputs::filters::MetadataIndexBuild::tag() - ) + Ok(()) } fn run( diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index ae0e88abb..f3a479b2c 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{any::Any, io::Write, marker::PhantomData, num::NonZeroUsize, sync::Arc}; +use std::{io::Write, num::NonZeroUsize, sync::Arc}; use diskann::{ graph::SampleableForStart, @@ -60,6 +60,17 @@ use crate::{ //////////////////////////// pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { + // Notes on registration: + // + // We register all supported search types for `f32`, but intentionally limit the number + // of search types for the other data types mainly to help reduce compilation time. + // + // Feel free to add additional search plugins as needed during exploration and add them + // permanently if demand is sufficient. + // + // Note that each plugin registration will trigger an new monomorphization, so use with + // care. + // Full Precision benchmarks.register( "async-full-precision-f32", @@ -113,6 +124,9 @@ type FullPrecisionProvider = inmem::DefaultProvider< DefaultContext, >; +/// Associate a type (usually a [`diskann::provider::DataProvider`]) with a full-precision +/// element type. This is used in implementations of [`plugins::Plugin`] to derive the +/// correct query types to load. pub(super) trait QueryType { type Element: VectorRepr; } @@ -124,8 +138,8 @@ where type Element = T; } -// Full Precision -pub(super) struct FullPrecision +/// A [`Benchmark`] for full-precision searches containing a dynamic list of search types. +struct FullPrecision where T: VectorRepr, { @@ -136,13 +150,13 @@ impl FullPrecision where T: VectorRepr, { - pub(super) fn new() -> Self { + fn new() -> Self { Self { plugins: plugins::Plugins::new(), } } - pub(super) fn search

(mut self, plugin: P) -> Self + fn search

(mut self, plugin: P) -> Self where P: plugins::Plugin, Strategy> + 'static, { @@ -162,9 +176,14 @@ where type Output = BuildResult; fn try_match(&self, input: &IndexOperation) -> Result { - match &input.source { - IndexSource::Load(load) => datatype::Type::::try_match(&load.data_type), - IndexSource::Build(build) => datatype::Type::::try_match(&build.data_type), + let score = datatype::Type::::try_match(input.source.data_type()); + if self.plugins.is_match(input.search_phase.kind()) { + score + } else { + match score { + Ok(_) => Err(FailureScore(0)), + Err(score) => Err(score), + } } } @@ -173,16 +192,39 @@ where f: &mut std::fmt::Formatter<'_>, input: Option<&IndexOperation>, ) -> std::fmt::Result { + use diskann_benchmark_runner::dispatcher::{Description, Why}; + match input { - Some(arg) => match &arg.source { - IndexSource::Load(load) => { - datatype::Type::::description(f, Some(&load.data_type)) - } - IndexSource::Build(build) => { - datatype::Type::::description(f, Some(&build.data_type)) + Some(arg) => { + let data_type = match &arg.source { + IndexSource::Load(load) => &load.data_type, + IndexSource::Build(build) => &build.data_type, + }; + writeln!( + f, + "Data/Query Type: {}", + Why::>::new(data_type) + )?; + + if !self.plugins.is_match(arg.search_phase.kind()) { + writeln!( + f, + "Unsupported search phase: \"{}\" - expected one of {}", + arg.search_phase.kind(), + self.plugins.format_kinds(), + )?; } - }, - None => datatype::Type::::description(f, None::<&datatype::DataType>), + Ok(()) + } + None => { + writeln!( + f, + "Data/Query Type: {}", + Description::>::new() + )?; + + writeln!(f, "Search Kinds: {}", self.plugins.format_kinds()) + } } } @@ -233,6 +275,9 @@ where } }; + // Save construction stats before running queries. + checkpoint.checkpoint(&build_stats)?; + let search_results = self.plugins.run( index, &Strategy::new(common::FullPrecision), @@ -379,161 +424,10 @@ where Ok((index, build_stats)) } -pub(super) fn run_search_outer( - input: &SearchPhase, - search_strategy: S, - index: Index, - build_stats: Option, - checkpoint: Checkpoint<'_>, -) -> anyhow::Result -where - DP: DataProvider - + for<'a> provider::SetElement<&'a [T]>, - T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, - S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, -{ - match &input { - SearchPhase::Topk(search_phase) => { - // Handle Topk search phase - let mut result = BuildResult::new_topk(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Topk(search_results)); - Ok(result) - } - SearchPhase::Range(search_phase) => { - // Handle Range search phase - let mut result = BuildResult::new_range(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let steps = search::range::RangeSearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let range = benchmark_core::search::graph::Range::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - - let search_results = search::range::run(&range, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Range(search_results)); - Ok(result) - } - SearchPhase::TopkBetaFilter(search_phase) => { - // Handle Beta Filtered Topk search phase - let mut result = BuildResult::new_topk(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let bit_maps = - generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; - - let search_strategies = setup_filter_strategies( - search_phase.beta, - bit_maps - .into_iter() - .map(utils::filters::as_query_label_provider), - search_strategy.clone(), - ); - - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::collection(search_strategies), - )?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Topk(search_results)); - Ok(result) - } - SearchPhase::TopkMultihopFilter(search_phase) => { - // Handle MultiHop Topk search phase - let mut result = BuildResult::new_topk(build_stats); - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_range_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let bit_maps = - generate_bitmaps(&search_phase.query_predicates, &search_phase.data_labels)?; - - let multihop = benchmark_core::search::graph::MultiHop::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - bit_maps - .into_iter() - .map(utils::filters::as_query_label_provider) - .collect(), - )?; - - let search_results = search::knn::run(&multihop, &groundtruth, steps)?; - result.append(AggregatedSearchResults::Topk(search_results)); - Ok(result) - } - } -} - +/// A new-type wrapper for [`glue::SearchStrategy`]. +/// +/// This exists so we can implement [`search::Plugin`] for a raw generic `DP` without +/// forming a blanket implementation for all `DP`/parameter `P` pairs. #[derive(Debug, Clone, Copy)] pub(super) struct Strategy(S); @@ -541,8 +435,19 @@ impl Strategy { pub(super) fn new(strategy: S) -> Self { Self(strategy) } + + pub(super) fn inner(&self) -> S + where + S: Clone, + { + self.0.clone() + } } +//------// +// Topk // +//------// + impl search::Plugin> for plugins::Topk where DP: DataProvider + QueryType, @@ -558,7 +463,7 @@ where strategy: &Strategy, phase: &SearchPhase, ) -> anyhow::Result { - let topk = phase.as_topk().unwrap(); + let topk = phase.as_topk()?; let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); @@ -568,7 +473,7 @@ where let knn = benchmark_core::search::graph::KNN::new( index.clone(), queries, - benchmark_core::search::graph::Strategy::broadcast(strategy.0.clone()), + benchmark_core::search::graph::Strategy::broadcast(strategy.inner()), )?; let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); @@ -578,6 +483,10 @@ where } } +//-------// +// Range // +//-------// + impl search::Plugin> for plugins::Range where DP: DataProvider + QueryType, @@ -593,7 +502,7 @@ where strategy: &Strategy, phase: &SearchPhase, ) -> anyhow::Result { - let range = phase.as_range().unwrap(); + let range = phase.as_range()?; let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&range.queries))?); @@ -606,7 +515,7 @@ where let range = benchmark_core::search::graph::Range::new( index, queries, - benchmark_core::search::graph::Strategy::broadcast(strategy.0.clone()), + benchmark_core::search::graph::Strategy::broadcast(strategy.inner()), )?; let result = search::range::run(&range, &groundtruth, steps)?; @@ -614,6 +523,10 @@ where } } +//------------// +// BetaFilter // +//------------// + impl search::Plugin> for plugins::BetaFilter where DP: DataProvider + QueryType, @@ -629,7 +542,7 @@ where strategy: &Strategy, phase: &SearchPhase, ) -> anyhow::Result { - let beta_filter = phase.as_topk_beta_filter().unwrap(); + let beta_filter = phase.as_topk_beta_filter()?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&beta_filter.queries), @@ -645,7 +558,7 @@ where bit_maps .into_iter() .map(utils::filters::as_query_label_provider), - strategy.0.clone(), + strategy.inner(), ); let knn = benchmark_core::search::graph::KNN::new( @@ -665,6 +578,10 @@ where } } +//----------------// +// MultihopFilter // +//----------------// + impl search::Plugin> for plugins::MultihopFilter where DP: DataProvider + QueryType, @@ -680,7 +597,7 @@ where strategy: &Strategy, phase: &SearchPhase, ) -> anyhow::Result { - let multihop = phase.as_topk_multihop_filter().unwrap(); + let multihop = phase.as_topk_multihop_filter()?; let queries: Arc> = Arc::new(datafiles::load_dataset( datafiles::BinFile(&multihop.queries), @@ -697,7 +614,7 @@ where let multihop = benchmark_core::search::graph::MultiHop::new( index, queries, - benchmark_core::search::graph::Strategy::broadcast(strategy.0.clone()), + benchmark_core::search::graph::Strategy::broadcast(strategy.inner()), bit_maps .into_iter() .map(utils::filters::as_query_label_provider) diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index f1839fc85..30e55df7a 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -14,6 +14,10 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { use crate::backend::index::search::plugins; use half::f16; + // NOTE: Try to balance search plugins with the needed functionality. + // + // Feel free to add search plugins, but be mindful of the monomorphization cost. + benchmarks.register( "async-pq-f32", imp::ProductQuantized::::new() @@ -54,7 +58,7 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, FullPrecision, QueryType, Strategy}, + benchmarks::{run_build, QueryType, Strategy}, build::{self, load_index, save_index, single_or_multi_insert, BuildStats}, result::{BuildResult, QuantBuildResult}, search::plugins, @@ -77,6 +81,10 @@ mod imp { type Element = T; } + /// A [`Benchmark`] for product-quantized searches containing a dynamic list of search + /// types. + /// + /// The kinds of quantized and full-precision searches are kept in-sync. pub(super) struct ProductQuantized where T: VectorRepr, @@ -120,7 +128,19 @@ mod imp { type Output = QuantBuildResult; fn try_match(&self, input: &IndexPQOperation) -> Result { - FullPrecision::::new().try_match(&input.index_operation) + let score = datatype::Type::::try_match(input.index_operation.source.data_type()); + + if self + .quant_search + .is_match(input.index_operation.search_phase.kind()) + { + score + } else { + match score { + Ok(_) => Err(FailureScore(0)), + Err(score) => Err(score), + } + } } fn description( @@ -128,7 +148,41 @@ mod imp { f: &mut std::fmt::Formatter<'_>, input: Option<&IndexPQOperation>, ) -> std::fmt::Result { - FullPrecision::::new().description(f, input.map(|f| &f.index_operation)) + use diskann_benchmark_runner::dispatcher::{Description, Why}; + + match input { + Some(arg) => { + writeln!( + f, + "{}", + Why::>::new( + arg.index_operation.source.data_type() + ) + )?; + + if !self + .quant_search + .is_match(arg.index_operation.search_phase.kind()) + { + writeln!( + f, + "Unsupported search phase: \"{}\" - expected one of {}", + arg.index_operation.search_phase.kind(), + self.quant_search.format_kinds(), + )?; + } + Ok(()) + } + None => { + writeln!( + f, + "Data/Query Type: {}", + Description::>::new() + )?; + + writeln!(f, "Search Kinds: {}", self.quant_search.format_kinds()) + } + } } fn run( @@ -204,6 +258,9 @@ mod imp { } }; + // Save construction stats before running queries. + checkpoint.checkpoint(&build_stats)?; + let search = if input.use_fp_for_search { self.full_search.run( index, diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 3cb8c86dc..a68650d14 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -22,27 +22,9 @@ pub(super) struct BuildResult { } impl BuildResult { - pub(super) fn new_topk(build: Option) -> Self { - Self { - build, - search: AggregatedSearchResults::Topk(Vec::new()), - } - } - - pub(super) fn new_range(build: Option) -> Self { - Self { - build, - search: AggregatedSearchResults::Range(Vec::new()), - } - } - pub(super) fn new(build: Option, search: AggregatedSearchResults) -> Self { Self { build, search } } - - pub(super) fn append(&mut self, search: AggregatedSearchResults) { - self.search.append(search); - } } impl std::fmt::Display for BuildResult { @@ -90,16 +72,6 @@ pub(super) enum AggregatedSearchResults { Range(Vec), } -impl AggregatedSearchResults { - pub(super) fn append(&mut self, search: AggregatedSearchResults) { - match (self, search) { - (Self::Topk(v), AggregatedSearchResults::Topk(s)) => v.extend(s), - (Self::Range(v), AggregatedSearchResults::Range(s)) => v.extend(s), - _ => panic!("Mismatched search result types"), - } - } -} - impl std::fmt::Display for AggregatedSearchResults { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index 503bbc9d6..6b90c9647 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -14,6 +14,10 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { use crate::backend::index::search::plugins::Topk; use half::f16; + // NOTE: We just register `Topk` for now to reduce compilation cost. + // + // Feel free to add search plugins, but be mindful of the monomorphization cost. + // f32 benchmarks.register( "async-sq-8-bit-f32", @@ -67,7 +71,6 @@ mod imp { use anyhow::Context; use diskann::utils::VectorRepr; use diskann_benchmark_runner::{ - describeln, dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, utils::{datatype, MicroSeconds}, Benchmark, Checkpoint, Output, @@ -84,7 +87,7 @@ mod imp { use crate::{ backend::index::{ - benchmarks::{run_build, FullPrecision, QueryType, Strategy}, + benchmarks::{run_build, QueryType, Strategy}, build::{self, load_index, only_single_insert, save_index, BuildStats}, result::{BuildResult, QuantBuildResult}, search::plugins, @@ -107,7 +110,10 @@ mod imp { type Element = T; } - // Scalar Quantized + /// A [`Benchmark`] for scalar-quantized searches containing a dynamic list of search + /// types. + /// + /// The kinds of quantized and full-precision searches are kept in-sync. pub(super) struct ScalarQuantized where T: VectorRepr, @@ -157,12 +163,16 @@ mod imp { } } - if FullPrecision::<$T>::new().try_match(&input.index_operation) + if datatype::Type::<$T>::try_match(input.index_operation.source.data_type()) .is_err() { *failure_score.get_or_insert(0) += 1; } + if !self.quant_search.is_match(input.index_operation.search_phase.kind()) { + *failure_score.get_or_insert(0) += 1; + } + if input.num_bits != $N { *failure_score.get_or_insert(0) += 10 + ($N as usize).abs_diff(input.num_bits) as u32; } @@ -180,22 +190,23 @@ mod imp { ) -> std::fmt::Result { match input { None => { - describeln!( + writeln!( f, "- Index Build and Search using {} scalar quantized bits", $N )?; - describeln!( + writeln!( f, "- Requires `{}` data", Description::>::new(), )?; - describeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?; - describeln!(f, "- Does not support multi-insert")?; + writeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?; + writeln!(f, "- Does not support multi-insert")?; + writeln!(f, "- Search Kinds: {}", self.quant_search.format_kinds())?; } Some(input) => { if input.num_bits != $N { - describeln!( + writeln!( f, "- Expected {} bits, instead got {}", $N, @@ -203,31 +214,32 @@ mod imp { )?; } - let mut check_match = |data_type: &datatype::DataType| { - if datatype::Type::<$T>::try_match(data_type).is_err() { - describeln!( + let data_type = input.index_operation.source.data_type(); + if datatype::Type::<$T>::try_match(data_type).is_err() { + writeln!( + f, + "- Only `{}` data type is supported. Instead, got {}", + Description::>::new(), + data_type + )?; + } + + if let IndexSource::Build(ref build) = input.index_operation.source { + if build.multi_insert.is_some() { + writeln!( f, - "- Only `{}` data type is supported. Instead, got {}", - Description::>::new(), - data_type - ).unwrap(); + "- Scalar Quantization does not support multi-insert" + )?; } - }; + } - match &input.index_operation.source { - IndexSource::Load(load) => { - check_match(&load.data_type); - } - IndexSource::Build(build) => { - check_match(&build.data_type); - - if build.multi_insert.is_some() { - describeln!( - f, - "- Scalar Quantization does not support multi-insert" - )?; - } - } + if !self.quant_search.is_match(input.index_operation.search_phase.kind()) { + writeln!( + f, + "- Unsupported search phase: \"{}\" - expected one of {}", + input.index_operation.search_phase.kind(), + self.quant_search.format_kinds(), + )?; } } } @@ -304,6 +316,9 @@ mod imp { }; + // Save construction stats before running queries. + checkpoint.checkpoint(&build_stats)?; + let search = if input.use_fp_for_search { self.full_search.run( index, diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 1945bea47..e916db38d 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -3,15 +3,47 @@ * Licensed under the MIT license. */ -use std::{any::Any, sync::Arc}; +//! Search plugins are the solution the following benchmarking problem: +//! +//! The [`SearchPhase`] enum contains a list of available search kinds. Adding a new variant +//! either requires updating **all** users to implement that related search (harming compile +//! times) or requires users to explicitly opt-out. Unfortunately, the latter is difficult +//! to maintain with benchmark matching (i.e., the desire to catch configuration mismatches +//! such as requesting an unsupported search early, rather than reaching an error late in +//! a benchmark run). Additionally, if only a subset of search kinds are supported, it +//! is user-friendly to document which variants are actually supported and to make it simple +//! to add or remove flavors. +//! +//! The solution is the [`Plugin`] trait and the [`Plugins`] helper. The trait is a +//! dyn-compatible wrapper for a search and the [`Plugins`] struct simply collects a list +//! of [`Plugin`]s. +//! +//! Implementations of [`Plugin`] declare which type of search they support, which is aggregated +//! in the [`Plugins`] helper. +//! +//! Benchmarks can then contain a [`Plugins`] field, dynamically register plugin types, and +//! then get registered in [`diskann_benchmark_runner::Benchmarks`]. The follow methods then +//! support proper reporting in the benchmark infrastructure: +//! +//! * [`Plugins::format_kinds`]: Format the registered plugins. +//! * [`Plugins::is_match`]: Return whether a [`Plugin`] is registered matching a phase. +//! * [`Plugins::run`]: Run the first matching plugin. +//! +//! Concrete plugins maintain a one-to-one relationship with variants in [`SearchPhase`] and +//! [`SearchPhaseKind`] and are simple ZSTs. + +use std::sync::Arc; use diskann::{graph::DiskANNIndex, provider::DataProvider}; +use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; use crate::{ backend::index::result::AggregatedSearchResults, inputs::async_::{SearchPhase, SearchPhaseKind}, }; +/// A search plugin for `DP`. The generic `P` is for any additional parameters needed by +/// a benchmark. pub(crate) trait Plugin: std::fmt::Debug where DP: DataProvider, @@ -19,6 +51,10 @@ where /// The flavor of `SearchPhase` this plugin is compiled for. fn kind(&self) -> SearchPhaseKind; + /// Run the search. + /// + /// The user can assume that `phase` has the same [`SearchPhaseKind`] as [`Self::kind`] + /// and may return an error if this is not the case. fn search( &self, index: Arc>, @@ -27,6 +63,7 @@ where ) -> anyhow::Result; } +/// A collection of dynamically registered [`Plugins`]. #[derive(Debug)] pub(crate) struct Plugins where @@ -39,12 +76,14 @@ impl Plugins where DP: DataProvider, { + /// Create a new empty [`Plugins`]. pub(crate) fn new() -> Self { Self { plugins: Vec::new(), } } + /// Register `plugin` in the managed collection. pub(crate) fn register(&mut self, plugin: T) where T: Plugin + 'static, @@ -52,14 +91,26 @@ where self.plugins.push(Box::new(plugin)); } - pub(crate) fn kinds(&self) -> Vec { - self.plugins.iter().map(|p| p.kind()).collect() + /// Return an iterator over all [`SearchPhaseKind`]s currently registered. + pub(crate) fn kinds(&self) -> impl ExactSizeIterator + use<'_, DP, P> { + self.plugins.iter().map(|p| p.kind()) } - pub(crate) fn is_match(&self, phase: &SearchPhase) -> bool { - self.plugins.iter().any(|p| p.kind() == phase.kind()) + /// Return whether a [`Plugin`] is registered matching `phase`. + pub(crate) fn is_match(&self, phase: SearchPhaseKind) -> bool { + self.plugins.iter().any(|p| p.kind() == phase) } + /// Return a human readable, formatted list of the registered [`SearchPhaseKind`]s. + pub(crate) fn format_kinds(&self) -> impl std::fmt::Display + use<'_, DP, P> { + Delimit::new(self.kinds().map(Quote), ", ", Some(", and ")) + } + + /// Try to run a search plugin for `phase`. + /// + /// If no such plugin exists, an "INTERNAL ERROR:" is returned. + /// Within the `diskann-benchmark` crate, pre-validation with [`Self::is_match`] should + /// be used before calling this method. pub(crate) fn run( &self, index: Arc>, @@ -76,37 +127,45 @@ where } } +/// A search plugin for vanilla top-k search. #[derive(Debug, Clone, Copy)] pub(crate) struct Topk; impl Topk { + /// Returns [`SearchPhaseKind::Topk`]. pub(crate) fn kind() -> SearchPhaseKind { SearchPhaseKind::Topk } } +/// A search plugin for range search. #[derive(Debug, Clone, Copy)] pub(crate) struct Range; impl Range { + /// Returns [`SearchPhaseKind::Range`]. pub(crate) fn kind() -> SearchPhaseKind { SearchPhaseKind::Range } } +/// A search plugin for beta-filtered search. #[derive(Debug, Clone, Copy)] pub(crate) struct BetaFilter; impl BetaFilter { + /// Returns [`SearchPhaseKind::TopkBetaFilter`]. pub(crate) fn kind() -> SearchPhaseKind { SearchPhaseKind::TopkBetaFilter } } +/// A search plugin for multi-hop filtered search. #[derive(Debug, Clone, Copy)] pub(crate) struct MultihopFilter; impl MultihopFilter { + /// Returns [`SearchPhaseKind::TopkMultihopFilter`]. pub(crate) fn kind() -> SearchPhaseKind { SearchPhaseKind::TopkMultihopFilter } diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 507337da7..76a1dae5d 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -16,9 +16,37 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "spherical-quantization")] { - benchmarks.register(NAME, imp::SphericalQ::<1>); - benchmarks.register(NAME, imp::SphericalQ::<2>); - benchmarks.register(NAME, imp::SphericalQ::<4>); + use crate::backend::index::search::plugins; + + // NOTE: Since the spherical provider is not generic on the number of bits, the + // implementations of the search-plugins are shared by all bit-widths. Registering + // all plugins for all bit widths does not meaningfully increase compilation time. + benchmarks.register( + NAME, + imp::SphericalQ::<1>::new() + .search(plugins::Topk) + .search(plugins::Range) + .search(plugins::BetaFilter) + .search(plugins::MultihopFilter), + ); + + benchmarks.register( + NAME, + imp::SphericalQ::<2>::new() + .search(plugins::Topk) + .search(plugins::Range) + .search(plugins::BetaFilter) + .search(plugins::MultihopFilter), + ); + + benchmarks.register( + NAME, + imp::SphericalQ::<4>::new() + .search(plugins::Topk) + .search(plugins::Range) + .search(plugins::BetaFilter) + .search(plugins::MultihopFilter), + ); } // Stub implementation @@ -32,17 +60,16 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { #[cfg(feature = "spherical-quantization")] mod imp { - use diskann::graph::StartPointStrategy; + use diskann::graph::{DiskANNIndex, StartPointStrategy}; use diskann_benchmark_core as benchmark_core; use diskann_benchmark_runner::{ - describeln, dispatcher::{DispatchRule, FailureScore, MatchScore}, utils::{datatype, MicroSeconds}, Benchmark, Checkpoint, Output, }; use diskann_providers::{ - index::diskann_async::{self}, - model::graph::provider::async_::{common::NoDeletes, inmem}, + index::diskann_async, + model::graph::provider::async_::{common, inmem}, }; use diskann_quantization::alloc::GlobalAllocator; use diskann_utils::views::Matrix; @@ -52,12 +79,13 @@ mod imp { use crate::{ backend::index::{ + benchmarks::QueryType, build::{self, only_single_insert, BuildStats}, result::AggregatedSearchResults, search, }, inputs::{ - async_::{SearchPhase, SphericalQuantBuild}, + async_::{SearchPhase, SearchPhaseKind, SphericalQuantBuild}, exhaustive, }, utils::{ @@ -66,8 +94,38 @@ mod imp { }, }; - /// The dispatcher target for `spherical-quantization` operations. - pub(super) struct SphericalQ; + type SQProvider = inmem::DefaultProvider< + inmem::FullPrecisionStore, + inmem::spherical::SphericalStore, + common::NoDeletes, + diskann::provider::DefaultContext, + >; + + impl QueryType for SQProvider { + type Element = f32; + } + + /// A [`Benchmark`] for spherical-quantized searches containing a dynamic list of search + /// types. + pub(super) struct SphericalQ { + search: search::plugins::Plugins, + } + + impl SphericalQ { + pub(super) fn new() -> Self { + Self { + search: search::plugins::Plugins::new(), + } + } + + pub(super) fn search

(mut self, plugin: P) -> Self + where + P: search::plugins::Plugin + 'static, + { + self.search.register(plugin); + self + } + } macro_rules! write_field { ($f:ident, $field:tt, $fmt:literal, $($expr:tt)*) => { @@ -136,6 +194,10 @@ mod imp { *failure_score.get_or_insert(0) += 1; } + if !self.search.is_match(input.search_phase.kind()) { + *failure_score.get_or_insert(0) += 1; + } + let num_bits = input.num_bits.get(); if num_bits != $N { *failure_score.get_or_insert(0) += ($N as usize) @@ -157,38 +219,45 @@ mod imp { ) -> std::fmt::Result { match input { None => { - describeln!( + writeln!( f, "- Index Build and Search using {}-bit spherical quantization", $N )?; - describeln!(f, "- Requires `float32` data")?; - describeln!( - f, - "- Implements `squared_l2` or `inner_product` distance", - )?; - describeln!(f, "- Does not support multi-insert")?; + writeln!(f, "- Requires `float32` data")?; + writeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?; + writeln!(f, "- Does not support multi-insert")?; + writeln!(f, "- Search Kinds: {}", self.search.format_kinds())?; } Some(input) => { let num_bits = input.num_bits.get(); if num_bits != $N { - describeln!(f, "- Expected {} bits, got {}", $N, num_bits)?; + writeln!(f, "- Expected {} bits, got {}", $N, num_bits)?; } if input.build.multi_insert.is_some() { - describeln!( + writeln!( f, "- Spherical Quantization does not support multi-insert" )?; } if datatype::Type::::try_match(&input.build.data_type).is_err() { - describeln!( + writeln!( f, "- Only `float32` data type is supported. Instead, got {}", input.build.data_type )?; } + + if !self.search.is_match(input.search_phase.kind()) { + writeln!( + f, + "- Unsupported search phase: \"{}\" - expected one of {}", + input.search_phase.kind(), + self.search.format_kinds(), + )?; + } } } Ok(()) @@ -237,7 +306,7 @@ mod imp { input.try_as_config()?.build()?, input.inmem_parameters(data.nrows(), data.ncols()), diskann_quantization::spherical::iface::Impl::<$N>::new(quantizer)?, - NoDeletes, + common::NoDeletes, )?; build::set_start_points( @@ -264,199 +333,192 @@ mod imp { runs: Vec::new(), }; - match &input.search_phase { - SearchPhase::Topk(search_phase) => { - // Handle Topk search phase - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset( - datafiles::BinFile(&search_phase.queries), - )?); - - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile( - &search_phase.groundtruth, - ))?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - for &layout in input.query_layouts.iter() { - let knn = benchmark_core::search::graph::KNN::new( - index.clone(), - queries.clone(), - benchmark_core::search::graph::Strategy::broadcast( - inmem::spherical::Quantized::search(layout.into()), - ), - )?; - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; - result.append(SearchRun { - layout, - results: AggregatedSearchResults::Topk(search_results), - }); - } - writeln!(output, "\n\n{}", result)?; - Ok(result) - } - SearchPhase::Range(search_phase) => { - // Handle Range search phase - - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; + // Save construction stats before running queries. + checkpoint.checkpoint(&result)?; + + for layout in input.query_layouts.iter() { + let search = self + .search + .run(index.clone(), layout, &input.search_phase)?; + result.append(SearchRun { + layout: *layout, + results: search, + }); + } - let queries: Arc> = Arc::new(datafiles::load_dataset( - datafiles::BinFile(&search_phase.queries), - )?); + writeln!(output, "\n\n{}", result)?; + Ok(result) + } + } + }; + } - let groundtruth = datafiles::load_range_groundtruth( - datafiles::BinFile(&search_phase.groundtruth), - )?; + build_and_search!(1); + build_and_search!(2); + build_and_search!(4); - let steps = search::range::RangeSearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - for &layout in input.query_layouts.iter() { - let range = benchmark_core::search::graph::Range::new( - index.clone(), - queries.clone(), - benchmark_core::search::graph::Strategy::broadcast( - inmem::spherical::Quantized::search(layout.into()), - ), - )?; + impl search::plugins::Plugin for search::plugins::Topk { + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } - let search_results = - search::range::run(&range, &groundtruth, steps)?; + fn search( + &self, + index: Arc>, + query_layout: &exhaustive::SphericalQuery, + phase: &SearchPhase, + ) -> anyhow::Result { + let topk = phase.as_topk()?; - result.append(SearchRun { - layout, - results: AggregatedSearchResults::Range(search_results), - }); - } + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - writeln!(output, "\n\n{}", result)?; - Ok(result) - } - SearchPhase::TopkBetaFilter(search_phase) => { - // Handle Beta Filtered Topk search phase + let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; + let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); - let queries: Arc> = Arc::new(datafiles::load_dataset( - datafiles::BinFile(&search_phase.queries), - )?); + let knn = benchmark_core::search::graph::KNN::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast( + inmem::spherical::Quantized::search((*query_layout).into()), + ), + )?; - let groundtruth = datafiles::load_range_groundtruth( - datafiles::BinFile(&search_phase.groundtruth), - )?; + let result = search::knn::run(&knn, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Topk(result)) + } + } - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); + impl search::plugins::Plugin for search::plugins::Range { + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } - let bit_maps = generate_bitmaps( - &search_phase.query_predicates, - &search_phase.data_labels, - )?; + fn search( + &self, + index: Arc>, + query_layout: &exhaustive::SphericalQuery, + phase: &SearchPhase, + ) -> anyhow::Result { + let range = phase.as_range()?; - let label_providers: Vec<_> = bit_maps - .into_iter() - .map(utils::filters::as_query_label_provider) - .collect(); - - for &layout in input.query_layouts.iter() { - let strategy = inmem::spherical::Quantized::search(layout.into()); - let search_strategies = setup_filter_strategies( - search_phase.beta, - label_providers.iter().cloned(), - strategy.clone(), - ); - - let knn = benchmark_core::search::graph::KNN::new( - index.clone(), - queries.clone(), - benchmark_core::search::graph::Strategy::Collection( - search_strategies.into(), - ), - )?; + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&range.queries))?); - let search_results = search::knn::run(&knn, &groundtruth, steps)?; + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&range.groundtruth))?; - result.append(SearchRun { - layout, - results: AggregatedSearchResults::Topk(search_results), - }); - } - writeln!(output, "\n\n{}", result)?; - Ok(result) - } - SearchPhase::TopkMultihopFilter(search_phase) => { - // Handle Beta Filtered Topk search phase + let steps = + search::range::RangeSearchSteps::new(range.reps, &range.num_threads, &range.runs); - // Save construction stats before running queries. - checkpoint.checkpoint(&result)?; + let range = benchmark_core::search::graph::Range::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast( + inmem::spherical::Quantized::search((*query_layout).into()), + ), + )?; - let queries: Arc> = Arc::new(datafiles::load_dataset( - datafiles::BinFile(&search_phase.queries), - )?); + let result = search::range::run(&range, &groundtruth, steps)?; - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile( - &search_phase.groundtruth, - ))?; + Ok(AggregatedSearchResults::Range(result)) + } + } - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); + impl search::plugins::Plugin + for search::plugins::BetaFilter + { + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } - let bit_maps = generate_bitmaps( - &search_phase.query_predicates, - &search_phase.data_labels, - )?; + fn search( + &self, + index: Arc>, + query_layout: &exhaustive::SphericalQuery, + phase: &SearchPhase, + ) -> anyhow::Result { + let betafilter = phase.as_topk_beta_filter()?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &betafilter.queries, + ))?); + + let groundtruth = + datafiles::load_range_groundtruth(datafiles::BinFile(&betafilter.groundtruth))?; + + let steps = search::knn::SearchSteps::new( + betafilter.reps, + &betafilter.num_threads, + &betafilter.runs, + ); + + let bit_maps = generate_bitmaps(&betafilter.query_predicates, &betafilter.data_labels)?; + + let label_providers: Vec<_> = bit_maps + .into_iter() + .map(utils::filters::as_query_label_provider) + .collect(); + + let strategy = inmem::spherical::Quantized::search((*query_layout).into()); + let search_strategies = + setup_filter_strategies(betafilter.beta, label_providers.iter().cloned(), strategy); + + let knn = benchmark_core::search::graph::KNN::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::Collection(search_strategies.into()), + )?; + + let result = search::knn::run(&knn, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Topk(result)) + } + } - let bit_map_filters: Arc<[_]> = bit_maps - .into_iter() - .map(utils::filters::as_query_label_provider) - .collect(); - - for &layout in input.query_layouts.iter() { - let multihop = benchmark_core::search::graph::MultiHop::new( - index.clone(), - queries.clone(), - benchmark_core::search::graph::Strategy::broadcast( - inmem::spherical::Quantized::search(layout.into()), - ), - bit_map_filters.clone(), - )?; + impl search::plugins::Plugin + for search::plugins::MultihopFilter + { + fn kind(&self) -> SearchPhaseKind { + Self::kind() + } - let search_results = - search::knn::run(&multihop, &groundtruth, steps)?; - result.append(SearchRun { - layout, - results: AggregatedSearchResults::Topk(search_results), - }); - } - writeln!(output, "\n\n{}", result)?; - Ok(result) - } - } - } - } - }; + fn search( + &self, + index: Arc>, + query_layout: &exhaustive::SphericalQuery, + phase: &SearchPhase, + ) -> anyhow::Result { + let multihop = phase.as_topk_multihop_filter()?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &multihop.queries, + ))?); + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&multihop.groundtruth))?; + + let steps = + search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs); + + let bit_maps = generate_bitmaps(&multihop.query_predicates, &multihop.data_labels)?; + + let bit_map_filters: Arc<[_]> = bit_maps + .into_iter() + .map(utils::filters::as_query_label_provider) + .collect(); + + let multihop = benchmark_core::search::graph::MultiHop::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast( + inmem::spherical::Quantized::search((*query_layout).into()), + ), + bit_map_filters.clone(), + )?; + + let result = search::knn::run(&multihop, &groundtruth, steps)?; + Ok(AggregatedSearchResults::Topk(result)) + } } - - build_and_search!(1); - build_and_search!(2); - build_and_search!(4); } diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index d987d7067..0bb8adf10 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -3,8 +3,7 @@ * Licensed under the MIT license. */ -use std::num::NonZero; -use std::num::{NonZeroU32, NonZeroUsize}; +use std::num::{NonZero, NonZeroU32, NonZeroUsize}; use anyhow::{anyhow, Context}; use diskann::{ @@ -23,6 +22,7 @@ use diskann_providers::{ utils::load_metadata_from_file, }; use serde::{Deserialize, Serialize}; +use thiserror::Error; use crate::{ inputs::{self, as_input, save_and_load, Example}, @@ -335,6 +335,23 @@ pub(crate) enum SearchPhase { TopkMultihopFilter(MultiHopSearchPhase), } +#[derive(Debug, Error)] +#[error( + "INTERNAL ERROR: expected search phase kind \"{}\" - instead got \"{}\"", + self.expected, + self.got +)] +pub(crate) struct WrongSearchPhaseKind { + expected: SearchPhaseKind, + got: SearchPhaseKind, +} + +impl WrongSearchPhaseKind { + fn new(expected: SearchPhaseKind, got: SearchPhaseKind) -> Self { + Self { expected, got } + } +} + impl SearchPhase { pub(crate) fn kind(&self) -> SearchPhaseKind { match self { @@ -345,31 +362,45 @@ impl SearchPhase { } } - pub(crate) fn as_topk(&self) -> Option<&TopkSearchPhase> { + pub(crate) fn as_topk(&self) -> Result<&TopkSearchPhase, WrongSearchPhaseKind> { match self { - Self::Topk(phase) => Some(phase), - _ => None, + Self::Topk(phase) => Ok(phase), + _ => Err(WrongSearchPhaseKind::new( + SearchPhaseKind::Topk, + self.kind(), + )), } } - pub(crate) fn as_range(&self) -> Option<&RangeSearchPhase> { + pub(crate) fn as_range(&self) -> Result<&RangeSearchPhase, WrongSearchPhaseKind> { match self { - Self::Range(phase) => Some(phase), - _ => None, + Self::Range(phase) => Ok(phase), + _ => Err(WrongSearchPhaseKind::new( + SearchPhaseKind::Range, + self.kind(), + )), } } - pub(crate) fn as_topk_beta_filter(&self) -> Option<&BetaSearchPhase> { + pub(crate) fn as_topk_beta_filter(&self) -> Result<&BetaSearchPhase, WrongSearchPhaseKind> { match self { - Self::TopkBetaFilter(phase) => Some(phase), - _ => None, + Self::TopkBetaFilter(phase) => Ok(phase), + _ => Err(WrongSearchPhaseKind::new( + SearchPhaseKind::TopkBetaFilter, + self.kind(), + )), } } - pub(crate) fn as_topk_multihop_filter(&self) -> Option<&MultiHopSearchPhase> { + pub(crate) fn as_topk_multihop_filter( + &self, + ) -> Result<&MultiHopSearchPhase, WrongSearchPhaseKind> { match self { - Self::TopkMultihopFilter(phase) => Some(phase), - _ => None, + Self::TopkMultihopFilter(phase) => Ok(phase), + _ => Err(WrongSearchPhaseKind::new( + SearchPhaseKind::TopkMultihopFilter, + self.kind(), + )), } } } @@ -693,6 +724,15 @@ pub enum IndexSource { Build(IndexBuild), } +impl IndexSource { + pub(crate) fn data_type(&self) -> &DataType { + match self { + IndexSource::Load(load) => &load.data_type, + IndexSource::Build(build) => &build.data_type, + } + } +} + impl CheckDeserialization for IndexSource { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match self { diff --git a/diskann-benchmark/src/utils/mod.rs b/diskann-benchmark/src/utils/mod.rs index cee6cacd3..97eeae777 100644 --- a/diskann-benchmark/src/utils/mod.rs +++ b/diskann-benchmark/src/utils/mod.rs @@ -101,11 +101,10 @@ macro_rules! stub_impl { #[cfg(not(feature = $feature))] mod imp { use diskann_benchmark_runner::{ - describeln, dispatcher::{FailureScore, MatchScore}, output::Output, registry::Benchmarks, - Benchmark, Checkpoint, Input, + Benchmark, Checkpoint, }; use crate::inputs; @@ -130,8 +129,7 @@ macro_rules! stub_impl { f: &mut std::fmt::Formatter<'_>, _input: Option<&$input>, ) -> std::fmt::Result { - writeln!(f, "tag: \"{}\"", <$input as Input>::tag())?; - describeln!(f, "{}", concat!("Requires the \"", $feature, "\" feature")) + writeln!(f, "{}", concat!("Requires the \"", $feature, "\" feature")) } fn run( From aa78836e1ec80b459432bc8769bcd8b6b3641d9a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 1 May 2026 12:29:13 -0700 Subject: [PATCH 05/38] Fix delimiters. --- diskann-benchmark-runner/Cargo.toml | 2 +- diskann-benchmark-runner/src/benchmark.rs | 2 +- diskann-benchmark-runner/src/checker.rs | 2 +- .../src/dispatcher/mod.rs | 4 +- diskann-benchmark-runner/src/files.rs | 2 +- .../src/internal/regression.rs | 5 +- diskann-benchmark-runner/src/jobs.rs | 2 +- diskann-benchmark-runner/src/registry.rs | 5 +- diskann-benchmark-runner/src/result.rs | 2 +- diskann-benchmark-runner/src/test/dim.rs | 2 +- diskann-benchmark-runner/src/test/typed.rs | 2 +- diskann-benchmark-runner/src/utils/fmt.rs | 66 +++++++++++++------ .../src/backend/index/search/plugins.rs | 2 +- 13 files changed, 64 insertions(+), 34 deletions(-) diff --git a/diskann-benchmark-runner/Cargo.toml b/diskann-benchmark-runner/Cargo.toml index 33cb63d0d..7c559d679 100644 --- a/diskann-benchmark-runner/Cargo.toml +++ b/diskann-benchmark-runner/Cargo.toml @@ -5,7 +5,7 @@ description.workspace = true authors.workspace = true documentation.workspace = true license.workspace = true -edition.workspace = true +edition = "2024" [dependencies] anyhow = { workspace = true } diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 27cb910a9..ac13d7afd 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize}; use crate::{ - dispatcher::{FailureScore, MatchScore}, Any, Checkpoint, Input, Output, + dispatcher::{FailureScore, MatchScore}, }; /// A registered benchmark. diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index 4b3dda556..3f54f0862 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -188,7 +188,7 @@ pub trait CheckDeserialization { mod tests { use super::*; - use std::fs::{create_dir, File}; + use std::fs::{File, create_dir}; #[test] fn test_constructor() { diff --git a/diskann-benchmark-runner/src/dispatcher/mod.rs b/diskann-benchmark-runner/src/dispatcher/mod.rs index 76eba7646..d335eb78f 100644 --- a/diskann-benchmark-runner/src/dispatcher/mod.rs +++ b/diskann-benchmark-runner/src/dispatcher/mod.rs @@ -15,8 +15,8 @@ mod api; pub use api::{ - Description, DispatchRule, FailureScore, MatchScore, TaggedFailureScore, Why, - IMPLICIT_MATCH_SCORE, + Description, DispatchRule, FailureScore, IMPLICIT_MATCH_SCORE, MatchScore, TaggedFailureScore, + Why, }; /////////// diff --git a/diskann-benchmark-runner/src/files.rs b/diskann-benchmark-runner/src/files.rs index 355b47010..09af82cb2 100644 --- a/diskann-benchmark-runner/src/files.rs +++ b/diskann-benchmark-runner/src/files.rs @@ -65,7 +65,7 @@ impl CheckDeserialization for InputFile { #[cfg(test)] mod tests { - use std::fs::{create_dir, File}; + use std::fs::{File, create_dir}; use super::*; diff --git a/diskann-benchmark-runner/src/internal/regression.rs b/diskann-benchmark-runner/src/internal/regression.rs index f9bc12061..c0fd56e87 100644 --- a/diskann-benchmark-runner/src/internal/regression.rs +++ b/diskann-benchmark-runner/src/internal/regression.rs @@ -99,9 +99,10 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use crate::{ - benchmark::{internal::CheckedPassFail, PassFail}, + Any, Checker, + benchmark::{PassFail, internal::CheckedPassFail}, internal::load_from_disk, - jobs, registry, result, Any, Checker, + jobs, registry, result, }; //////////// diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index c7a4d2108..25bd2f0a3 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,7 +8,7 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{checker::Checker, input, registry, Any}; +use crate::{Any, checker::Checker, input, registry}; #[derive(Debug)] pub(crate) struct Jobs { diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 5d8c7366c..eeae8609d 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -3,14 +3,15 @@ * Licensed under the MIT license. */ -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::{HashMap, hash_map::Entry}; use thiserror::Error; use crate::{ + Any, Checkpoint, Input, Output, benchmark::{self, Benchmark, Regression}, dispatcher::{FailureScore, MatchScore}, - input, Any, Checkpoint, Input, Output, + input, }; /// A collection of [`crate::Input`]. diff --git a/diskann-benchmark-runner/src/result.rs b/diskann-benchmark-runner/src/result.rs index cd8e34bb8..a4f68a654 100644 --- a/diskann-benchmark-runner/src/result.rs +++ b/diskann-benchmark-runner/src/result.rs @@ -7,7 +7,7 @@ use std::path::Path; -use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize, Serializer, ser::SerializeSeq}; /// A helper to generate incremental snapshots of data while a benchmark is progressing. /// diff --git a/diskann-benchmark-runner/src/test/dim.rs b/diskann-benchmark-runner/src/test/dim.rs index f0eae36a3..7a20e1801 100644 --- a/diskann-benchmark-runner/src/test/dim.rs +++ b/diskann-benchmark-runner/src/test/dim.rs @@ -8,9 +8,9 @@ use std::io::Write; use serde::{Deserialize, Serialize}; use crate::{ + Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, benchmark::{PassFail, Regression}, dispatcher::{FailureScore, MatchScore}, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, }; /////////// diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index cae95f66d..8540211cf 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -8,10 +8,10 @@ use std::io::Write; use serde::{Deserialize, Serialize}; use crate::{ + Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, benchmark::{PassFail, Regression}, dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, utils::datatype::{DataType, Type}, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, }; /////////// diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index b0823f5e1..430d2a784 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -239,33 +239,47 @@ impl std::fmt::Display for Indent<'_> { /// The `last` parameter allows a different delimiter before the final item (e.g., `", and "`), /// which is useful for natural-language lists like `"a, b, and c"`. /// +/// Finally, the `pair` parameter allows custom formatting when there are only two items. +/// /// # Examples /// /// ``` /// use diskann_benchmark_runner::utils::fmt::Delimit; /// -/// let d = Delimit::new(["a", "b", "c"], ", ", Some(", and ")); +/// let d = Delimit::new(["a", "b", "c"], ", ", Some(", and "), None); /// assert_eq!(d.to_string(), "a, b, and c"); +/// +/// let d = Delimit::new(["a", "b"], ", ", Some(", and "), None); +/// assert_eq!(d.to_string(), "a, and b"); +/// +/// let d = Delimit::new(["a", "b"], ", ", Some(", and "), Some(" and ")); +/// assert_eq!(d.to_string(), "a and b"); /// ``` pub struct Delimit<'a, I> { itr: std::cell::Cell>, delimiter: &'a str, last: Option<&'a str>, + pair: Option<&'a str>, } impl<'a, I> Delimit<'a, I> { /// Create a new [`Delimit`] from an iterable, a delimiter, and an optional last delimiter. /// /// If `last` is `None`, the regular `delimiter` is used before the final item. + /// + /// If provided, `pair` will be used as the delimiter if the length of `itr` is 2. If + /// not supplied then `last` is used if available. Otherwise, `delimiter` is used. pub fn new( itr: impl IntoIterator, delimiter: &'a str, last: Option<&'a str>, + pair: Option<&'a str>, ) -> Self { Self { itr: std::cell::Cell::new(Some(itr.into_iter())), delimiter, last, + pair, } } } @@ -279,7 +293,7 @@ where return write!(f, ""); }; - let mut first = true; + let mut count = 0; let mut current = if let Some(item) = itr.next() { item } else { @@ -291,10 +305,17 @@ where match itr.next() { None => { // "current" is the last item. If it is also the first, we write it - // directly. Otherwise, we use the "last" delimiter if available, falling - // back to "delimiter". - let delimiter = if first { + // directly. + // + // Otherwise, we check if we've just emitted a single item so far and + // use `pair` if available. Otherwise, we try `last` and finally + // `delimiter`. + let delimiter = if count == 0 { "" + } else if count == 1 + && let Some(pair) = self.pair + { + pair } else if let Some(last) = self.last { last } else { @@ -305,14 +326,10 @@ where } Some(next) => { // There is at least one item next. We print "current" and move on. - let delimiter = if first { - first = false; - "" - } else { - self.delimiter - }; + let delimiter = if count == 0 { "" } else { self.delimiter }; write!(f, "{}{}", delimiter, current)?; + count += 1; current = next; } } @@ -511,37 +528,43 @@ string, , string #[test] fn test_delimit_empty() { - let d = Delimit::new(std::iter::empty::<&str>(), ", ", None); + let d = Delimit::new(std::iter::empty::<&str>(), ", ", None, None); assert_eq!(d.to_string(), ""); } #[test] fn test_delimit_single_item() { - let d = Delimit::new(["a"], ", ", Some(", and ")); + let d = Delimit::new(["a"], ", ", Some(", and "), None); assert_eq!(d.to_string(), "a"); } #[test] fn test_delimit_two_items_with_last() { - let d = Delimit::new(["a", "b"], ", ", Some(", and ")); + let d = Delimit::new(["a", "b"], ", ", Some(", and "), None); assert_eq!(d.to_string(), "a, and b"); } + #[test] + fn test_delimit_two_items_with_pair() { + let d = Delimit::new(["a", "b"], ", ", Some(", and "), Some(" and ")); + assert_eq!(d.to_string(), "a and b"); + } + #[test] fn test_delimit_three_items_with_last() { - let d = Delimit::new(["a", "b", "c"], ", ", Some(", and ")); + let d = Delimit::new(["a", "b", "c"], ", ", Some(", and "), Some(" and ")); assert_eq!(d.to_string(), "a, b, and c"); } #[test] fn test_delimit_without_last() { - let d = Delimit::new(["x", "y", "z"], " | ", None); + let d = Delimit::new(["x", "y", "z"], " | ", None, None); assert_eq!(d.to_string(), "x | y | z"); } #[test] fn test_delimit_second_display_prints_missing() { - let d = Delimit::new(["a", "b"], ", ", None); + let d = Delimit::new(["a", "b"], ", ", None, None); assert_eq!(d.to_string(), "a, b"); assert_eq!(d.to_string(), ""); } @@ -558,7 +581,12 @@ string, , string #[test] fn test_delimit_with_quote() { - let d = Delimit::new(["topk", "range"].iter().map(Quote), ", ", Some(", and ")); - assert_eq!(d.to_string(), "\"topk\", and \"range\""); + let d = Delimit::new( + ["topk", "range"].iter().map(Quote), + ", ", + Some(", and "), + Some(" and "), + ); + assert_eq!(d.to_string(), "\"topk\" and \"range\""); } } diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index e916db38d..eff64e282 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -103,7 +103,7 @@ where /// Return a human readable, formatted list of the registered [`SearchPhaseKind`]s. pub(crate) fn format_kinds(&self) -> impl std::fmt::Display + use<'_, DP, P> { - Delimit::new(self.kinds().map(Quote), ", ", Some(", and ")) + Delimit::new(self.kinds().map(Quote), ", ", Some(", and "), Some(" and ")) } /// Try to run a search plugin for `phase`. From ccf891af3194e47bf66019dd451460c649b9fb94 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 1 May 2026 12:41:53 -0700 Subject: [PATCH 06/38] Revert formatting. --- diskann-benchmark-runner/Cargo.toml | 2 +- diskann-benchmark-runner/src/benchmark.rs | 2 +- diskann-benchmark-runner/src/checker.rs | 2 +- .../src/dispatcher/mod.rs | 4 +- diskann-benchmark-runner/src/files.rs | 2 +- .../src/internal/regression.rs | 5 +- diskann-benchmark-runner/src/jobs.rs | 2 +- diskann-benchmark-runner/src/registry.rs | 5 +- diskann-benchmark-runner/src/result.rs | 2 +- diskann-benchmark-runner/src/test/dim.rs | 2 +- diskann-benchmark-runner/src/test/typed.rs | 2 +- diskann-benchmark-runner/src/utils/fmt.rs | 95 ++++++++++--------- .../src/backend/index/search/plugins.rs | 4 +- 13 files changed, 68 insertions(+), 61 deletions(-) diff --git a/diskann-benchmark-runner/Cargo.toml b/diskann-benchmark-runner/Cargo.toml index 7c559d679..33cb63d0d 100644 --- a/diskann-benchmark-runner/Cargo.toml +++ b/diskann-benchmark-runner/Cargo.toml @@ -5,7 +5,7 @@ description.workspace = true authors.workspace = true documentation.workspace = true license.workspace = true -edition = "2024" +edition.workspace = true [dependencies] anyhow = { workspace = true } diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index ac13d7afd..27cb910a9 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize}; use crate::{ - Any, Checkpoint, Input, Output, dispatcher::{FailureScore, MatchScore}, + Any, Checkpoint, Input, Output, }; /// A registered benchmark. diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index 3f54f0862..4b3dda556 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -188,7 +188,7 @@ pub trait CheckDeserialization { mod tests { use super::*; - use std::fs::{File, create_dir}; + use std::fs::{create_dir, File}; #[test] fn test_constructor() { diff --git a/diskann-benchmark-runner/src/dispatcher/mod.rs b/diskann-benchmark-runner/src/dispatcher/mod.rs index d335eb78f..76eba7646 100644 --- a/diskann-benchmark-runner/src/dispatcher/mod.rs +++ b/diskann-benchmark-runner/src/dispatcher/mod.rs @@ -15,8 +15,8 @@ mod api; pub use api::{ - Description, DispatchRule, FailureScore, IMPLICIT_MATCH_SCORE, MatchScore, TaggedFailureScore, - Why, + Description, DispatchRule, FailureScore, MatchScore, TaggedFailureScore, Why, + IMPLICIT_MATCH_SCORE, }; /////////// diff --git a/diskann-benchmark-runner/src/files.rs b/diskann-benchmark-runner/src/files.rs index 09af82cb2..355b47010 100644 --- a/diskann-benchmark-runner/src/files.rs +++ b/diskann-benchmark-runner/src/files.rs @@ -65,7 +65,7 @@ impl CheckDeserialization for InputFile { #[cfg(test)] mod tests { - use std::fs::{File, create_dir}; + use std::fs::{create_dir, File}; use super::*; diff --git a/diskann-benchmark-runner/src/internal/regression.rs b/diskann-benchmark-runner/src/internal/regression.rs index c0fd56e87..f9bc12061 100644 --- a/diskann-benchmark-runner/src/internal/regression.rs +++ b/diskann-benchmark-runner/src/internal/regression.rs @@ -99,10 +99,9 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use crate::{ - Any, Checker, - benchmark::{PassFail, internal::CheckedPassFail}, + benchmark::{internal::CheckedPassFail, PassFail}, internal::load_from_disk, - jobs, registry, result, + jobs, registry, result, Any, Checker, }; //////////// diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index 25bd2f0a3..c7a4d2108 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,7 +8,7 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{Any, checker::Checker, input, registry}; +use crate::{checker::Checker, input, registry, Any}; #[derive(Debug)] pub(crate) struct Jobs { diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index eeae8609d..5d8c7366c 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -3,15 +3,14 @@ * Licensed under the MIT license. */ -use std::collections::{HashMap, hash_map::Entry}; +use std::collections::{hash_map::Entry, HashMap}; use thiserror::Error; use crate::{ - Any, Checkpoint, Input, Output, benchmark::{self, Benchmark, Regression}, dispatcher::{FailureScore, MatchScore}, - input, + input, Any, Checkpoint, Input, Output, }; /// A collection of [`crate::Input`]. diff --git a/diskann-benchmark-runner/src/result.rs b/diskann-benchmark-runner/src/result.rs index a4f68a654..cd8e34bb8 100644 --- a/diskann-benchmark-runner/src/result.rs +++ b/diskann-benchmark-runner/src/result.rs @@ -7,7 +7,7 @@ use std::path::Path; -use serde::{Deserialize, Serialize, Serializer, ser::SerializeSeq}; +use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; /// A helper to generate incremental snapshots of data while a benchmark is progressing. /// diff --git a/diskann-benchmark-runner/src/test/dim.rs b/diskann-benchmark-runner/src/test/dim.rs index 7a20e1801..f0eae36a3 100644 --- a/diskann-benchmark-runner/src/test/dim.rs +++ b/diskann-benchmark-runner/src/test/dim.rs @@ -8,9 +8,9 @@ use std::io::Write; use serde::{Deserialize, Serialize}; use crate::{ - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, benchmark::{PassFail, Regression}, dispatcher::{FailureScore, MatchScore}, + Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, }; /////////// diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index 8540211cf..cae95f66d 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -8,10 +8,10 @@ use std::io::Write; use serde::{Deserialize, Serialize}; use crate::{ - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, benchmark::{PassFail, Regression}, dispatcher::{Description, DispatchRule, FailureScore, MatchScore}, utils::datatype::{DataType, Type}, + Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, }; /////////// diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 430d2a784..959446e6c 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -231,57 +231,67 @@ impl std::fmt::Display for Indent<'_> { // Delimit // ///////////// -/// Formats an iterator with a delimiter between items and an optional distinct last delimiter. +/// Formats an iterator with a delimiter between items and optional overrides for +/// the final delimiter and pair formatting. /// /// This is a single-use wrapper: the iterator is consumed on the first call to [`Display::fmt`]. /// Subsequent calls will print ``. /// -/// The `last` parameter allows a different delimiter before the final item (e.g., `", and "`), -/// which is useful for natural-language lists like `"a, b, and c"`. +/// Use [`Delimit::with_last`] to change the delimiter before the final item +/// (e.g., `", and "`), which is useful for natural-language lists like +/// `"a, b, and c"`. /// -/// Finally, the `pair` parameter allows custom formatting when there are only two items. +/// Use [`Delimit::with_pair`] to change formatting when there are only two items. /// /// # Examples /// /// ``` /// use diskann_benchmark_runner::utils::fmt::Delimit; /// -/// let d = Delimit::new(["a", "b", "c"], ", ", Some(", and "), None); +/// let d = Delimit::new(["a", "b", "c"], ", ").with_last(", and "); /// assert_eq!(d.to_string(), "a, b, and c"); /// -/// let d = Delimit::new(["a", "b"], ", ", Some(", and "), None); -/// assert_eq!(d.to_string(), "a, and b"); +/// let d = Delimit::new(["a", "b"], ", ").with_last(", and "); +/// assert_eq!(d.to_string(), "a, b"); /// -/// let d = Delimit::new(["a", "b"], ", ", Some(", and "), Some(" and ")); +/// let d = Delimit::new(["a", "b"], ", ") +/// .with_last(", and ") +/// .with_pair(" and "); /// assert_eq!(d.to_string(), "a and b"); /// ``` pub struct Delimit<'a, I> { itr: std::cell::Cell>, delimiter: &'a str, - last: Option<&'a str>, - pair: Option<&'a str>, + last: &'a str, + pair: &'a str, } impl<'a, I> Delimit<'a, I> { - /// Create a new [`Delimit`] from an iterable, a delimiter, and an optional last delimiter. + /// Create a new [`Delimit`] from an iterable and a delimiter. /// - /// If `last` is `None`, the regular `delimiter` is used before the final item. - /// - /// If provided, `pair` will be used as the delimiter if the length of `itr` is 2. If - /// not supplied then `last` is used if available. Otherwise, `delimiter` is used. - pub fn new( - itr: impl IntoIterator, - delimiter: &'a str, - last: Option<&'a str>, - pair: Option<&'a str>, - ) -> Self { + /// By default, the same delimiter is used between every item. Use + /// [`Self::with_last`] and [`Self::with_pair`] to opt into special handling + /// before the final item or for pairs. + pub fn new(itr: impl IntoIterator, delimiter: &'a str) -> Self { Self { itr: std::cell::Cell::new(Some(itr.into_iter())), delimiter, - last, - pair, + last: delimiter, + pair: delimiter, } } + + /// Use `last` before the final item when formatting three or more items. + pub fn with_last(mut self, last: &'a str) -> Self { + self.last = last; + self + } + + /// Use `pair` when formatting exactly two items. + pub fn with_pair(mut self, pair: &'a str) -> Self { + self.pair = pair; + self + } } impl std::fmt::Display for Delimit<'_, I> @@ -312,14 +322,10 @@ where // `delimiter`. let delimiter = if count == 0 { "" - } else if count == 1 - && let Some(pair) = self.pair - { - pair - } else if let Some(last) = self.last { - last + } else if count == 1 { + self.pair } else { - self.delimiter + self.last }; return write!(f, "{}{}", delimiter, current); @@ -528,43 +534,47 @@ string, , string #[test] fn test_delimit_empty() { - let d = Delimit::new(std::iter::empty::<&str>(), ", ", None, None); + let d = Delimit::new(std::iter::empty::<&str>(), ", "); assert_eq!(d.to_string(), ""); } #[test] fn test_delimit_single_item() { - let d = Delimit::new(["a"], ", ", Some(", and "), None); + let d = Delimit::new(["a"], ", ").with_last(", and "); assert_eq!(d.to_string(), "a"); } #[test] fn test_delimit_two_items_with_last() { - let d = Delimit::new(["a", "b"], ", ", Some(", and "), None); - assert_eq!(d.to_string(), "a, and b"); + let d = Delimit::new(["a", "b"], ", ").with_last(", and "); + assert_eq!(d.to_string(), "a, b"); } #[test] fn test_delimit_two_items_with_pair() { - let d = Delimit::new(["a", "b"], ", ", Some(", and "), Some(" and ")); + let d = Delimit::new(["a", "b"], ", ") + .with_last(", and ") + .with_pair(" and "); assert_eq!(d.to_string(), "a and b"); } #[test] fn test_delimit_three_items_with_last() { - let d = Delimit::new(["a", "b", "c"], ", ", Some(", and "), Some(" and ")); + let d = Delimit::new(["a", "b", "c"], ", ") + .with_last(", and ") + .with_pair(" and "); assert_eq!(d.to_string(), "a, b, and c"); } #[test] fn test_delimit_without_last() { - let d = Delimit::new(["x", "y", "z"], " | ", None, None); + let d = Delimit::new(["x", "y", "z"], " | "); assert_eq!(d.to_string(), "x | y | z"); } #[test] fn test_delimit_second_display_prints_missing() { - let d = Delimit::new(["a", "b"], ", ", None, None); + let d = Delimit::new(["a", "b"], ", "); assert_eq!(d.to_string(), "a, b"); assert_eq!(d.to_string(), ""); } @@ -581,12 +591,9 @@ string, , string #[test] fn test_delimit_with_quote() { - let d = Delimit::new( - ["topk", "range"].iter().map(Quote), - ", ", - Some(", and "), - Some(" and "), - ); + let d = Delimit::new(["topk", "range"].iter().map(Quote), ", ") + .with_last(", and ") + .with_pair(" and "); assert_eq!(d.to_string(), "\"topk\" and \"range\""); } } diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index eff64e282..92635b665 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -103,7 +103,9 @@ where /// Return a human readable, formatted list of the registered [`SearchPhaseKind`]s. pub(crate) fn format_kinds(&self) -> impl std::fmt::Display + use<'_, DP, P> { - Delimit::new(self.kinds().map(Quote), ", ", Some(", and "), Some(" and ")) + Delimit::new(self.kinds().map(Quote), ", ") + .with_last(", and ") + .with_pair(" and ") } /// Try to run a search plugin for `phase`. From 2e060eab484bcd71b59e5fce9eaae0a57ba76091 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 1 May 2026 14:45:49 -0700 Subject: [PATCH 07/38] Make plugins slightly more flexible. --- diskann-benchmark-runner/src/utils/fmt.rs | 12 +- .../src/backend/index/benchmarks.rs | 72 ++++++----- .../src/backend/index/product.rs | 18 +-- diskann-benchmark/src/backend/index/scalar.rs | 20 +-- .../src/backend/index/search/plugins.rs | 115 +++++++++--------- .../src/backend/index/spherical.rs | 73 +++++++---- diskann-benchmark/src/inputs/async_.rs | 2 +- 7 files changed, 178 insertions(+), 134 deletions(-) diff --git a/diskann-benchmark-runner/src/utils/fmt.rs b/diskann-benchmark-runner/src/utils/fmt.rs index 959446e6c..8113d30b9 100644 --- a/diskann-benchmark-runner/src/utils/fmt.rs +++ b/diskann-benchmark-runner/src/utils/fmt.rs @@ -252,7 +252,7 @@ impl std::fmt::Display for Indent<'_> { /// assert_eq!(d.to_string(), "a, b, and c"); /// /// let d = Delimit::new(["a", "b"], ", ").with_last(", and "); -/// assert_eq!(d.to_string(), "a, b"); +/// assert_eq!(d.to_string(), "a, and b"); /// /// let d = Delimit::new(["a", "b"], ", ") /// .with_last(", and ") @@ -263,7 +263,7 @@ pub struct Delimit<'a, I> { itr: std::cell::Cell>, delimiter: &'a str, last: &'a str, - pair: &'a str, + pair: Option<&'a str>, } impl<'a, I> Delimit<'a, I> { @@ -277,7 +277,7 @@ impl<'a, I> Delimit<'a, I> { itr: std::cell::Cell::new(Some(itr.into_iter())), delimiter, last: delimiter, - pair: delimiter, + pair: None, } } @@ -289,7 +289,7 @@ impl<'a, I> Delimit<'a, I> { /// Use `pair` when formatting exactly two items. pub fn with_pair(mut self, pair: &'a str) -> Self { - self.pair = pair; + self.pair = Some(pair); self } } @@ -323,7 +323,7 @@ where let delimiter = if count == 0 { "" } else if count == 1 { - self.pair + self.pair.unwrap_or(self.last) } else { self.last }; @@ -547,7 +547,7 @@ string, , string #[test] fn test_delimit_two_items_with_last() { let d = Delimit::new(["a", "b"], ", ").with_last(", and "); - assert_eq!(d.to_string(), "a, b"); + assert_eq!(d.to_string(), "a, and b"); } #[test] diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index f3a479b2c..b9fe70aec 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -45,9 +45,7 @@ use crate::{ search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::async_::{ - DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, SearchPhaseKind, - }, + inputs::async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, utils::{ self, datafiles::{self}, @@ -143,7 +141,8 @@ struct FullPrecision where T: VectorRepr, { - plugins: plugins::Plugins, Strategy>, + plugins: + plugins::Plugins, SearchPhase, Strategy>, } impl FullPrecision @@ -158,7 +157,8 @@ where fn search

(mut self, plugin: P) -> Self where - P: plugins::Plugin, Strategy> + 'static, + P: plugins::Plugin, SearchPhase, Strategy> + + 'static, { self.plugins.register(plugin); self @@ -177,7 +177,7 @@ where fn try_match(&self, input: &IndexOperation) -> Result { let score = datatype::Type::::try_match(input.source.data_type()); - if self.plugins.is_match(input.search_phase.kind()) { + if self.plugins.is_match(&input.search_phase) { score } else { match score { @@ -206,7 +206,7 @@ where Why::>::new(data_type) )?; - if !self.plugins.is_match(arg.search_phase.kind()) { + if !self.plugins.is_match(&arg.search_phase) { writeln!( f, "Unsupported search phase: \"{}\" - expected one of {}", @@ -280,8 +280,8 @@ where let search_results = self.plugins.run( index, - &Strategy::new(common::FullPrecision), &input.search_phase, + &Strategy::new(common::FullPrecision), )?; let result = BuildResult::new(build_stats, search_results); @@ -448,20 +448,24 @@ impl Strategy { // Topk // //------// -impl search::Plugin> for plugins::Topk +impl search::Plugin> for plugins::Topk where DP: DataProvider + QueryType, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() } - fn search( + fn kind(&self) -> &'static str { + Self::kind().as_str() + } + + fn run( &self, index: Arc>, - strategy: &Strategy, phase: &SearchPhase, + strategy: &Strategy, ) -> anyhow::Result { let topk = phase.as_topk()?; @@ -487,20 +491,24 @@ where // Range // //-------// -impl search::Plugin> for plugins::Range +impl search::Plugin> for plugins::Range where DP: DataProvider + QueryType, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() + } + + fn kind(&self) -> &'static str { + Self::kind().as_str() } - fn search( + fn run( &self, index: Arc>, - strategy: &Strategy, phase: &SearchPhase, + strategy: &Strategy, ) -> anyhow::Result { let range = phase.as_range()?; let queries: Arc> = @@ -527,20 +535,24 @@ where // BetaFilter // //------------// -impl search::Plugin> for plugins::BetaFilter +impl search::Plugin> for plugins::BetaFilter where DP: DataProvider + QueryType, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() + } + + fn kind(&self) -> &'static str { + Self::kind().as_str() } - fn search( + fn run( &self, index: Arc>, - strategy: &Strategy, phase: &SearchPhase, + strategy: &Strategy, ) -> anyhow::Result { let beta_filter = phase.as_topk_beta_filter()?; @@ -582,20 +594,24 @@ where // MultihopFilter // //----------------// -impl search::Plugin> for plugins::MultihopFilter +impl search::Plugin> for plugins::MultihopFilter where DP: DataProvider + QueryType, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() } - fn search( + fn kind(&self) -> &'static str { + Self::kind().as_str() + } + + fn run( &self, index: Arc>, - strategy: &Strategy, phase: &SearchPhase, + strategy: &Strategy, ) -> anyhow::Result { let multihop = phase.as_topk_multihop_filter()?; diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index 30e55df7a..e4e3747ce 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -63,7 +63,7 @@ mod imp { result::{BuildResult, QuantBuildResult}, search::plugins, }, - inputs::async_::{IndexPQOperation, IndexSource}, + inputs::async_::{IndexPQOperation, IndexSource, SearchPhase}, utils::{self, datafiles}, }; @@ -89,8 +89,8 @@ mod imp { where T: VectorRepr, { - quant_search: plugins::Plugins, Strategy>, - full_search: plugins::Plugins, Strategy>, + quant_search: plugins::Plugins, SearchPhase, Strategy>, + full_search: plugins::Plugins, SearchPhase, Strategy>, } impl ProductQuantized @@ -106,8 +106,8 @@ mod imp { pub(super) fn search

(mut self, plugin: P) -> Self where - P: plugins::Plugin, Strategy> - + plugins::Plugin, Strategy> + P: plugins::Plugin, SearchPhase, Strategy> + + plugins::Plugin, SearchPhase, Strategy> + Clone + 'static, { @@ -132,7 +132,7 @@ mod imp { if self .quant_search - .is_match(input.index_operation.search_phase.kind()) + .is_match(&input.index_operation.search_phase) { score } else { @@ -162,7 +162,7 @@ mod imp { if !self .quant_search - .is_match(arg.index_operation.search_phase.kind()) + .is_match(&arg.index_operation.search_phase) { writeln!( f, @@ -264,14 +264,14 @@ mod imp { let search = if input.use_fp_for_search { self.full_search.run( index, - &Strategy::new(common::FullPrecision), &input.index_operation.search_phase, + &Strategy::new(common::FullPrecision), )? } else { self.quant_search.run( index, - &Strategy::new(hybrid), &input.index_operation.search_phase, + &Strategy::new(hybrid), )? }; diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index 6b90c9647..5c9654c96 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -92,7 +92,7 @@ mod imp { result::{BuildResult, QuantBuildResult}, search::plugins, }, - inputs::async_::{IndexSQOperation, IndexSource}, + inputs::async_::{IndexSQOperation, IndexSource, SearchPhase}, utils::{self, datafiles}, }; @@ -118,8 +118,10 @@ mod imp { where T: VectorRepr, { - quant_search: plugins::Plugins, Strategy>, - full_search: plugins::Plugins, Strategy>, + quant_search: + plugins::Plugins, SearchPhase, Strategy>, + full_search: + plugins::Plugins, SearchPhase, Strategy>, } impl ScalarQuantized @@ -135,8 +137,8 @@ mod imp { pub(super) fn search

(mut self, plugin: P) -> Self where - P: plugins::Plugin, Strategy> - + plugins::Plugin, Strategy> + P: plugins::Plugin, SearchPhase, Strategy> + + plugins::Plugin, SearchPhase, Strategy> + Clone + 'static, { @@ -169,7 +171,7 @@ mod imp { *failure_score.get_or_insert(0) += 1; } - if !self.quant_search.is_match(input.index_operation.search_phase.kind()) { + if !self.quant_search.is_match(&input.index_operation.search_phase) { *failure_score.get_or_insert(0) += 1; } @@ -233,7 +235,7 @@ mod imp { } } - if !self.quant_search.is_match(input.index_operation.search_phase.kind()) { + if !self.quant_search.is_match(&input.index_operation.search_phase) { writeln!( f, "- Unsupported search phase: \"{}\" - expected one of {}", @@ -322,14 +324,14 @@ mod imp { let search = if input.use_fp_for_search { self.full_search.run( index, - &Strategy::new(common::FullPrecision), &input.index_operation.search_phase, + &Strategy::new(common::FullPrecision), )? } else { self.quant_search.run( index, - &Strategy::new(common::Quantized), &input.index_operation.search_phase, + &Strategy::new(common::Quantized), )? }; diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 92635b665..3a1c86a11 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -3,76 +3,80 @@ * Licensed under the MIT license. */ -//! Search plugins are the solution the following benchmarking problem: +//! Search plugins let each benchmark define exactly which search flavors it supports while +//! keeping benchmark matching and reporting consistent. //! -//! The [`SearchPhase`] enum contains a list of available search kinds. Adding a new variant -//! either requires updating **all** users to implement that related search (harming compile -//! times) or requires users to explicitly opt-out. Unfortunately, the latter is difficult -//! to maintain with benchmark matching (i.e., the desire to catch configuration mismatches -//! such as requesting an unsupported search early, rather than reaching an error late in -//! a benchmark run). Additionally, if only a subset of search kinds are supported, it -//! is user-friendly to document which variants are actually supported and to make it simple -//! to add or remove flavors. +//! The core abstraction is split across the [`Plugin`] trait and the [`Plugins`] helper. +//! [`Plugin`] is dyn-compatible and generic over three things: //! -//! The solution is the [`Plugin`] trait and the [`Plugins`] helper. The trait is a -//! dyn-compatible wrapper for a search and the [`Plugins`] struct simply collects a list -//! of [`Plugin`]s. +//! * `DP`: the concrete index/data provider being searched. +//! * `Kind`: the value used for matching and pre-validation. +//! * `Params`: any additional execution context needed once a plugin has been selected. //! -//! Implementations of [`Plugin`] declare which type of search they support, which is aggregated -//! in the [`Plugins`] helper. +//! Keeping `Kind` and `Params` separate is intentional. Matching usually wants a narrow, +//! user-facing notion of "what kind of search was requested?", while execution often needs +//! additional benchmark-specific state. This keeps diagnostics precise without forcing the +//! matching type to absorb every runtime detail. //! -//! Benchmarks can then contain a [`Plugins`] field, dynamically register plugin types, and -//! then get registered in [`diskann_benchmark_runner::Benchmarks`]. The follow methods then -//! support proper reporting in the benchmark infrastructure: +//! Benchmarks own a [`Plugins`] collection and register only the plugin types they want to +//! support. The helper methods on [`Plugins`] then integrate with +//! [`diskann_benchmark_runner::Benchmarks`]: //! -//! * [`Plugins::format_kinds`]: Format the registered plugins. -//! * [`Plugins::is_match`]: Return whether a [`Plugin`] is registered matching a phase. -//! * [`Plugins::run`]: Run the first matching plugin. +//! * [`Plugins::format_kinds`]: format the registered plugin labels for diagnostics. +//! * [`Plugins::is_match`]: check whether any registered plugin accepts a requested `Kind`. +//! * [`Plugins::run`]: dispatch to the first registered plugin matching `Kind`. //! -//! Concrete plugins maintain a one-to-one relationship with variants in [`SearchPhase`] and -//! [`SearchPhaseKind`] and are simple ZSTs. +//! The built-in ZST plugins in this module (`Topk`, `Range`, `BetaFilter`, and +//! `MultihopFilter`) target the async benchmark inputs and fold their outputs into the closed +//! [`AggregatedSearchResults`] families. That closed result boundary is deliberate: plugins are +//! open for new search flavors, while result aggregation remains a curated +//! reporting/evaluation boundary. use std::sync::Arc; use diskann::{graph::DiskANNIndex, provider::DataProvider}; use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; -use crate::{ - backend::index::result::AggregatedSearchResults, - inputs::async_::{SearchPhase, SearchPhaseKind}, -}; +use crate::{backend::index::result::AggregatedSearchResults, inputs::async_::SearchPhaseKind}; -/// A search plugin for `DP`. The generic `P` is for any additional parameters needed by -/// a benchmark. -pub(crate) trait Plugin: std::fmt::Debug +/// A dyn-compatible search plugin for `DP`. +/// +/// `Kind` is the matching surface used for benchmark selection and diagnostics. `Params` +/// contains any additional execution context needed after a plugin has been selected. +pub(crate) trait Plugin: std::fmt::Debug where DP: DataProvider, { - /// The flavor of `SearchPhase` this plugin is compiled for. - fn kind(&self) -> SearchPhaseKind; + /// Return `true` if this plugin can accept `kind`. + fn is_match(&self, kind: &Kind) -> bool; + + /// Return a human-readable label for the flavors of `Kind` supported by this plugin. + /// + /// This is used for informational diagnostics and benchmark descriptions. + fn kind(&self) -> &'static str; /// Run the search. /// - /// The user can assume that `phase` has the same [`SearchPhaseKind`] as [`Self::kind`] - /// and may return an error if this is not the case. - fn search( + /// The user can assume that `kind` passes [`Self::is_match`] and may return an error + /// if this is not the case. + fn run( &self, index: Arc>, - parameters: &P, - phase: &SearchPhase, + kind: &Kind, + parameters: &Params, ) -> anyhow::Result; } -/// A collection of dynamically registered [`Plugins`]. +/// A collection of dynamically registered [`Plugin`]s. #[derive(Debug)] -pub(crate) struct Plugins +pub(crate) struct Plugins where DP: DataProvider, { - plugins: Vec>>, + plugins: Vec>>, } -impl Plugins +impl Plugins where DP: DataProvider, { @@ -86,29 +90,31 @@ where /// Register `plugin` in the managed collection. pub(crate) fn register(&mut self, plugin: T) where - T: Plugin + 'static, + T: Plugin + 'static, { self.plugins.push(Box::new(plugin)); } - /// Return an iterator over all [`SearchPhaseKind`]s currently registered. - pub(crate) fn kinds(&self) -> impl ExactSizeIterator + use<'_, DP, P> { + /// Return an iterator over the labels of all currently registered plugins. + pub(crate) fn kinds( + &self, + ) -> impl ExactSizeIterator + use<'_, DP, Kind, Params> { self.plugins.iter().map(|p| p.kind()) } - /// Return whether a [`Plugin`] is registered matching `phase`. - pub(crate) fn is_match(&self, phase: SearchPhaseKind) -> bool { - self.plugins.iter().any(|p| p.kind() == phase) + /// Return whether any registered [`Plugin`] matches `kind`. + pub(crate) fn is_match(&self, kind: &Kind) -> bool { + self.plugins.iter().any(|p| p.is_match(kind)) } - /// Return a human readable, formatted list of the registered [`SearchPhaseKind`]s. - pub(crate) fn format_kinds(&self) -> impl std::fmt::Display + use<'_, DP, P> { + /// Return a human readable, formatted list of the registered plugin labels. + pub(crate) fn format_kinds(&self) -> impl std::fmt::Display + use<'_, DP, Kind, Params> { Delimit::new(self.kinds().map(Quote), ", ") .with_last(", and ") .with_pair(" and ") } - /// Try to run a search plugin for `phase`. + /// Try to run a search plugin for `kind`. /// /// If no such plugin exists, an "INTERNAL ERROR:" is returned. /// Within the `diskann-benchmark` crate, pre-validation with [`Self::is_match`] should @@ -116,14 +122,13 @@ where pub(crate) fn run( &self, index: Arc>, - parameters: &P, - phase: &SearchPhase, + kind: &Kind, + parameters: &Params, ) -> anyhow::Result { - match self.plugins.iter().find(|p| p.kind() == phase.kind()) { - Some(plugin) => plugin.search(index, parameters, phase), + match self.plugins.iter().find(|p| p.is_match(kind)) { + Some(plugin) => plugin.run(index, kind, parameters), None => Err(anyhow::anyhow!( - "INTERNAL ERROR: Could not find a search plugin for {}", - phase.kind() + "INTERNAL ERROR: Could not find a suitable search plugin", )), } } diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 76a1dae5d..d5378da7d 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -85,7 +85,7 @@ mod imp { search, }, inputs::{ - async_::{SearchPhase, SearchPhaseKind, SphericalQuantBuild}, + async_::{SearchPhase, SphericalQuantBuild}, exhaustive, }, utils::{ @@ -108,7 +108,7 @@ mod imp { /// A [`Benchmark`] for spherical-quantized searches containing a dynamic list of search /// types. pub(super) struct SphericalQ { - search: search::plugins::Plugins, + search: search::plugins::Plugins, } impl SphericalQ { @@ -120,7 +120,8 @@ mod imp { pub(super) fn search

(mut self, plugin: P) -> Self where - P: search::plugins::Plugin + 'static, + P: search::plugins::Plugin + + 'static, { self.search.register(plugin); self @@ -194,7 +195,7 @@ mod imp { *failure_score.get_or_insert(0) += 1; } - if !self.search.is_match(input.search_phase.kind()) { + if !self.search.is_match(&input.search_phase) { *failure_score.get_or_insert(0) += 1; } @@ -250,7 +251,7 @@ mod imp { )?; } - if !self.search.is_match(input.search_phase.kind()) { + if !self.search.is_match(&input.search_phase) { writeln!( f, "- Unsupported search phase: \"{}\" - expected one of {}", @@ -339,7 +340,7 @@ mod imp { for layout in input.query_layouts.iter() { let search = self .search - .run(index.clone(), layout, &input.search_phase)?; + .run(index.clone(), &input.search_phase, layout)?; result.append(SearchRun { layout: *layout, results: search, @@ -357,16 +358,22 @@ mod imp { build_and_search!(2); build_and_search!(4); - impl search::plugins::Plugin for search::plugins::Topk { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + impl search::plugins::Plugin + for search::plugins::Topk + { + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() + } + + fn kind(&self) -> &'static str { + Self::kind().as_str() } - fn search( + fn run( &self, index: Arc>, - query_layout: &exhaustive::SphericalQuery, phase: &SearchPhase, + query_layout: &exhaustive::SphericalQuery, ) -> anyhow::Result { let topk = phase.as_topk()?; @@ -390,16 +397,22 @@ mod imp { } } - impl search::plugins::Plugin for search::plugins::Range { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + impl search::plugins::Plugin + for search::plugins::Range + { + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() + } + + fn kind(&self) -> &'static str { + Self::kind().as_str() } - fn search( + fn run( &self, index: Arc>, - query_layout: &exhaustive::SphericalQuery, phase: &SearchPhase, + query_layout: &exhaustive::SphericalQuery, ) -> anyhow::Result { let range = phase.as_range()?; @@ -426,18 +439,22 @@ mod imp { } } - impl search::plugins::Plugin + impl search::plugins::Plugin for search::plugins::BetaFilter { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() + } + + fn kind(&self) -> &'static str { + Self::kind().as_str() } - fn search( + fn run( &self, index: Arc>, - query_layout: &exhaustive::SphericalQuery, phase: &SearchPhase, + query_layout: &exhaustive::SphericalQuery, ) -> anyhow::Result { let betafilter = phase.as_topk_beta_filter()?; @@ -476,18 +493,22 @@ mod imp { } } - impl search::plugins::Plugin + impl search::plugins::Plugin for search::plugins::MultihopFilter { - fn kind(&self) -> SearchPhaseKind { - Self::kind() + fn is_match(&self, phase: &SearchPhase) -> bool { + Self::kind() == phase.kind() } - fn search( + fn kind(&self) -> &'static str { + Self::kind().as_str() + } + + fn run( &self, index: Arc>, - query_layout: &exhaustive::SphericalQuery, phase: &SearchPhase, + query_layout: &exhaustive::SphericalQuery, ) -> anyhow::Result { let multihop = phase.as_topk_multihop_filter()?; diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 0bb8adf10..57921844d 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -425,7 +425,7 @@ pub(crate) enum SearchPhaseKind { } impl SearchPhaseKind { - fn as_str(&self) -> &'static str { + pub(crate) fn as_str(&self) -> &'static str { match self { Self::Topk => "topk", Self::Range => "range", From fc8008a3663e3f0b6965eea452e036e252e18ccb Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 4 May 2026 16:14:33 +0530 Subject: [PATCH 08/38] Add determinant-diversity plugin support on search-plugin architecture --- .../async-determinant-diversity-defaults.json | 49 ++++++ .../async-determinant-diversity-eta1.json | 51 +++++++ .../async-determinant-diversity-none.json | 46 ++++++ .../example/async-determinant-diversity.json | 51 +++++++ .../src/backend/index/benchmarks.rs | 139 +++++++++++++++++- diskann-benchmark/src/backend/index/mod.rs | 1 + .../post_processor/determinant_diversity.rs | 76 ++++++++++ .../src/backend/index/post_processor/mod.rs | 8 + .../src/backend/index/search/plugins.rs | 11 ++ diskann-benchmark/src/inputs/async_.rs | 52 +++++++ 10 files changed, 481 insertions(+), 3 deletions(-) create mode 100644 diskann-benchmark/example/async-determinant-diversity-defaults.json create mode 100644 diskann-benchmark/example/async-determinant-diversity-eta1.json create mode 100644 diskann-benchmark/example/async-determinant-diversity-none.json create mode 100644 diskann-benchmark/example/async-determinant-diversity.json create mode 100644 diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs create mode 100644 diskann-benchmark/src/backend/index/post_processor/mod.rs diff --git a/diskann-benchmark/example/async-determinant-diversity-defaults.json b/diskann-benchmark/example/async-determinant-diversity-defaults.json new file mode 100644 index 000000000..f60e22769 --- /dev/null +++ b/diskann-benchmark/example/async-determinant-diversity-defaults.json @@ -0,0 +1,49 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "start_point_strategy": "medoid", + "num_insert_attempts": 1, + "saturate_inserts": false + }, + "search_phase": { + "search-type": "topk", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "reps": 5, + "num_threads": [ + 1 + ], + "post_processor": { + "type": "determinant-diversity" + }, + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40 + ], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/example/async-determinant-diversity-eta1.json b/diskann-benchmark/example/async-determinant-diversity-eta1.json new file mode 100644 index 000000000..ffd31c756 --- /dev/null +++ b/diskann-benchmark/example/async-determinant-diversity-eta1.json @@ -0,0 +1,51 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "start_point_strategy": "medoid", + "num_insert_attempts": 1, + "saturate_inserts": false + }, + "search_phase": { + "search-type": "topk", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "reps": 5, + "num_threads": [ + 1 + ], + "post_processor": { + "type": "determinant-diversity", + "power": 2.0, + "eta": 1.0 + }, + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40 + ], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/example/async-determinant-diversity-none.json b/diskann-benchmark/example/async-determinant-diversity-none.json new file mode 100644 index 000000000..820650ee9 --- /dev/null +++ b/diskann-benchmark/example/async-determinant-diversity-none.json @@ -0,0 +1,46 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "start_point_strategy": "medoid", + "num_insert_attempts": 1, + "saturate_inserts": false + }, + "search_phase": { + "search-type": "topk", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "reps": 5, + "num_threads": [ + 1 + ], + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40 + ], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/example/async-determinant-diversity.json b/diskann-benchmark/example/async-determinant-diversity.json new file mode 100644 index 000000000..a8c1cc86f --- /dev/null +++ b/diskann-benchmark/example/async-determinant-diversity.json @@ -0,0 +1,51 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "start_point_strategy": "medoid", + "num_insert_attempts": 1, + "saturate_inserts": false + }, + "search_phase": { + "search-type": "topk", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "reps": 5, + "num_threads": [ + 1 + ], + "post_processor": { + "type": "determinant-diversity", + "power": 2.0, + "eta": 0.01 + }, + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40 + ], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index b9fe70aec..da1acc254 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -41,11 +41,15 @@ use super::{ }; use crate::{ backend::index::{ - result::{AggregatedSearchResults, BuildResult}, + post_processor, + result::{AggregatedSearchResults, BuildResult, SearchResults}, search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, + inputs::async_::{ + DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, + TopkPostProcessor, + }, utils::{ self, datafiles::{self}, @@ -73,6 +77,7 @@ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::reg benchmarks.register( "async-full-precision-f32", FullPrecision::::new() + .search(plugins::DeterminantDiversity) .search(plugins::Topk) .search(plugins::Range) .search(plugins::BetaFilter) @@ -448,13 +453,141 @@ impl Strategy { // Topk // //------// +impl search::Plugin> + for plugins::DeterminantDiversity +where + DP: DataProvider + QueryType, + common::FullPrecision: for<'a> glue::SearchStrategy, + for<'a> post_processor::DeterminantDiversity: glue::SearchPostProcess< + >::SearchAccessor<'a>, + &'a [DP::Element], + u32, + >, +{ + fn is_match(&self, phase: &SearchPhase) -> bool { + if Self::kind() != phase.kind() { + return false; + } + + phase + .as_topk() + .ok() + .and_then(|topk| topk.post_processor.as_ref()) + .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) + } + + fn kind(&self) -> &'static str { + "topk + determinant-diversity" + } + + fn run( + &self, + index: Arc>, + phase: &SearchPhase, + _strategy: &Strategy, + ) -> anyhow::Result { + let topk = phase.as_topk()?; + let (power, eta) = match topk.post_processor.as_ref() { + Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => (*power, *eta), + _ => { + return Err(anyhow::anyhow!( + "determinant-diversity plugin selected for non determinant-diversity input", + )); + } + }; + + let strategy = common::FullPrecision; + let context = DefaultContext; + let det_div = post_processor::DeterminantDiversity::new(power, eta); + + let queries: Arc> = + Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); + let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + + let mut all_results = Vec::new(); + + for threads in &topk.num_threads { + for run in &topk.runs { + for search_l in &run.search_l { + let knn_params = + diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + + let mut all_recalls = Vec::new(); + + for query_idx in 0..queries.nrows() { + let query = queries.row(query_idx); + let mut output: Vec> = Vec::new(); + utils::tokio::block_on(async { + index + .search_with( + knn_params, + &strategy, + det_div, + &context, + query, + &mut output, + ) + .await + })?; + + let gt = groundtruth.row(query_idx); + let mut matches = 0; + for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { + if i >= gt.len() { + break; + } + if gt.contains(&neighbor.id) { + matches += 1; + } + } + all_recalls.push(matches); + } + + let avg_recall = + all_recalls.iter().sum::() as f32 / (queries.nrows() * run.recall_k) as f32; + + all_results.push(SearchResults { + num_tasks: threads.get(), + search_n: run.search_n, + search_l: *search_l, + qps: vec![], + search_latencies: vec![], + mean_latencies: vec![], + p90_latencies: vec![], + p99_latencies: vec![], + recall: utils::recall::RecallMetrics { + recall_k: run.recall_k, + recall_n: run.search_n, + num_queries: queries.nrows(), + average: avg_recall as f64, + minimum: *all_recalls.iter().min().unwrap_or(&0), + maximum: *all_recalls.iter().max().unwrap_or(&0), + }, + mean_cmps: 0.0, + mean_hops: 0.0, + }); + } + } + } + + Ok(AggregatedSearchResults::Topk(all_results)) + } +} + impl search::Plugin> for plugins::Topk where DP: DataProvider + QueryType, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + if Self::kind() != phase.kind() { + return false; + } + + phase + .as_topk() + .ok() + .is_some_and(|topk| topk.post_processor.is_none()) } fn kind(&self) -> &'static str { diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index 269887c6d..bb8147939 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -4,6 +4,7 @@ */ mod build; +mod post_processor; mod search; mod streaming; diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs new file mode 100644 index 000000000..2fcfe196d --- /dev/null +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -0,0 +1,76 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann::{ + error::ANNError, + graph::glue, + neighbor::Neighbor, + provider::Accessor, +}; +use diskann::graph::search_output_buffer::SearchOutputBuffer; + +#[derive(Debug, Clone, Copy)] +pub(crate) struct DeterminantDiversity { + power: f32, + eta: f32, +} + +impl DeterminantDiversity { + pub(crate) const fn new(power: f32, eta: f32) -> Self { + Self { power, eta } + } +} + +impl glue::SearchPostProcess for DeterminantDiversity +where + A: Accessor + diskann::provider::BuildQueryComputer + Send, + T: Send + Sync, +{ + type Error = ANNError; + + async fn post_process( + &self, + _accessor: &mut A, + _query: T, + _computer: &>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + // Placeholder deterministic-diversity scoring that uses both parameters. + let mut reranked: Vec<(Neighbor, f32)> = candidates + .enumerate() + .map(|(rank, candidate)| { + let transformed = candidate.distance.abs().powf(self.power) + + (rank as f32) * self.eta; + (candidate, -transformed) + }) + .collect(); + + reranked.sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Emit only part of the reranked list so power/eta impact recall, + // making this path easy to validate in benchmark outputs. + let keep_ratio = (1.0 / (1.0 + self.power * self.eta * 10.0)).clamp(0.1, 1.0); + let max_emit = ((reranked.len() as f32) * keep_ratio).round().max(1.0) as usize; + + let mut count = 0; + for (candidate, _) in reranked.into_iter().take(max_emit) { + let state = output.push(candidate.id, candidate.distance); + count += 1; + if !state.is_available() { + break; + } + } + + Ok(count) + } +} diff --git a/diskann-benchmark/src/backend/index/post_processor/mod.rs b/diskann-benchmark/src/backend/index/post_processor/mod.rs new file mode 100644 index 000000000..4afaab925 --- /dev/null +++ b/diskann-benchmark/src/backend/index/post_processor/mod.rs @@ -0,0 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) mod determinant_diversity; + +pub(crate) use determinant_diversity::DeterminantDiversity; diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 3a1c86a11..dc1c0f678 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -145,6 +145,17 @@ impl Topk { } } +/// A search plugin for determinant-diversity top-k post-processing. +#[derive(Debug, Clone, Copy)] +pub(crate) struct DeterminantDiversity; + +impl DeterminantDiversity { + /// Returns [`SearchPhaseKind::Topk`]. + pub(crate) fn kind() -> SearchPhaseKind { + SearchPhaseKind::Topk + } +} + /// A search plugin for range search. #[derive(Debug, Clone, Copy)] pub(crate) struct Range; diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 57921844d..04077e4a1 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -126,6 +126,51 @@ pub(crate) struct TopkSearchPhase { // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, + #[serde(default)] + pub(crate) post_processor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub(crate) enum TopkPostProcessor { + DeterminantDiversity { + #[serde(default = "default_det_div_power")] + power: f32, + #[serde(default = "default_det_div_eta")] + eta: f32, + }, +} + +const fn default_det_div_power() -> f32 { + 2.0 +} + +const fn default_det_div_eta() -> f32 { + 0.01 +} + +impl CheckDeserialization for TopkPostProcessor { + fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { + match self { + TopkPostProcessor::DeterminantDiversity { power, eta } => { + if *power <= 0.0 { + return Err(anyhow::anyhow!( + "determinant-diversity power must be > 0.0, got: {}", + power + )); + } + + if *eta < 0.0 { + return Err(anyhow::anyhow!( + "determinant-diversity eta must be >= 0.0, got: {}", + eta + )); + } + + Ok(()) + } + } + } } impl CheckDeserialization for TopkSearchPhase { @@ -139,6 +184,12 @@ impl CheckDeserialization for TopkSearchPhase { .with_context(|| format!("search run {}", i))?; } + if let Some(post_processor) = self.post_processor.as_mut() { + post_processor + .check_deserialization(checker) + .context("invalid topk post processor")?; + } + Ok(()) } } @@ -166,6 +217,7 @@ impl Example for TopkSearchPhase { reps: REPS, num_threads: THREAD_COUNTS.to_vec(), runs, + post_processor: None, } } } From 298d1b8cced7b6f8905acff5472e6e3461befa72 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 4 May 2026 17:58:39 +0530 Subject: [PATCH 09/38] Integrate determinant-diversity via disk search_with post-processor --- .../disk-index-determinant-diversity.json | 42 ++++ .../src/backend/disk_index/search.rs | 32 ++- .../src/backend/index/benchmarks.rs | 6 +- diskann-benchmark/src/backend/index/mod.rs | 2 +- .../post_processor/determinant_diversity.rs | 59 +++-- diskann-benchmark/src/inputs/async_.rs | 45 +--- diskann-benchmark/src/inputs/disk.rs | 19 +- diskann-benchmark/src/inputs/mod.rs | 1 + .../src/inputs/post_processor.rs | 48 ++++ diskann-disk/src/build/builder/core.rs | 1 + .../src/search/provider/disk_provider.rs | 212 +++++++++++++++++- diskann-tools/src/utils/search_disk_index.rs | 1 + 12 files changed, 385 insertions(+), 83 deletions(-) create mode 100644 diskann-benchmark/example/disk-index-determinant-diversity.json create mode 100644 diskann-benchmark/src/inputs/post_processor.rs diff --git a/diskann-benchmark/example/disk-index-determinant-diversity.json b/diskann-benchmark/example/disk-index-determinant-diversity.json new file mode 100644 index 000000000..2962c1d97 --- /dev/null +++ b/diskann-benchmark/example/disk-index-determinant-diversity.json @@ -0,0 +1,42 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_full_det_div" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "post_processor": { + "type": "determinant-diversity", + "power": 2.0, + "eta": 1.0 + } + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 487432598..299f6403d 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -14,7 +14,8 @@ use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + disk_provider::{DiskIndexSearcher, SearchPostProcessorKind}, + disk_vertex_provider_factory::DiskVertexProviderFactory, }, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, @@ -32,7 +33,10 @@ use serde::{Deserialize, Serialize}; use crate::{ backend::disk_index::json_spancollector::JsonSpanCollector, - inputs::disk::{DiskIndexLoad, DiskSearchPhase}, + inputs::{ + disk::{DiskIndexLoad, DiskSearchPhase}, + post_processor::TopkPostProcessor, + }, utils::{datafiles, SimilarityMeasure}, }; @@ -264,6 +268,14 @@ where zipped.for_each_in_pool( pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { + let post_processor = search_params.post_processor.as_ref().map( + |TopkPostProcessor::DeterminantDiversity { power, eta }| { + SearchPostProcessorKind::DeterminantDiversity { + power: *power, + eta: *eta, + } + }, + ); let vector_filter = if search_params.vector_filters_file.is_none() { None } else { @@ -277,19 +289,23 @@ where l, Some(search_params.beam_width), vector_filter, + post_processor, search_params.is_flat_search, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; - *rc = search_result.results.len() as u32; - let actual_results = search_result - .results - .len() - .min(search_params.recall_at as usize); + let base_count = (search_result.stats.result_count as usize) + .min(search_params.recall_at as usize) + .min(search_result.results.len()); + + *rc = base_count as u32; + id_chunk.fill(0); + dist_chunk.fill(0.0); + for (i, result_item) in search_result .results .iter() - .take(actual_results) + .take(base_count) .enumerate() { id_chunk[i] = result_item.vertex_id; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index da1acc254..443aa66c8 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -46,9 +46,9 @@ use crate::{ search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::async_::{ - DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, - TopkPostProcessor, + inputs::{ + async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, + post_processor::TopkPostProcessor, }, utils::{ self, diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index bb8147939..f762d41ad 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -4,7 +4,7 @@ */ mod build; -mod post_processor; +pub(crate) mod post_processor; mod search; mod streaming; diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 2fcfe196d..9a56f0d39 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -23,6 +23,39 @@ impl DeterminantDiversity { } } +pub(crate) fn rank_and_limit_by_distance( + distances: &[f32], + power: f32, + eta: f32, +) -> (Vec, usize) { + let mut ranked: Vec<(usize, f32)> = distances + .iter() + .copied() + .enumerate() + .map(|(rank, distance)| { + let transformed = distance.abs().powf(power) + (rank as f32) * eta; + (rank, -transformed) + }) + .collect(); + + ranked.sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let ranked_indices: Vec = ranked.into_iter().map(|(rank, _)| rank).collect(); + if ranked_indices.is_empty() { + return (ranked_indices, 0); + } + + let keep_ratio = (1.0 / (1.0 + power * eta * 10.0)).clamp(0.1, 1.0); + let max_emit = ((ranked_indices.len() as f32) * keep_ratio) + .round() + .max(1.0) as usize; + + (ranked_indices, max_emit) +} + impl glue::SearchPostProcess for DeterminantDiversity where A: Accessor + diskann::provider::BuildQueryComputer + Send, @@ -42,28 +75,14 @@ where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { - // Placeholder deterministic-diversity scoring that uses both parameters. - let mut reranked: Vec<(Neighbor, f32)> = candidates - .enumerate() - .map(|(rank, candidate)| { - let transformed = candidate.distance.abs().powf(self.power) - + (rank as f32) * self.eta; - (candidate, -transformed) - }) - .collect(); - - reranked.sort_by(|a, b| { - b.1.partial_cmp(&a.1) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // Emit only part of the reranked list so power/eta impact recall, - // making this path easy to validate in benchmark outputs. - let keep_ratio = (1.0 / (1.0 + self.power * self.eta * 10.0)).clamp(0.1, 1.0); - let max_emit = ((reranked.len() as f32) * keep_ratio).round().max(1.0) as usize; + let candidates: Vec> = candidates.collect(); + let distances: Vec = candidates.iter().map(|c| c.distance).collect(); + let (ranked_indices, max_emit) = + rank_and_limit_by_distance(&distances, self.power, self.eta); let mut count = 0; - for (candidate, _) in reranked.into_iter().take(max_emit) { + for rank in ranked_indices.into_iter().take(max_emit) { + let candidate = &candidates[rank]; let state = output.push(candidate.id, candidate.distance); count += 1; if !state.is_available() { diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 04077e4a1..8f6ed151a 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::{ - inputs::{self, as_input, save_and_load, Example}, + inputs::{self, as_input, post_processor::TopkPostProcessor, save_and_load, Example}, utils::SimilarityMeasure, }; @@ -130,49 +130,6 @@ pub(crate) struct TopkSearchPhase { pub(crate) post_processor: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "kebab-case")] -pub(crate) enum TopkPostProcessor { - DeterminantDiversity { - #[serde(default = "default_det_div_power")] - power: f32, - #[serde(default = "default_det_div_eta")] - eta: f32, - }, -} - -const fn default_det_div_power() -> f32 { - 2.0 -} - -const fn default_det_div_eta() -> f32 { - 0.01 -} - -impl CheckDeserialization for TopkPostProcessor { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - match self { - TopkPostProcessor::DeterminantDiversity { power, eta } => { - if *power <= 0.0 { - return Err(anyhow::anyhow!( - "determinant-diversity power must be > 0.0, got: {}", - power - )); - } - - if *eta < 0.0 { - return Err(anyhow::anyhow!( - "determinant-diversity eta must be >= 0.0, got: {}", - eta - )); - } - - Ok(()) - } - } - } -} - impl CheckDeserialization for TopkSearchPhase { fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { // Check the validity of the input files. diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 2951d1fe4..43ab2c9df 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -15,7 +15,7 @@ use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, ge use serde::{Deserialize, Serialize}; use crate::{ - inputs::{as_input, Example}, + inputs::{as_input, post_processor::TopkPostProcessor, Example}, utils::SimilarityMeasure, }; @@ -85,6 +85,8 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, + #[serde(default)] + pub(crate) post_processor: Option, } ///////// @@ -234,6 +236,12 @@ impl CheckDeserialization for DiskSearchPhase { anyhow::bail!("search_io_limit must be positive if specified"); } } + + if let Some(pp) = self.post_processor.as_mut() { + pp.check_deserialization(checker) + .context("invalid disk search post processor")?; + } + Ok(()) } } @@ -272,6 +280,7 @@ impl Example for DiskIndexOperation { vector_filters_file: None, num_nodes_to_cache: None, search_io_limit: None, + post_processor: None, }; Self { @@ -397,6 +406,14 @@ impl DiskSearchPhase { Some(lim) => write_field!(f, "Search IO Limit", format!("{lim}"))?, None => write_field!(f, "Search IO Limit", "none (defaults to `usize::MAX`)")?, } + match &self.post_processor { + Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { + write_field!(f, "Post Processor", "determinant-diversity")?; + write_field!(f, "DetDiv Power", power)?; + write_field!(f, "DetDiv Eta", eta)?; + } + None => write_field!(f, "Post Processor", "none")?, + } Ok(()) } } diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index f5f6c015a..0733fa266 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod async_; pub(crate) mod disk; pub(crate) mod exhaustive; pub(crate) mod filters; +pub(crate) mod post_processor; pub(crate) mod save_and_load; pub(crate) fn register_inputs( diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs new file mode 100644 index 000000000..3958180dd --- /dev/null +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -0,0 +1,48 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_benchmark_runner::{CheckDeserialization, Checker}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub(crate) enum TopkPostProcessor { + DeterminantDiversity { + #[serde(default = "default_det_div_power")] + power: f32, + #[serde(default = "default_det_div_eta")] + eta: f32, + }, +} + +const fn default_det_div_power() -> f32 { + 2.0 +} + +const fn default_det_div_eta() -> f32 { + 0.01 +} + +impl CheckDeserialization for TopkPostProcessor { + fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { + match self { + TopkPostProcessor::DeterminantDiversity { power, eta } => { + if *power <= 0.0 { + anyhow::bail!( + "determinant-diversity power must be > 0.0, got: {}", + power + ); + } + if *eta < 0.0 { + anyhow::bail!( + "determinant-diversity eta must be >= 0.0, got: {}", + eta + ); + } + Ok(()) + } + } + } +} diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index efb9bf697..cb38cbd1f 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1092,6 +1092,7 @@ pub(crate) mod disk_index_builder_tests { &mut indices, &mut distances, &mut associated_data, + None, &|_| true, false, ); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 1344605f4..981b4ac4e 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -273,12 +273,66 @@ pub struct RerankAndFilter<'a> { filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), } +#[derive(Clone, Copy)] +pub struct DeterminantDiversityAndFilter<'a> { + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + power: f32, + eta: f32, +} + +#[derive(Clone, Copy)] +pub enum SearchPostProcessorKind { + RerankAndFilter, + DeterminantDiversity { power: f32, eta: f32 }, +} + +#[derive(Clone, Copy)] +pub enum DiskSearchPostProcessor<'a> { + RerankAndFilter(RerankAndFilter<'a>), + DeterminantDiversity(DeterminantDiversityAndFilter<'a>), +} + impl<'a> RerankAndFilter<'a> { - fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { + pub fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { Self { filter } } } +impl<'a> DeterminantDiversityAndFilter<'a> { + pub fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), power: f32, eta: f32) -> Self { + Self { filter, power, eta } + } +} + +fn rank_and_limit_by_distance(distances: &[f32], power: f32, eta: f32) -> (Vec, usize) { + let mut ranked: Vec<(usize, f32)> = distances + .iter() + .copied() + .enumerate() + .map(|(rank, distance)| { + let transformed = distance.abs().powf(power) + (rank as f32) * eta; + (rank, -transformed) + }) + .collect(); + + ranked.sort_by(|a, b| { + b.1.partial_cmp(&a.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let ranked_indices: Vec = ranked.into_iter().map(|(rank, _)| rank).collect(); + if ranked_indices.is_empty() { + return (ranked_indices, 0); + } + + let keep_ratio = (1.0 / (1.0 + power * eta * 10.0)).clamp(0.1, 1.0); + let max_emit = ((ranked_indices.len() as f32) * keep_ratio) + .round() + .max(1.0) as usize; + + (ranked_indices, max_emit) +} + impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -340,6 +394,115 @@ where } } +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + &[Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DeterminantDiversityAndFilter<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + _computer: &DiskQueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + let provider = accessor.provider; + + let mut uncached_ids = Vec::new(); + let mut reranked = candidates + .map(|n| n.id) + .filter(|id| (self.filter)(id)) + .filter_map(|n| { + if let Some(entry) = accessor.scratch.distance_cache.get(&n) { + Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0))) + } else { + uncached_ids.push(n); + None + } + }) + .collect::, _>>()?; + if !uncached_ids.is_empty() { + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?; + for n in &uncached_ids { + let v = accessor.scratch.vertex_provider.get_vector(n)?; + let d = provider.distance_comparer.evaluate_similarity(query, v); + let a = accessor.scratch.vertex_provider.get_associated_data(n)?; + reranked.push(((*n, *a), d)); + } + } + + reranked + .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + let distances: Vec = reranked.iter().map(|item| item.1).collect(); + let (ranked_indices, max_emit) = rank_and_limit_by_distance(&distances, self.power, self.eta); + let selected: Vec<_> = ranked_indices + .into_iter() + .take(max_emit) + .map(|rank| reranked[rank]) + .collect(); + + Ok(output.extend(selected)) + } +} + +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + &[Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DiskSearchPostProcessor<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + computer: &DiskQueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + match self { + DiskSearchPostProcessor::RerankAndFilter(pp) => { + pp.post_process(accessor, query, computer, candidates, output) + .await + } + DiskSearchPostProcessor::DeterminantDiversity(pp) => { + pp.post_process(accessor, query, computer, candidates, output) + .await + } + } + } +} + impl<'this, Data, ProviderFactory> SearchStrategy, &[Data::VectorDataType]> for DiskSearchStrategy<'this, Data, ProviderFactory> where @@ -917,6 +1080,7 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. + #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], @@ -924,6 +1088,7 @@ where search_list_size: u32, beam_width: Option, vector_filter: Option>, + post_processor: Option, is_flat_search: bool, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); @@ -932,6 +1097,18 @@ where let mut associated_data = vec![Data::AssociatedDataType::default(); return_list_size as usize]; + let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); + let post_processor = post_processor.map(|processor| match processor { + SearchPostProcessorKind::RerankAndFilter => { + DiskSearchPostProcessor::RerankAndFilter(RerankAndFilter::new(vector_filter.as_ref())) + } + SearchPostProcessorKind::DeterminantDiversity { power, eta } => { + DiskSearchPostProcessor::DeterminantDiversity( + DeterminantDiversityAndFilter::new(vector_filter.as_ref(), power, eta), + ) + } + }); + let stats = self.search_internal( query, return_list_size as usize, @@ -941,7 +1118,8 @@ where &mut indices, &mut distances, &mut associated_data, - &vector_filter.unwrap_or(default_vector_filter::()), + post_processor, + vector_filter.as_ref(), is_flat_search, )?; @@ -968,7 +1146,7 @@ where /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. #[allow(clippy::too_many_arguments)] - pub(crate) fn search_internal( + pub fn search_internal( &self, query: &[Data::VectorDataType], k_value: usize, @@ -978,6 +1156,7 @@ where indices: &mut [u32], distances: &mut [f32], associated_data: &mut [Data::AssociatedDataType], + post_processor: Option>, vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), is_flat_search: bool, ) -> ANNResult { @@ -1000,10 +1179,18 @@ where &Knn::new(k, l, beam_width)?, &mut result_output_buffer, ))? + } else if let Some(processor) = post_processor { + self.runtime.block_on(self.index.search_with( + Knn::new(k, l, beam_width)?, + &strategy, + processor, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? } else { - let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( - knn_search, + Knn::new(k, l, beam_width)?, &strategy, &DefaultContext, strategy.query, @@ -1400,6 +1587,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &(|_| true), false, ); @@ -1448,7 +1636,15 @@ mod disk_provider_tests { .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + .search( + query, + params.k as u32, + params.l as u32, + beam_width, + None, + None, + false, + ) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1558,6 +1754,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &|_| true, false, ); @@ -1628,6 +1825,7 @@ mod disk_provider_tests { search_list_size, Some(4), None, + None, false, ); assert!(result.is_ok(), "Expected search to succeed"); @@ -1966,6 +2164,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &vector_filter, is_flat_search, ); @@ -1988,6 +2187,7 @@ mod disk_provider_tests { 10, None, // beam_width Some(Box::new(vector_filter)), + None, is_flat_search, ); diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index a0a91fde2..81aabb902 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -259,6 +259,7 @@ where l, Some(parameters.beam_width as usize), Some(vector_filter_function), + None, parameters.is_flat_search, ); From fbd34fdd8034615afd95729a0cddb903502fed63 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 4 May 2026 18:08:57 +0530 Subject: [PATCH 10/38] Restrict determinant-diversity to async full-precision topk --- diskann-benchmark/src/inputs/async_.rs | 49 ++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 8f6ed151a..c4a8a7e4d 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -425,6 +425,14 @@ impl CheckDeserialization for SearchPhase { } } +fn has_topk_determinant_diversity(phase: &SearchPhase) -> bool { + phase + .as_topk() + .ok() + .and_then(|topk| topk.post_processor.as_ref()) + .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum SearchPhaseKind { Topk, @@ -778,6 +786,14 @@ impl CheckDeserialization for IndexOperation { self.source.check_deserialization(checker)?; self.search_phase.check_deserialization(checker)?; + if has_topk_determinant_diversity(&self.search_phase) + && !matches!(self.source.data_type(), DataType::Float32) + { + anyhow::bail!( + "topk determinant-diversity is supported only for full-precision f32 searches" + ); + } + Ok(()) } } @@ -856,7 +872,15 @@ impl IndexPQOperation { impl CheckDeserialization for IndexPQOperation { fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.index_operation.check_deserialization(checker) + self.index_operation.check_deserialization(checker)?; + + if has_topk_determinant_diversity(&self.index_operation.search_phase) { + anyhow::bail!( + "topk determinant-diversity is supported only for async full-precision topk, not async-index-build-pq" + ); + } + + Ok(()) } } @@ -942,7 +966,15 @@ impl CheckDeserialization for IndexSQOperation { )); } - self.index_operation.check_deserialization(checker) + self.index_operation.check_deserialization(checker)?; + + if has_topk_determinant_diversity(&self.index_operation.search_phase) { + anyhow::bail!( + "topk determinant-diversity is supported only for async full-precision topk, not async-index-build-sq" + ); + } + + Ok(()) } } @@ -1021,6 +1053,12 @@ impl CheckDeserialization for SphericalQuantBuild { self.build.check_deserialization(checker)?; self.search_phase.check_deserialization(checker)?; + if has_topk_determinant_diversity(&self.search_phase) { + anyhow::bail!( + "topk determinant-diversity is supported only for async full-precision topk, not async-index-build-spherical-quantization" + ); + } + if self.build.save_path.is_some() { return Err(anyhow::anyhow!( "Spherical quantization does not support saving the index" @@ -1296,6 +1334,13 @@ impl CheckDeserialization for DynamicIndexRun { self.build.check_deserialization(checker)?; self.runbook_params.check_deserialization(checker)?; self.search_phase.check_deserialization(checker)?; + + if has_topk_determinant_diversity(&self.search_phase) { + anyhow::bail!( + "topk determinant-diversity is supported only for async full-precision topk, not async-dynamic-index-run" + ); + } + Ok(()) } } From 851174e180dffde5cb863df582d769754e28e132 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 4 May 2026 18:12:19 +0530 Subject: [PATCH 11/38] Keep single async determinant-diversity example JSON --- .../async-determinant-diversity-defaults.json | 49 ------------------ .../async-determinant-diversity-eta1.json | 51 ------------------- .../async-determinant-diversity-none.json | 46 ----------------- 3 files changed, 146 deletions(-) delete mode 100644 diskann-benchmark/example/async-determinant-diversity-defaults.json delete mode 100644 diskann-benchmark/example/async-determinant-diversity-eta1.json delete mode 100644 diskann-benchmark/example/async-determinant-diversity-none.json diff --git a/diskann-benchmark/example/async-determinant-diversity-defaults.json b/diskann-benchmark/example/async-determinant-diversity-defaults.json deleted file mode 100644 index f60e22769..000000000 --- a/diskann-benchmark/example/async-determinant-diversity-defaults.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - "search_directories": [ - "test_data/disk_index_search" - ], - "jobs": [ - { - "type": "async-index-build", - "content": { - "source": { - "index-source": "Build", - "data_type": "float32", - "data": "disk_index_siftsmall_learn_256pts_data.fbin", - "distance": "squared_l2", - "max_degree": 32, - "l_build": 50, - "alpha": 1.2, - "backedge_ratio": 1.0, - "num_threads": 1, - "start_point_strategy": "medoid", - "num_insert_attempts": 1, - "saturate_inserts": false - }, - "search_phase": { - "search-type": "topk", - "queries": "disk_index_sample_query_10pts.fbin", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", - "reps": 5, - "num_threads": [ - 1 - ], - "post_processor": { - "type": "determinant-diversity" - }, - "runs": [ - { - "search_n": 20, - "search_l": [ - 20, - 30, - 40 - ], - "recall_k": 10 - } - ] - } - } - } - ] -} diff --git a/diskann-benchmark/example/async-determinant-diversity-eta1.json b/diskann-benchmark/example/async-determinant-diversity-eta1.json deleted file mode 100644 index ffd31c756..000000000 --- a/diskann-benchmark/example/async-determinant-diversity-eta1.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "search_directories": [ - "test_data/disk_index_search" - ], - "jobs": [ - { - "type": "async-index-build", - "content": { - "source": { - "index-source": "Build", - "data_type": "float32", - "data": "disk_index_siftsmall_learn_256pts_data.fbin", - "distance": "squared_l2", - "max_degree": 32, - "l_build": 50, - "alpha": 1.2, - "backedge_ratio": 1.0, - "num_threads": 1, - "start_point_strategy": "medoid", - "num_insert_attempts": 1, - "saturate_inserts": false - }, - "search_phase": { - "search-type": "topk", - "queries": "disk_index_sample_query_10pts.fbin", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", - "reps": 5, - "num_threads": [ - 1 - ], - "post_processor": { - "type": "determinant-diversity", - "power": 2.0, - "eta": 1.0 - }, - "runs": [ - { - "search_n": 20, - "search_l": [ - 20, - 30, - 40 - ], - "recall_k": 10 - } - ] - } - } - } - ] -} diff --git a/diskann-benchmark/example/async-determinant-diversity-none.json b/diskann-benchmark/example/async-determinant-diversity-none.json deleted file mode 100644 index 820650ee9..000000000 --- a/diskann-benchmark/example/async-determinant-diversity-none.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "search_directories": [ - "test_data/disk_index_search" - ], - "jobs": [ - { - "type": "async-index-build", - "content": { - "source": { - "index-source": "Build", - "data_type": "float32", - "data": "disk_index_siftsmall_learn_256pts_data.fbin", - "distance": "squared_l2", - "max_degree": 32, - "l_build": 50, - "alpha": 1.2, - "backedge_ratio": 1.0, - "num_threads": 1, - "start_point_strategy": "medoid", - "num_insert_attempts": 1, - "saturate_inserts": false - }, - "search_phase": { - "search-type": "topk", - "queries": "disk_index_sample_query_10pts.fbin", - "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", - "reps": 5, - "num_threads": [ - 1 - ], - "runs": [ - { - "search_n": 20, - "search_l": [ - 20, - 30, - 40 - ], - "recall_k": 10 - } - ] - } - } - } - ] -} From ed0c9187a518cc52f1bdf7826d3616f87054c963 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Tue, 5 May 2026 12:41:11 +0530 Subject: [PATCH 12/38] Improve plugin matching resilience via phase-shape helpers Replace kind()-based string equality checks with explicit is_match() and get() phase-shape helpers on plugin structs. This avoids fragile ordering assumptions and makes each plugin responsible for recognising its own phase shape. --- .../src/backend/index/benchmarks.rs | 48 ++++------------- .../src/backend/index/search/plugins.rs | 53 +++++++++++++------ 2 files changed, 47 insertions(+), 54 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 443aa66c8..6348633db 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -46,10 +46,7 @@ use crate::{ search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::{ - async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, - post_processor::TopkPostProcessor, - }, + inputs::async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, utils::{ self, datafiles::{self}, @@ -465,15 +462,7 @@ where >, { fn is_match(&self, phase: &SearchPhase) -> bool { - if Self::kind() != phase.kind() { - return false; - } - - phase - .as_topk() - .ok() - .and_then(|topk| topk.post_processor.as_ref()) - .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) + plugins::DeterminantDiversity::is_match(phase) } fn kind(&self) -> &'static str { @@ -486,15 +475,7 @@ where phase: &SearchPhase, _strategy: &Strategy, ) -> anyhow::Result { - let topk = phase.as_topk()?; - let (power, eta) = match topk.post_processor.as_ref() { - Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => (*power, *eta), - _ => { - return Err(anyhow::anyhow!( - "determinant-diversity plugin selected for non determinant-diversity input", - )); - } - }; + let (topk, power, eta) = plugins::DeterminantDiversity::get(phase)?; let strategy = common::FullPrecision; let context = DefaultContext; @@ -580,18 +561,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - if Self::kind() != phase.kind() { - return false; - } - - phase - .as_topk() - .ok() - .is_some_and(|topk| topk.post_processor.is_none()) + plugins::Topk::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + "topk" } fn run( @@ -630,11 +604,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::Range::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + "range" } fn run( @@ -674,11 +648,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::BetaFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + "topk-beta-filter" } fn run( @@ -733,11 +707,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::MultihopFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + "topk-multihop-filter" } fn run( diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index dc1c0f678..adc7b252e 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -37,8 +37,13 @@ use std::sync::Arc; use diskann::{graph::DiskANNIndex, provider::DataProvider}; use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; -use crate::{backend::index::result::AggregatedSearchResults, inputs::async_::SearchPhaseKind}; - +use crate::{ + backend::index::result::AggregatedSearchResults, + inputs::{ + async_::{SearchPhase, TopkSearchPhase}, + post_processor::TopkPostProcessor, + }, +}; /// A dyn-compatible search plugin for `DP`. /// /// `Kind` is the matching surface used for benchmark selection and diagnostics. `Params` @@ -139,9 +144,11 @@ where pub(crate) struct Topk; impl Topk { - /// Returns [`SearchPhaseKind::Topk`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Topk + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase + .as_topk() + .ok() + .is_some_and(|topk| topk.post_processor.is_none()) } } @@ -150,9 +157,24 @@ impl Topk { pub(crate) struct DeterminantDiversity; impl DeterminantDiversity { - /// Returns [`SearchPhaseKind::Topk`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Topk + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase + .as_topk() + .ok() + .and_then(|topk| topk.post_processor.as_ref()) + .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) + } + + pub(crate) fn get(phase: &SearchPhase) -> anyhow::Result<(&TopkSearchPhase, f32, f32)> { + let topk = phase.as_topk()?; + match topk.post_processor.as_ref() { + Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { + Ok((topk, *power, *eta)) + } + _ => Err(anyhow::anyhow!( + "determinant-diversity plugin selected for non determinant-diversity input", + )), + } } } @@ -161,9 +183,8 @@ impl DeterminantDiversity { pub(crate) struct Range; impl Range { - /// Returns [`SearchPhaseKind::Range`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Range + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_range().is_ok() } } @@ -172,9 +193,8 @@ impl Range { pub(crate) struct BetaFilter; impl BetaFilter { - /// Returns [`SearchPhaseKind::TopkBetaFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkBetaFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_beta_filter().is_ok() } } @@ -183,8 +203,7 @@ impl BetaFilter { pub(crate) struct MultihopFilter; impl MultihopFilter { - /// Returns [`SearchPhaseKind::TopkMultihopFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkMultihopFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_multihop_filter().is_ok() } } From 3a6aa1af7a18205afa907d1fa035fed5526a0dd5 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Tue, 5 May 2026 12:44:58 +0530 Subject: [PATCH 13/38] Use SearchPhaseKind::as_str in benchmark plugin kinds --- diskann-benchmark/src/backend/index/benchmarks.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 6348633db..3a2d2ec53 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -46,7 +46,7 @@ use crate::{ search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, + inputs::async_::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, SearchPhaseKind}, utils::{ self, datafiles::{self}, @@ -565,7 +565,7 @@ where } fn kind(&self) -> &'static str { - "topk" + SearchPhaseKind::Topk.as_str() } fn run( @@ -608,7 +608,7 @@ where } fn kind(&self) -> &'static str { - "range" + SearchPhaseKind::Range.as_str() } fn run( @@ -652,7 +652,7 @@ where } fn kind(&self) -> &'static str { - "topk-beta-filter" + SearchPhaseKind::TopkBetaFilter.as_str() } fn run( @@ -711,7 +711,7 @@ where } fn kind(&self) -> &'static str { - "topk-multihop-filter" + SearchPhaseKind::TopkMultihopFilter.as_str() } fn run( From ccfe4d73356208f1f3bb649513560a8bed563655 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Tue, 5 May 2026 12:56:29 +0530 Subject: [PATCH 14/38] remove serde defaults Co-authored-by: Copilot --- diskann-benchmark/src/inputs/async_.rs | 1 - diskann-benchmark/src/inputs/disk.rs | 1 - diskann-benchmark/src/inputs/post_processor.rs | 10 ---------- 3 files changed, 12 deletions(-) diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index c4a8a7e4d..2512ff885 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -126,7 +126,6 @@ pub(crate) struct TopkSearchPhase { // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, - #[serde(default)] pub(crate) post_processor: Option, } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 43ab2c9df..339fb12c0 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -85,7 +85,6 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, - #[serde(default)] pub(crate) post_processor: Option, } diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs index 3958180dd..a4b7bf96d 100644 --- a/diskann-benchmark/src/inputs/post_processor.rs +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -10,21 +10,11 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "type", rename_all = "kebab-case")] pub(crate) enum TopkPostProcessor { DeterminantDiversity { - #[serde(default = "default_det_div_power")] power: f32, - #[serde(default = "default_det_div_eta")] eta: f32, }, } -const fn default_det_div_power() -> f32 { - 2.0 -} - -const fn default_det_div_eta() -> f32 { - 0.01 -} - impl CheckDeserialization for TopkPostProcessor { fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { match self { From 86883b71b53f90400205995fae776b19409db3fb Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Tue, 5 May 2026 13:26:41 +0530 Subject: [PATCH 15/38] minor merge fix --- .../src/backend/index/search/plugins.rs | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 2dc7b6d39..66ea4e38a 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -40,7 +40,7 @@ use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; use crate::{ backend::index::result::AggregatedSearchResults, inputs::{ - graph_index::{SearchPhase, SearchPhaseKind, TopkSearchPhase}, + graph_index::{SearchPhase, TopkSearchPhase}, post_processor::TopkPostProcessor, }, }; @@ -149,11 +149,6 @@ where pub(crate) struct Topk; impl Topk { - /// Returns [`SearchPhaseKind::Topk`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Topk - } - pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase .as_topk() @@ -193,11 +188,6 @@ impl DeterminantDiversity { pub(crate) struct Range; impl Range { - /// Returns [`SearchPhaseKind::Range`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Range - } - pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase.as_range().is_ok() } @@ -208,11 +198,6 @@ impl Range { pub(crate) struct TopkBetaFilter; impl TopkBetaFilter { - /// Returns [`SearchPhaseKind::TopkBetaFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkBetaFilter - } - pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase.as_topk_beta_filter().is_ok() } @@ -223,11 +208,6 @@ impl TopkBetaFilter { pub(crate) struct TopkMultihopFilter; impl TopkMultihopFilter { - /// Returns [`SearchPhaseKind::TopkMultihopFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkMultihopFilter - } - pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase.as_topk_multihop_filter().is_ok() } From 554bc7f2b2609b32ec8316d9c1870ed5893f9464 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Tue, 5 May 2026 23:51:09 +0530 Subject: [PATCH 16/38] hook up actual algorithm, replace placeholder. Co-authored-by: Copilot --- .../post_processor/determinant_diversity.rs | 65 ++--- .../src/search/provider/disk_provider.rs | 98 +++---- .../determinant_diversity_post_process.rs | 262 ++++++++++++++++++ .../src/model/graph/provider/async_/mod.rs | 2 + 4 files changed, 318 insertions(+), 109 deletions(-) create mode 100644 diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 9a56f0d39..1b5d014ad 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -3,13 +3,14 @@ * Licensed under the MIT license. */ +use diskann::graph::search_output_buffer::SearchOutputBuffer; use diskann::{ error::ANNError, graph::glue, neighbor::Neighbor, provider::Accessor, }; -use diskann::graph::search_output_buffer::SearchOutputBuffer; +use diskann_providers::model::graph::provider::async_::determinant_diversity_post_process; #[derive(Debug, Clone, Copy)] pub(crate) struct DeterminantDiversity { @@ -23,39 +24,6 @@ impl DeterminantDiversity { } } -pub(crate) fn rank_and_limit_by_distance( - distances: &[f32], - power: f32, - eta: f32, -) -> (Vec, usize) { - let mut ranked: Vec<(usize, f32)> = distances - .iter() - .copied() - .enumerate() - .map(|(rank, distance)| { - let transformed = distance.abs().powf(power) + (rank as f32) * eta; - (rank, -transformed) - }) - .collect(); - - ranked.sort_by(|a, b| { - b.1.partial_cmp(&a.1) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let ranked_indices: Vec = ranked.into_iter().map(|(rank, _)| rank).collect(); - if ranked_indices.is_empty() { - return (ranked_indices, 0); - } - - let keep_ratio = (1.0 / (1.0 + power * eta * 10.0)).clamp(0.1, 1.0); - let max_emit = ((ranked_indices.len() as f32) * keep_ratio) - .round() - .max(1.0) as usize; - - (ranked_indices, max_emit) -} - impl glue::SearchPostProcess for DeterminantDiversity where A: Accessor + diskann::provider::BuildQueryComputer + Send, @@ -76,20 +44,19 @@ where B: SearchOutputBuffer + Send + ?Sized, { let candidates: Vec> = candidates.collect(); - let distances: Vec = candidates.iter().map(|c| c.distance).collect(); - let (ranked_indices, max_emit) = - rank_and_limit_by_distance(&distances, self.power, self.eta); - - let mut count = 0; - for rank in ranked_indices.into_iter().take(max_emit) { - let candidate = &candidates[rank]; - let state = output.push(candidate.id, candidate.distance); - count += 1; - if !state.is_available() { - break; - } - } - - Ok(count) + let embedded: Vec<_> = candidates + .iter() + .map(|c| (c.id, c.distance, vec![c.distance])) + .collect(); + + let reranked = determinant_diversity_post_process( + embedded, + &[0.0], + candidates.len(), + self.eta, + self.power, + ); + + Ok(output.extend(reranked)) } } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 981b4ac4e..f828268bd 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -36,7 +36,10 @@ use diskann::{ }; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ - model::{compute_pq_distance, compute_pq_distance_for_pq_coordinates}, + model::{ + compute_pq_distance, compute_pq_distance_for_pq_coordinates, + graph::provider::async_::determinant_diversity_post_process, + }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, }; use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; @@ -304,35 +307,6 @@ impl<'a> DeterminantDiversityAndFilter<'a> { } } -fn rank_and_limit_by_distance(distances: &[f32], power: f32, eta: f32) -> (Vec, usize) { - let mut ranked: Vec<(usize, f32)> = distances - .iter() - .copied() - .enumerate() - .map(|(rank, distance)| { - let transformed = distance.abs().powf(power) + (rank as f32) * eta; - (rank, -transformed) - }) - .collect(); - - ranked.sort_by(|a, b| { - b.1.partial_cmp(&a.1) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let ranked_indices: Vec = ranked.into_iter().map(|(rank, _)| rank).collect(); - if ranked_indices.is_empty() { - return (ranked_indices, 0); - } - - let keep_ratio = (1.0 / (1.0 + power * eta * 10.0)).clamp(0.1, 1.0); - let max_emit = ((ranked_indices.len() as f32) * keep_ratio) - .round() - .max(1.0) as usize; - - (ranked_indices, max_emit) -} - impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -423,42 +397,46 @@ where + ?Sized, { let provider = accessor.provider; + let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; - let mut uncached_ids = Vec::new(); - let mut reranked = candidates - .map(|n| n.id) + let candidate_ids: Vec = candidates + .map(|candidate| candidate.id) .filter(|id| (self.filter)(id)) - .filter_map(|n| { - if let Some(entry) = accessor.scratch.distance_cache.get(&n) { - Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0))) - } else { - uncached_ids.push(n); - None - } - }) - .collect::, _>>()?; - if !uncached_ids.is_empty() { - ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?; - for n in &uncached_ids { - let v = accessor.scratch.vertex_provider.get_vector(n)?; - let d = provider.distance_comparer.evaluate_similarity(query, v); - let a = accessor.scratch.vertex_provider.get_associated_data(n)?; - reranked.push(((*n, *a), d)); - } + .collect(); + + if candidate_ids.is_empty() { + return Ok(0); } - reranked - .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &candidate_ids)?; - let distances: Vec = reranked.iter().map(|item| item.1).collect(); - let (ranked_indices, max_emit) = rank_and_limit_by_distance(&distances, self.power, self.eta); - let selected: Vec<_> = ranked_indices - .into_iter() - .take(max_emit) - .map(|rank| reranked[rank]) - .collect(); + let mut candidate_vectors = Vec::with_capacity(candidate_ids.len()); + let mut associated_data = HashMap::with_capacity(candidate_ids.len()); + + for id in candidate_ids { + let vector = accessor.scratch.vertex_provider.get_vector(&id)?; + let distance = provider.distance_comparer.evaluate_similarity(query, vector); + let vector_f32 = Data::VectorDataType::as_f32(vector).map_err(Into::into)?; + let data = accessor.scratch.vertex_provider.get_associated_data(&id)?; + + candidate_vectors.push((id, distance, vector_f32.to_vec())); + associated_data.insert(id, *data); + } + + let reranked = determinant_diversity_post_process( + candidate_vectors, + &query_f32, + usize::MAX, + self.eta, + self.power, + ); - Ok(output.extend(selected)) + Ok(output.extend(reranked.into_iter().filter_map(|(id, distance)| { + associated_data + .get(&id) + .copied() + .map(|data| ((id, data), distance)) + }))) } } diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs new file mode 100644 index 000000000..ba1ab9603 --- /dev/null +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -0,0 +1,262 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_vector::{ + MathematicalValue, PureDistanceFunction, + distance::InnerProduct, +}; + +pub fn determinant_diversity_post_process( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_eta: f32, + determinant_diversity_power: f32, +) -> Vec<(Id, f32)> { + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let candidates: Vec<_> = candidates + .into_iter() + .filter(|(_, _, vector)| vector.len() == query.len()) + .collect(); + + if candidates.is_empty() { + return Vec::new(); + } + + let k = k.min(candidates.len()); + if k == 0 { + return Vec::new(); + } + + if candidates[0].2.is_empty() { + return Vec::new(); + } + + let distance_range = { + let mut min_distance = f32::INFINITY; + let mut max_distance = f32::NEG_INFINITY; + + for (_, distance, _) in &candidates { + min_distance = min_distance.min(*distance); + max_distance = max_distance.max(*distance); + } + + (min_distance, max_distance) + }; + + if determinant_diversity_eta > 0.0 { + post_process_with_eta_f32( + candidates, + k, + determinant_diversity_eta, + determinant_diversity_power, + distance_range, + ) + } else { + post_process_greedy_orthogonalization_f32( + candidates, + k, + determinant_diversity_power, + distance_range, + ) + } +} + +fn post_process_with_eta_f32( + candidates: Vec<(Id, f32, Vec)>, + k: usize, + eta: f32, + power: f32, + distance_range: (f32, f32), +) -> Vec<(Id, f32)> { + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + let inv_sqrt_eta = 1.0 / eta.sqrt(); + let mut residuals = Vec::with_capacity(n); + let mut norms_sq = Vec::with_capacity(n); + + for (_, distance_to_query, v) in &candidates { + let scale = distance_to_similarity(*distance_to_query, distance_range).powf(power) + * inv_sqrt_eta; + let residual: Vec = v.iter().map(|&x| x * scale).collect(); + let norm_sq = dot_product(&residual, &residual); + residuals.push(residual); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; + + for _ in 0..k { + let best_idx = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i); + + let Some(selected_index) = best_idx else { + break; + }; + + selected.push(selected_index); + available[selected_index] = false; + + if selected.len() == k { + break; + } + + let best_norm_sq = norms_sq[selected_index]; + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq = 1.0 / best_norm_sq; + let r_star_copy = residuals[selected_index].clone(); + + for i in 0..n { + if !available[i] { + projections[i] = 0.0; + } else { + projections[i] = dot_product(&residuals[i], &r_star_copy) * inv_norm_sq; + } + } + + for i in 0..n { + if !available[i] { + continue; + } + + let projection = projections[i]; + for (residual, &star) in residuals[i].iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + norms_sq[i] = (norms_sq[i] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +fn post_process_greedy_orthogonalization_f32( + candidates: Vec<(Id, f32, Vec)>, + k: usize, + power: f32, + distance_range: (f32, f32), +) -> Vec<(Id, f32)> { + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + let mut residuals = Vec::with_capacity(n); + let mut norms_sq = Vec::with_capacity(n); + + for (_, distance_to_query, v) in &candidates { + let scale = distance_to_similarity(*distance_to_query, distance_range).powf(power); + let residual: Vec = v.iter().map(|&x| x * scale).collect(); + let norm_sq = dot_product(&residual, &residual); + residuals.push(residual); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; + + for _ in 0..k { + let best = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let Some((best_index, _)) = best else { + break; + }; + + let best_norm_sq = norms_sq[best_index]; + selected.push(best_index); + available[best_index] = false; + + if selected.len() == k { + break; + } + + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq_star = 1.0 / best_norm_sq; + let r_star_copy = residuals[best_index].clone(); + + for j in 0..n { + if !available[j] { + projections[j] = 0.0; + } else { + projections[j] = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; + } + } + + for j in 0..n { + if !available[j] { + continue; + } + + let projection = projections[j]; + for (residual, &star) in residuals[j].iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + norms_sq[j] = (norms_sq[j] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +fn distance_to_similarity(distance: f32, distance_range: (f32, f32)) -> f32 { + let (min_distance, max_distance) = distance_range; + let span = (max_distance - min_distance).max(f32::EPSILON); + + // Distances are lower-is-better in DiskANN distance semantics. + ((max_distance - distance) / span).max(0.0) + f32::EPSILON +} + +#[inline] +fn dot_product(a: &[f32], b: &[f32]) -> f32 { + >>::evaluate(a, b) + .into_inner() +} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index cf719e730..816b95523 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -8,6 +8,8 @@ pub mod common; pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; pub(crate) mod postprocess; +mod determinant_diversity_post_process; +pub use determinant_diversity_post_process::determinant_diversity_post_process; pub mod distances; From 8c59e6f6654ba69f15590b7e48906b7bcd833a1c Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 00:17:47 +0530 Subject: [PATCH 17/38] WIP: Trait bound experiment for async determinant-diversity - HRTB projection issue This commit explores approaches to wire real candidate vectors into async determinant-diversity post-processing. Current state: IN COMPILATION ERROR (intentional for analysis) Attempted approaches: 1. Initial shim-trait FullPrecisionVectorAccessor with async get_full_precision_vector() - Resulted in 'implementation not general enough' at search_with() call 2. Removed explicit for<'a> post_processor::DeterminantDiversity bound - Still fails - the constraint is inherent in search_with() signature itself Root cause analysis: - search_with() requires: PP: for<'a> SearchPostProcess, T, O> - This means post-processor must work for ANY accessor lifetime 'a - But query = queries.row(query_idx) is borrowed for specific loop iteration lifetime - These are fundamentally incompatible - a borrowed value can't satisfy for<'a> generically Compiler errors (3 total): - 'not general enough': implementation needed for or<'a> but found specific '0 - 'does not live long enough': queries lifetime too short for 'static requirement Files modified: - diskann-benchmark/src/backend/index/benchmarks.rs: * Removed explicit for<'a> post_processor::DeterminantDiversity constraint * Narrowed plugin impl to FullPrecisionProvider - diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs: * Added shim trait FullPrecisionVectorAccessor * Async method get_full_precision_vector(&mut self, id) -> impl Future<...> Next steps to investigate: - Move determinant-diversity outside search_with() as post-processing reranking - This avoids HRTB entirely by applying after candidates are returned - Benchmark impact: measure recall/QPS with external reranking vs baseline Related context: - Disk index determinant-diversity works correctly (uses real vectors, shows 51-53% QPS cost) - Shared algorithm fixed (distance-to-similarity scoring direction) - Branch already merged with origin/main --- .../src/backend/index/benchmarks.rs | 18 +++--- .../post_processor/determinant_diversity.rs | 57 +++++++++++++++---- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index f7e8e4e8c..ce3dd7165 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -451,16 +451,14 @@ impl Strategy { // Topk // //------// -impl search::Plugin> +impl search::Plugin, SearchPhase, Strategy> for plugins::DeterminantDiversity where - DP: DataProvider + QueryType, - common::FullPrecision: for<'a> glue::SearchStrategy, - for<'a> post_processor::DeterminantDiversity: glue::SearchPostProcess< - >::SearchAccessor<'a>, - &'a [DP::Element], - u32, - >, + common::FullPrecision: for<'a> glue::SearchStrategy, &'a [f32]>, + for<'a> , + &'a [f32], + >>::SearchAccessor<'a>: post_processor::determinant_diversity::FullPrecisionVectorAccessor, { fn is_match(&self, phase: &SearchPhase) -> bool { plugins::DeterminantDiversity::is_match(phase) @@ -472,7 +470,7 @@ where fn run( &self, - index: Arc>, + index: Arc>>, phase: &SearchPhase, _strategy: &Strategy, ) -> anyhow::Result { @@ -482,7 +480,7 @@ where let context = DefaultContext; let det_div = post_processor::DeterminantDiversity::new(power, eta); - let queries: Arc> = + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 1b5d014ad..38d1f529d 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -3,6 +3,8 @@ * Licensed under the MIT license. */ +use std::future::Future; + use diskann::graph::search_output_buffer::SearchOutputBuffer; use diskann::{ error::ANNError, @@ -10,7 +12,37 @@ use diskann::{ neighbor::Neighbor, provider::Accessor, }; -use diskann_providers::model::graph::provider::async_::determinant_diversity_post_process; +use diskann_providers::model::graph::provider::async_::{ + determinant_diversity_post_process, + inmem, +}; +use diskann_utils::future::AsyncFriendly; + +pub(crate) trait FullPrecisionVectorAccessor: Accessor + Send { + fn get_full_precision_vector( + &mut self, + id: Self::Id, + ) -> impl Future, ANNError>> + Send; +} + +impl FullPrecisionVectorAccessor for inmem::FullAccessor<'_, f32, Q, D, Ctx> +where + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: diskann::provider::ExecutionContext, +{ + fn get_full_precision_vector( + &mut self, + id: Self::Id, + ) -> impl Future, ANNError>> + Send { + async move { + self.get_element(id) + .await + .map(|vector| vector.to_vec()) + .map_err(Into::into) + } + } +} #[derive(Debug, Clone, Copy)] pub(crate) struct DeterminantDiversity { @@ -26,15 +58,15 @@ impl DeterminantDiversity { impl glue::SearchPostProcess for DeterminantDiversity where - A: Accessor + diskann::provider::BuildQueryComputer + Send, - T: Send + Sync, + A: FullPrecisionVectorAccessor + diskann::provider::BuildQueryComputer + Send, + T: AsRef<[f32]> + Send + Sync, { type Error = ANNError; async fn post_process( &self, - _accessor: &mut A, - _query: T, + accessor: &mut A, + query: T, _computer: &>::QueryComputer, candidates: I, output: &mut B, @@ -44,14 +76,19 @@ where B: SearchOutputBuffer + Send + ?Sized, { let candidates: Vec> = candidates.collect(); - let embedded: Vec<_> = candidates - .iter() - .map(|c| (c.id, c.distance, vec![c.distance])) - .collect(); + let mut embedded = Vec::with_capacity(candidates.len()); + + for candidate in &candidates { + embedded.push(( + candidate.id, + candidate.distance, + accessor.get_full_precision_vector(candidate.id).await?, + )); + } let reranked = determinant_diversity_post_process( embedded, - &[0.0], + query.as_ref(), candidates.len(), self.eta, self.power, From b73abc87e18af80f396d64074d350e84b1f48b8d Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 08:55:38 +0530 Subject: [PATCH 18/38] apply mark's beautiful fix for lifetime issue Co-authored-by: Copilot --- .../example/async-determinant-diversity.json | 2 +- .../src/backend/index/benchmarks.rs | 114 ++++++++++++------ .../post_processor/determinant_diversity.rs | 2 + 3 files changed, 81 insertions(+), 37 deletions(-) diff --git a/diskann-benchmark/example/async-determinant-diversity.json b/diskann-benchmark/example/async-determinant-diversity.json index a8c1cc86f..acb4260ea 100644 --- a/diskann-benchmark/example/async-determinant-diversity.json +++ b/diskann-benchmark/example/async-determinant-diversity.json @@ -4,7 +4,7 @@ ], "jobs": [ { - "type": "async-index-build", + "type": "graph-index-build", "content": { "source": { "index-source": "Build", diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index ce3dd7165..ee0a766e1 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{io::Write, num::NonZeroUsize, sync::Arc}; +use std::{io::Write, num::NonZeroUsize, sync::Arc, time::Instant}; use diskann::{ graph::SampleableForStart, @@ -18,7 +18,7 @@ use diskann_benchmark_core::{ use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, - utils::datatype, + utils::{datatype, MicroSeconds}, Benchmark, Checkpoint, }; use diskann_providers::{ @@ -454,11 +454,11 @@ impl Strategy { impl search::Plugin, SearchPhase, Strategy> for plugins::DeterminantDiversity where - common::FullPrecision: for<'a> glue::SearchStrategy, &'a [f32]>, - for<'a> glue::SearchStrategy< FullPrecisionProvider, &'a [f32], - >>::SearchAccessor<'a>: post_processor::determinant_diversity::FullPrecisionVectorAccessor, + SearchAccessor<'b>: post_processor::determinant_diversity::FullPrecisionVectorAccessor, + >, { fn is_match(&self, phase: &SearchPhase) -> bool { plugins::DeterminantDiversity::is_match(phase) @@ -494,47 +494,89 @@ where let mut all_recalls = Vec::new(); - for query_idx in 0..queries.nrows() { - let query = queries.row(query_idx); - let mut output: Vec> = Vec::new(); - utils::tokio::block_on(async { - index - .search_with( - knn_params, - &strategy, - det_div, - &context, - query, - &mut output, - ) - .await - })?; - - let gt = groundtruth.row(query_idx); - let mut matches = 0; - for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { - if i >= gt.len() { - break; - } - if gt.contains(&neighbor.id) { - matches += 1; + let mut qps = Vec::with_capacity(topk.reps.get()); + let mut search_latencies = Vec::with_capacity(topk.reps.get()); + let mut mean_latencies = Vec::with_capacity(topk.reps.get()); + let mut p90_latencies = Vec::with_capacity(topk.reps.get()); + let mut p99_latencies = Vec::with_capacity(topk.reps.get()); + + for _ in 0..topk.reps.get() { + let search_start = Instant::now(); + let mut per_query_latencies = Vec::with_capacity(queries.nrows()); + + for query_idx in 0..queries.nrows() { + let query = queries.row(query_idx); + let mut output: Vec> = Vec::new(); + + let query_start = Instant::now(); + utils::tokio::block_on(async { + index + .search_with( + knn_params, + &strategy, + det_div, + &context, + query, + &mut output, + ) + .await + })?; + per_query_latencies.push(query_start.elapsed().as_micros() as u64); + + let gt = groundtruth.row(query_idx); + let mut matches = 0; + for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { + if i >= gt.len() { + break; + } + if gt.contains(&neighbor.id) { + matches += 1; + } } + all_recalls.push(matches); + } + + let elapsed: MicroSeconds = search_start.elapsed().into(); + let elapsed_secs = elapsed.as_seconds(); + if elapsed_secs > 0.0 { + qps.push(queries.nrows() as f64 / elapsed_secs); + } else { + qps.push(0.0); } - all_recalls.push(matches); + + per_query_latencies.sort_unstable(); + let len = per_query_latencies.len(); + let p90_idx = ((len as f64 * 0.90).ceil() as usize) + .saturating_sub(1) + .min(len.saturating_sub(1)); + let p99_idx = ((len as f64 * 0.99).ceil() as usize) + .saturating_sub(1) + .min(len.saturating_sub(1)); + + let mean = if len > 0 { + per_query_latencies.iter().sum::() as f64 / len as f64 + } else { + 0.0 + }; + + search_latencies.push(elapsed); + mean_latencies.push(mean); + p90_latencies.push(MicroSeconds::new(*per_query_latencies.get(p90_idx).unwrap_or(&0))); + p99_latencies.push(MicroSeconds::new(*per_query_latencies.get(p99_idx).unwrap_or(&0))); } let avg_recall = all_recalls.iter().sum::() as f32 - / (queries.nrows() * run.recall_k) as f32; + / (queries.nrows() * run.recall_k * topk.reps.get()) as f32; all_results.push(SearchResults { num_tasks: threads.get(), search_n: run.search_n, search_l: *search_l, - qps: vec![], - search_latencies: vec![], - mean_latencies: vec![], - p90_latencies: vec![], - p99_latencies: vec![], + qps, + search_latencies, + mean_latencies, + p90_latencies, + p99_latencies, recall: utils::recall::RecallMetrics { recall_k: run.recall_k, recall_n: run.search_n, diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 38d1f529d..0b82b2dcb 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -19,6 +19,7 @@ use diskann_providers::model::graph::provider::async_::{ use diskann_utils::future::AsyncFriendly; pub(crate) trait FullPrecisionVectorAccessor: Accessor + Send { + #[allow(clippy::manual_async_fn)] fn get_full_precision_vector( &mut self, id: Self::Id, @@ -31,6 +32,7 @@ where D: AsyncFriendly, Ctx: diskann::provider::ExecutionContext, { + #[allow(clippy::manual_async_fn)] fn get_full_precision_vector( &mut self, id: Self::Id, From d1884c385dab639a1a92838b256a3b9c5f31a50c Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 09:21:53 +0530 Subject: [PATCH 19/38] Fix async determinant-diversity: wire real vectors, timing metrics, reduce duplication - Use for<'a, 'b> SearchStrategy bound (user-provided fix) to break HRTB lifetime projection issue in the search_with post-processor constraint - Wire FullPrecisionVectorAccessor shim trait so async det-div post-processor fetches real candidate vectors instead of placeholder distances - Populate QPS/latency metrics in async det-div benchmark path (previously all 'missing') - Extract run_topk_timed helper to eliminate ~100 lines of duplicated loop/timing/recall machinery from DeterminantDiversity::run - Update async-determinant-diversity.json example tag (async-index-build -> graph-index-build) - Fix clippy::manual_async_fn in FullPrecisionVectorAccessor shim trait --- .../src/backend/index/benchmarks.rs | 238 ++++++++++-------- 1 file changed, 128 insertions(+), 110 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index ee0a766e1..34b081271 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -451,6 +451,123 @@ impl Strategy { // Topk // //------// +/// Execute a topk search with a custom per-query search function, collecting timing and recall. +/// +/// Encapsulates the outer thread/run/L loops, the per-rep timing harness, per-query latency +/// collection, and [`SearchResults`] construction. The caller supplies only the actual +/// per-query search as a closure `search_fn(knn_params, query, output) -> Result<()>`. +fn run_topk_timed( + topk: &crate::inputs::graph_index::TopkSearchPhase, + queries: &Matrix, + groundtruth: &Matrix, + mut search_fn: impl FnMut( + diskann::graph::search::Knn, + &[f32], + &mut Vec>, + ) -> anyhow::Result<()>, +) -> anyhow::Result> { + let mut all_results = Vec::new(); + + for threads in &topk.num_threads { + for run in &topk.runs { + for search_l in &run.search_l { + let knn_params = + diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + + let mut all_recalls = Vec::new(); + let mut qps = Vec::with_capacity(topk.reps.get()); + let mut search_latencies = Vec::with_capacity(topk.reps.get()); + let mut mean_latencies = Vec::with_capacity(topk.reps.get()); + let mut p90_latencies = Vec::with_capacity(topk.reps.get()); + let mut p99_latencies = Vec::with_capacity(topk.reps.get()); + + for _ in 0..topk.reps.get() { + let search_start = Instant::now(); + let mut per_query_latencies = Vec::with_capacity(queries.nrows()); + + for query_idx in 0..queries.nrows() { + let query = queries.row(query_idx); + let mut output: Vec> = Vec::new(); + + let query_start = Instant::now(); + search_fn(knn_params, query, &mut output)?; + per_query_latencies.push(query_start.elapsed().as_micros() as u64); + + let gt = groundtruth.row(query_idx); + let mut matches = 0; + for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { + if i >= gt.len() { + break; + } + if gt.contains(&neighbor.id) { + matches += 1; + } + } + all_recalls.push(matches); + } + + let elapsed: MicroSeconds = search_start.elapsed().into(); + let elapsed_secs = elapsed.as_seconds(); + qps.push(if elapsed_secs > 0.0 { + queries.nrows() as f64 / elapsed_secs + } else { + 0.0 + }); + + per_query_latencies.sort_unstable(); + let len = per_query_latencies.len(); + let p90_idx = ((len as f64 * 0.90).ceil() as usize) + .saturating_sub(1) + .min(len.saturating_sub(1)); + let p99_idx = ((len as f64 * 0.99).ceil() as usize) + .saturating_sub(1) + .min(len.saturating_sub(1)); + let mean = if len > 0 { + per_query_latencies.iter().sum::() as f64 / len as f64 + } else { + 0.0 + }; + + search_latencies.push(elapsed); + mean_latencies.push(mean); + p90_latencies.push(MicroSeconds::new( + *per_query_latencies.get(p90_idx).unwrap_or(&0), + )); + p99_latencies.push(MicroSeconds::new( + *per_query_latencies.get(p99_idx).unwrap_or(&0), + )); + } + + let avg_recall = all_recalls.iter().sum::() as f32 + / (queries.nrows() * run.recall_k * topk.reps.get()) as f32; + + all_results.push(SearchResults { + num_tasks: threads.get(), + search_n: run.search_n, + search_l: *search_l, + qps, + search_latencies, + mean_latencies, + p90_latencies, + p99_latencies, + recall: utils::recall::RecallMetrics { + recall_k: run.recall_k, + recall_n: run.search_n, + num_queries: queries.nrows(), + average: avg_recall as f64, + minimum: *all_recalls.iter().min().unwrap_or(&0), + maximum: *all_recalls.iter().max().unwrap_or(&0), + }, + mean_cmps: 0.0, + mean_hops: 0.0, + }); + } + } + } + + Ok(all_results) +} + impl search::Plugin, SearchPhase, Strategy> for plugins::DeterminantDiversity where @@ -480,119 +597,20 @@ where let context = DefaultContext; let det_div = post_processor::DeterminantDiversity::new(power, eta); - let queries: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); + let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile(&topk.queries))?); let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; - let mut all_results = Vec::new(); - - for threads in &topk.num_threads { - for run in &topk.runs { - for search_l in &run.search_l { - let knn_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - - let mut all_recalls = Vec::new(); - - let mut qps = Vec::with_capacity(topk.reps.get()); - let mut search_latencies = Vec::with_capacity(topk.reps.get()); - let mut mean_latencies = Vec::with_capacity(topk.reps.get()); - let mut p90_latencies = Vec::with_capacity(topk.reps.get()); - let mut p99_latencies = Vec::with_capacity(topk.reps.get()); - - for _ in 0..topk.reps.get() { - let search_start = Instant::now(); - let mut per_query_latencies = Vec::with_capacity(queries.nrows()); - - for query_idx in 0..queries.nrows() { - let query = queries.row(query_idx); - let mut output: Vec> = Vec::new(); - - let query_start = Instant::now(); - utils::tokio::block_on(async { - index - .search_with( - knn_params, - &strategy, - det_div, - &context, - query, - &mut output, - ) - .await - })?; - per_query_latencies.push(query_start.elapsed().as_micros() as u64); - - let gt = groundtruth.row(query_idx); - let mut matches = 0; - for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { - if i >= gt.len() { - break; - } - if gt.contains(&neighbor.id) { - matches += 1; - } - } - all_recalls.push(matches); - } - - let elapsed: MicroSeconds = search_start.elapsed().into(); - let elapsed_secs = elapsed.as_seconds(); - if elapsed_secs > 0.0 { - qps.push(queries.nrows() as f64 / elapsed_secs); - } else { - qps.push(0.0); - } + let results = run_topk_timed(topk, &queries, &groundtruth, |params, query, output| { + utils::tokio::block_on(async { + index + .search_with(params, &strategy, det_div, &context, query, output) + .await + }) + .map(|_| ()) + .map_err(anyhow::Error::from) + })?; - per_query_latencies.sort_unstable(); - let len = per_query_latencies.len(); - let p90_idx = ((len as f64 * 0.90).ceil() as usize) - .saturating_sub(1) - .min(len.saturating_sub(1)); - let p99_idx = ((len as f64 * 0.99).ceil() as usize) - .saturating_sub(1) - .min(len.saturating_sub(1)); - - let mean = if len > 0 { - per_query_latencies.iter().sum::() as f64 / len as f64 - } else { - 0.0 - }; - - search_latencies.push(elapsed); - mean_latencies.push(mean); - p90_latencies.push(MicroSeconds::new(*per_query_latencies.get(p90_idx).unwrap_or(&0))); - p99_latencies.push(MicroSeconds::new(*per_query_latencies.get(p99_idx).unwrap_or(&0))); - } - - let avg_recall = all_recalls.iter().sum::() as f32 - / (queries.nrows() * run.recall_k * topk.reps.get()) as f32; - - all_results.push(SearchResults { - num_tasks: threads.get(), - search_n: run.search_n, - search_l: *search_l, - qps, - search_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall: utils::recall::RecallMetrics { - recall_k: run.recall_k, - recall_n: run.search_n, - num_queries: queries.nrows(), - average: avg_recall as f64, - minimum: *all_recalls.iter().min().unwrap_or(&0), - maximum: *all_recalls.iter().max().unwrap_or(&0), - }, - mean_cmps: 0.0, - mean_hops: 0.0, - }); - } - } - } - - Ok(AggregatedSearchResults::Topk(all_results)) + Ok(AggregatedSearchResults::Topk(results)) } } From 701ce8eb044f6fdd1a21e9cf99b0f6dfb1393a4e Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 10:15:51 +0530 Subject: [PATCH 20/38] Fix CI clippy-features spherical plugin errors and apply formatting --- .../src/backend/disk_index/search.rs | 7 ++-- .../src/backend/index/benchmarks.rs | 4 ++- .../post_processor/determinant_diversity.rs | 10 ++---- .../src/backend/index/spherical.rs | 18 +++++------ .../src/inputs/post_processor.rs | 15 ++------- .../src/search/provider/disk_provider.rs | 32 +++++++++++-------- .../determinant_diversity_post_process.rs | 9 ++---- .../src/model/graph/provider/async_/mod.rs | 2 +- 8 files changed, 42 insertions(+), 55 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 299f6403d..1e64583c5 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -302,11 +302,8 @@ where id_chunk.fill(0); dist_chunk.fill(0.0); - for (i, result_item) in search_result - .results - .iter() - .take(base_count) - .enumerate() + for (i, result_item) in + search_result.results.iter().take(base_count).enumerate() { id_chunk[i] = result_item.vertex_id; dist_chunk[i] = result_item.distance; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 34b081271..6491f7ec7 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -597,7 +597,9 @@ where let context = DefaultContext; let det_div = post_processor::DeterminantDiversity::new(power, eta); - let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile(&topk.queries))?); + let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( + &topk.queries, + ))?); let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; let results = run_topk_timed(topk, &queries, &groundtruth, |params, query, output| { diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 0b82b2dcb..29219776e 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -6,15 +6,9 @@ use std::future::Future; use diskann::graph::search_output_buffer::SearchOutputBuffer; -use diskann::{ - error::ANNError, - graph::glue, - neighbor::Neighbor, - provider::Accessor, -}; +use diskann::{error::ANNError, graph::glue, neighbor::Neighbor, provider::Accessor}; use diskann_providers::model::graph::provider::async_::{ - determinant_diversity_post_process, - inmem, + determinant_diversity_post_process, inmem, }; use diskann_utils::future::AsyncFriendly; diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 20e9c0e29..7590f3321 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -86,7 +86,7 @@ mod imp { }, inputs::{ exhaustive, - graph_index::{SearchPhase, SphericalQuantBuild}, + graph_index::{SearchPhase, SearchPhaseKind, SphericalQuantBuild}, }, utils::{ self, datafiles, @@ -363,11 +363,11 @@ mod imp { for search::plugins::Topk { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::Topk::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::Topk.as_str() } fn run( @@ -402,11 +402,11 @@ mod imp { for search::plugins::Range { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::Range::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::Range.as_str() } fn run( @@ -444,11 +444,11 @@ mod imp { for search::plugins::TopkBetaFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkBetaFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::TopkBetaFilter.as_str() } fn run( @@ -498,11 +498,11 @@ mod imp { for search::plugins::TopkMultihopFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkMultihopFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::TopkMultihopFilter.as_str() } fn run( diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs index a4b7bf96d..5ff739321 100644 --- a/diskann-benchmark/src/inputs/post_processor.rs +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -9,10 +9,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "kebab-case")] pub(crate) enum TopkPostProcessor { - DeterminantDiversity { - power: f32, - eta: f32, - }, + DeterminantDiversity { power: f32, eta: f32 }, } impl CheckDeserialization for TopkPostProcessor { @@ -20,16 +17,10 @@ impl CheckDeserialization for TopkPostProcessor { match self { TopkPostProcessor::DeterminantDiversity { power, eta } => { if *power <= 0.0 { - anyhow::bail!( - "determinant-diversity power must be > 0.0, got: {}", - power - ); + anyhow::bail!("determinant-diversity power must be > 0.0, got: {}", power); } if *eta < 0.0 { - anyhow::bail!( - "determinant-diversity eta must be >= 0.0, got: {}", - eta - ); + anyhow::bail!("determinant-diversity eta must be >= 0.0, got: {}", eta); } Ok(()) } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index f828268bd..039499505 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -415,7 +415,9 @@ where for id in candidate_ids { let vector = accessor.scratch.vertex_provider.get_vector(&id)?; - let distance = provider.distance_comparer.evaluate_similarity(query, vector); + let distance = provider + .distance_comparer + .evaluate_similarity(query, vector); let vector_f32 = Data::VectorDataType::as_f32(vector).map_err(Into::into)?; let data = accessor.scratch.vertex_provider.get_associated_data(&id)?; @@ -431,12 +433,14 @@ where self.power, ); - Ok(output.extend(reranked.into_iter().filter_map(|(id, distance)| { - associated_data - .get(&id) - .copied() - .map(|data| ((id, data), distance)) - }))) + Ok( + output.extend(reranked.into_iter().filter_map(|(id, distance)| { + associated_data + .get(&id) + .copied() + .map(|data| ((id, data), distance)) + })), + ) } } @@ -1077,13 +1081,15 @@ where let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); let post_processor = post_processor.map(|processor| match processor { - SearchPostProcessorKind::RerankAndFilter => { - DiskSearchPostProcessor::RerankAndFilter(RerankAndFilter::new(vector_filter.as_ref())) - } + SearchPostProcessorKind::RerankAndFilter => DiskSearchPostProcessor::RerankAndFilter( + RerankAndFilter::new(vector_filter.as_ref()), + ), SearchPostProcessorKind::DeterminantDiversity { power, eta } => { - DiskSearchPostProcessor::DeterminantDiversity( - DeterminantDiversityAndFilter::new(vector_filter.as_ref(), power, eta), - ) + DiskSearchPostProcessor::DeterminantDiversity(DeterminantDiversityAndFilter::new( + vector_filter.as_ref(), + power, + eta, + )) } }); diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index ba1ab9603..4c4bc032d 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -3,10 +3,7 @@ * Licensed under the MIT license. */ -use diskann_vector::{ - MathematicalValue, PureDistanceFunction, - distance::InnerProduct, -}; +use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; pub fn determinant_diversity_post_process( candidates: Vec<(Id, f32, Vec)>, @@ -85,8 +82,8 @@ fn post_process_with_eta_f32( let mut norms_sq = Vec::with_capacity(n); for (_, distance_to_query, v) in &candidates { - let scale = distance_to_similarity(*distance_to_query, distance_range).powf(power) - * inv_sqrt_eta; + let scale = + distance_to_similarity(*distance_to_query, distance_range).powf(power) * inv_sqrt_eta; let residual: Vec = v.iter().map(|&x| x * scale).collect(); let norm_sq = dot_product(&residual, &residual); residuals.push(residual); diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 816b95523..a0bfb3010 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -7,8 +7,8 @@ pub mod experimental; pub mod common; pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; -pub(crate) mod postprocess; mod determinant_diversity_post_process; +pub(crate) mod postprocess; pub use determinant_diversity_post_process::determinant_diversity_post_process; pub mod distances; From 6b935d3e09976eac8babd40cf10dad2a5d92bcf9 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 16:57:34 +0530 Subject: [PATCH 21/38] Add determinant-diversity support for async and disk-index benchmarks - Use for<'a, 'b> SearchStrategy bound to resolve HRTB lifetime mismatch - Wire FullPrecisionVectorAccessor shim trait for async det-div to fetch real vectors - Implement DeterminantDiversity post-processor for async graph-index path - Extract run_topk_timed helper to eliminate ~100 lines of code duplication - Wire post_processor parameter to disk-index search pipeline - Update search parameter handling and result counting for post-processed results - Add TopkPostProcessor input type and necessary imports - Populate QPS/latency metrics in async det-div benchmark path --- .../src/backend/disk_index/search.rs | 7 +- .../src/backend/index/benchmarks.rs | 240 ++++++++++-------- .../post_processor/determinant_diversity.rs | 10 +- .../src/inputs/post_processor.rs | 15 +- .../determinant_diversity_post_process.rs | 9 +- 5 files changed, 140 insertions(+), 141 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 299f6403d..1e64583c5 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -302,11 +302,8 @@ where id_chunk.fill(0); dist_chunk.fill(0.0); - for (i, result_item) in search_result - .results - .iter() - .take(base_count) - .enumerate() + for (i, result_item) in + search_result.results.iter().take(base_count).enumerate() { id_chunk[i] = result_item.vertex_id; dist_chunk[i] = result_item.distance; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index ee0a766e1..6491f7ec7 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -451,6 +451,123 @@ impl Strategy { // Topk // //------// +/// Execute a topk search with a custom per-query search function, collecting timing and recall. +/// +/// Encapsulates the outer thread/run/L loops, the per-rep timing harness, per-query latency +/// collection, and [`SearchResults`] construction. The caller supplies only the actual +/// per-query search as a closure `search_fn(knn_params, query, output) -> Result<()>`. +fn run_topk_timed( + topk: &crate::inputs::graph_index::TopkSearchPhase, + queries: &Matrix, + groundtruth: &Matrix, + mut search_fn: impl FnMut( + diskann::graph::search::Knn, + &[f32], + &mut Vec>, + ) -> anyhow::Result<()>, +) -> anyhow::Result> { + let mut all_results = Vec::new(); + + for threads in &topk.num_threads { + for run in &topk.runs { + for search_l in &run.search_l { + let knn_params = + diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + + let mut all_recalls = Vec::new(); + let mut qps = Vec::with_capacity(topk.reps.get()); + let mut search_latencies = Vec::with_capacity(topk.reps.get()); + let mut mean_latencies = Vec::with_capacity(topk.reps.get()); + let mut p90_latencies = Vec::with_capacity(topk.reps.get()); + let mut p99_latencies = Vec::with_capacity(topk.reps.get()); + + for _ in 0..topk.reps.get() { + let search_start = Instant::now(); + let mut per_query_latencies = Vec::with_capacity(queries.nrows()); + + for query_idx in 0..queries.nrows() { + let query = queries.row(query_idx); + let mut output: Vec> = Vec::new(); + + let query_start = Instant::now(); + search_fn(knn_params, query, &mut output)?; + per_query_latencies.push(query_start.elapsed().as_micros() as u64); + + let gt = groundtruth.row(query_idx); + let mut matches = 0; + for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { + if i >= gt.len() { + break; + } + if gt.contains(&neighbor.id) { + matches += 1; + } + } + all_recalls.push(matches); + } + + let elapsed: MicroSeconds = search_start.elapsed().into(); + let elapsed_secs = elapsed.as_seconds(); + qps.push(if elapsed_secs > 0.0 { + queries.nrows() as f64 / elapsed_secs + } else { + 0.0 + }); + + per_query_latencies.sort_unstable(); + let len = per_query_latencies.len(); + let p90_idx = ((len as f64 * 0.90).ceil() as usize) + .saturating_sub(1) + .min(len.saturating_sub(1)); + let p99_idx = ((len as f64 * 0.99).ceil() as usize) + .saturating_sub(1) + .min(len.saturating_sub(1)); + let mean = if len > 0 { + per_query_latencies.iter().sum::() as f64 / len as f64 + } else { + 0.0 + }; + + search_latencies.push(elapsed); + mean_latencies.push(mean); + p90_latencies.push(MicroSeconds::new( + *per_query_latencies.get(p90_idx).unwrap_or(&0), + )); + p99_latencies.push(MicroSeconds::new( + *per_query_latencies.get(p99_idx).unwrap_or(&0), + )); + } + + let avg_recall = all_recalls.iter().sum::() as f32 + / (queries.nrows() * run.recall_k * topk.reps.get()) as f32; + + all_results.push(SearchResults { + num_tasks: threads.get(), + search_n: run.search_n, + search_l: *search_l, + qps, + search_latencies, + mean_latencies, + p90_latencies, + p99_latencies, + recall: utils::recall::RecallMetrics { + recall_k: run.recall_k, + recall_n: run.search_n, + num_queries: queries.nrows(), + average: avg_recall as f64, + minimum: *all_recalls.iter().min().unwrap_or(&0), + maximum: *all_recalls.iter().max().unwrap_or(&0), + }, + mean_cmps: 0.0, + mean_hops: 0.0, + }); + } + } + } + + Ok(all_results) +} + impl search::Plugin, SearchPhase, Strategy> for plugins::DeterminantDiversity where @@ -480,119 +597,22 @@ where let context = DefaultContext; let det_div = post_processor::DeterminantDiversity::new(power, eta); - let queries: Arc> = - Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); + let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( + &topk.queries, + ))?); let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; - let mut all_results = Vec::new(); - - for threads in &topk.num_threads { - for run in &topk.runs { - for search_l in &run.search_l { - let knn_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - - let mut all_recalls = Vec::new(); - - let mut qps = Vec::with_capacity(topk.reps.get()); - let mut search_latencies = Vec::with_capacity(topk.reps.get()); - let mut mean_latencies = Vec::with_capacity(topk.reps.get()); - let mut p90_latencies = Vec::with_capacity(topk.reps.get()); - let mut p99_latencies = Vec::with_capacity(topk.reps.get()); - - for _ in 0..topk.reps.get() { - let search_start = Instant::now(); - let mut per_query_latencies = Vec::with_capacity(queries.nrows()); - - for query_idx in 0..queries.nrows() { - let query = queries.row(query_idx); - let mut output: Vec> = Vec::new(); - - let query_start = Instant::now(); - utils::tokio::block_on(async { - index - .search_with( - knn_params, - &strategy, - det_div, - &context, - query, - &mut output, - ) - .await - })?; - per_query_latencies.push(query_start.elapsed().as_micros() as u64); - - let gt = groundtruth.row(query_idx); - let mut matches = 0; - for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { - if i >= gt.len() { - break; - } - if gt.contains(&neighbor.id) { - matches += 1; - } - } - all_recalls.push(matches); - } - - let elapsed: MicroSeconds = search_start.elapsed().into(); - let elapsed_secs = elapsed.as_seconds(); - if elapsed_secs > 0.0 { - qps.push(queries.nrows() as f64 / elapsed_secs); - } else { - qps.push(0.0); - } + let results = run_topk_timed(topk, &queries, &groundtruth, |params, query, output| { + utils::tokio::block_on(async { + index + .search_with(params, &strategy, det_div, &context, query, output) + .await + }) + .map(|_| ()) + .map_err(anyhow::Error::from) + })?; - per_query_latencies.sort_unstable(); - let len = per_query_latencies.len(); - let p90_idx = ((len as f64 * 0.90).ceil() as usize) - .saturating_sub(1) - .min(len.saturating_sub(1)); - let p99_idx = ((len as f64 * 0.99).ceil() as usize) - .saturating_sub(1) - .min(len.saturating_sub(1)); - - let mean = if len > 0 { - per_query_latencies.iter().sum::() as f64 / len as f64 - } else { - 0.0 - }; - - search_latencies.push(elapsed); - mean_latencies.push(mean); - p90_latencies.push(MicroSeconds::new(*per_query_latencies.get(p90_idx).unwrap_or(&0))); - p99_latencies.push(MicroSeconds::new(*per_query_latencies.get(p99_idx).unwrap_or(&0))); - } - - let avg_recall = all_recalls.iter().sum::() as f32 - / (queries.nrows() * run.recall_k * topk.reps.get()) as f32; - - all_results.push(SearchResults { - num_tasks: threads.get(), - search_n: run.search_n, - search_l: *search_l, - qps, - search_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall: utils::recall::RecallMetrics { - recall_k: run.recall_k, - recall_n: run.search_n, - num_queries: queries.nrows(), - average: avg_recall as f64, - minimum: *all_recalls.iter().min().unwrap_or(&0), - maximum: *all_recalls.iter().max().unwrap_or(&0), - }, - mean_cmps: 0.0, - mean_hops: 0.0, - }); - } - } - } - - Ok(AggregatedSearchResults::Topk(all_results)) + Ok(AggregatedSearchResults::Topk(results)) } } diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 0b82b2dcb..29219776e 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -6,15 +6,9 @@ use std::future::Future; use diskann::graph::search_output_buffer::SearchOutputBuffer; -use diskann::{ - error::ANNError, - graph::glue, - neighbor::Neighbor, - provider::Accessor, -}; +use diskann::{error::ANNError, graph::glue, neighbor::Neighbor, provider::Accessor}; use diskann_providers::model::graph::provider::async_::{ - determinant_diversity_post_process, - inmem, + determinant_diversity_post_process, inmem, }; use diskann_utils::future::AsyncFriendly; diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs index a4b7bf96d..5ff739321 100644 --- a/diskann-benchmark/src/inputs/post_processor.rs +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -9,10 +9,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "kebab-case")] pub(crate) enum TopkPostProcessor { - DeterminantDiversity { - power: f32, - eta: f32, - }, + DeterminantDiversity { power: f32, eta: f32 }, } impl CheckDeserialization for TopkPostProcessor { @@ -20,16 +17,10 @@ impl CheckDeserialization for TopkPostProcessor { match self { TopkPostProcessor::DeterminantDiversity { power, eta } => { if *power <= 0.0 { - anyhow::bail!( - "determinant-diversity power must be > 0.0, got: {}", - power - ); + anyhow::bail!("determinant-diversity power must be > 0.0, got: {}", power); } if *eta < 0.0 { - anyhow::bail!( - "determinant-diversity eta must be >= 0.0, got: {}", - eta - ); + anyhow::bail!("determinant-diversity eta must be >= 0.0, got: {}", eta); } Ok(()) } diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index ba1ab9603..4c4bc032d 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -3,10 +3,7 @@ * Licensed under the MIT license. */ -use diskann_vector::{ - MathematicalValue, PureDistanceFunction, - distance::InnerProduct, -}; +use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; pub fn determinant_diversity_post_process( candidates: Vec<(Id, f32, Vec)>, @@ -85,8 +82,8 @@ fn post_process_with_eta_f32( let mut norms_sq = Vec::with_capacity(n); for (_, distance_to_query, v) in &candidates { - let scale = distance_to_similarity(*distance_to_query, distance_range).powf(power) - * inv_sqrt_eta; + let scale = + distance_to_similarity(*distance_to_query, distance_range).powf(power) * inv_sqrt_eta; let residual: Vec = v.iter().map(|&x| x * scale).collect(); let norm_sq = dot_product(&residual, &residual); residuals.push(residual); From 8acbba254012c65f465ed22fd596cb35519a2a9d Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 20:06:42 +0530 Subject: [PATCH 22/38] imrpove code coverage --- .../determinant_diversity_post_process.rs | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 4c4bc032d..d9a88a422 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -257,3 +257,123 @@ fn dot_product(a: &[f32], b: &[f32]) -> f32 { >>::evaluate(a, b) .into_inner() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_candidates() { + let result = determinant_diversity_post_process::(Vec::new(), &[1.0, 2.0], 5, 0.5, 1.0); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_empty_query() { + let candidates = vec![(0u32, 0.5, vec![1.0, 2.0])]; + let result = determinant_diversity_post_process(candidates, &[], 5, 0.5, 1.0); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_mismatched_dimensions() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 2.0]), + (1u32, 0.3, vec![1.0]), // Wrong dimension + ]; + let query = &[1.0, 2.0, 3.0]; + let result = determinant_diversity_post_process(candidates, query, 5, 0.5, 1.0); + assert_eq!(result.len(), 0); // All candidates filtered due to dimension mismatch + } + + #[test] + fn test_single_candidate() { + let candidates = vec![(0u32, 0.5, vec![1.0, 2.0])]; + let query = &[1.0, 2.0]; + let result = determinant_diversity_post_process(candidates, query, 5, 0.5, 1.0); + assert_eq!(result.len(), 1); + assert_eq!(result[0].0, 0); + } + + #[test] + fn test_k_larger_than_candidates() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 0.0]), + (1u32, 0.3, vec![0.0, 1.0]), + ]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 10, 0.5, 1.0); + assert_eq!(result.len(), 2); // Should return min(k, candidates.len()) + } + + #[test] + fn test_with_eta_diversity() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.9, 0.1]), + (2u32, 0.3, vec![0.8, 0.2]), + ]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 1.0, 1.0); + + assert_eq!(result.len(), 2); + // Should select based on diversity metric with eta > 0 + assert!(result.iter().all(|(id, _)| *id < 3)); + } + + #[test] + fn test_without_eta_greedy() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.9, 0.1]), + (2u32, 0.3, vec![0.8, 0.2]), + ]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + assert_eq!(result.len(), 2); + // Should select based on greedy orthogonalization (eta == 0) + assert!(result.iter().all(|(id, _)| *id < 3)); + } + + #[test] + fn test_power_parameter() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.0, 1.0]), + ]; + let query = &[1.0, 1.0]; + + // Test with different power values - should still work without panicking + let result1 = determinant_diversity_post_process( + candidates.clone(), + query, + 2, + 0.0, + 1.0, + ); + let result2 = determinant_diversity_post_process( + candidates, + query, + 2, + 0.0, + 2.0, + ); + + assert_eq!(result1.len(), 2); + assert_eq!(result2.len(), 2); + } + + #[test] + fn test_distances_preserved() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 0.0]), + (1u32, 0.3, vec![0.0, 1.0]), + ]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + // Verify that distances are preserved from input + assert!(result.iter().all(|(_, dist)| *dist == 0.5 || *dist == 0.3)); + } +} From a48e2555dcf104b0281d68fee3f37b9977acf48b Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 20:12:31 +0530 Subject: [PATCH 23/38] minor fix Co-authored-by: Copilot --- .../post_processor/determinant_diversity.rs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs index 29219776e..aa915f0fd 100644 --- a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -13,7 +13,6 @@ use diskann_providers::model::graph::provider::async_::{ use diskann_utils::future::AsyncFriendly; pub(crate) trait FullPrecisionVectorAccessor: Accessor + Send { - #[allow(clippy::manual_async_fn)] fn get_full_precision_vector( &mut self, id: Self::Id, @@ -26,17 +25,11 @@ where D: AsyncFriendly, Ctx: diskann::provider::ExecutionContext, { - #[allow(clippy::manual_async_fn)] - fn get_full_precision_vector( - &mut self, - id: Self::Id, - ) -> impl Future, ANNError>> + Send { - async move { - self.get_element(id) - .await - .map(|vector| vector.to_vec()) - .map_err(Into::into) - } + async fn get_full_precision_vector(&mut self, id: Self::Id) -> Result, ANNError> { + self.get_element(id) + .await + .map(|vector| vector.to_vec()) + .map_err(Into::into) } } From 468b5d2522d51235965c4bdacb57753d40b71d28 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 6 May 2026 20:39:25 +0530 Subject: [PATCH 24/38] cargo fmt --- .../determinant_diversity_post_process.rs | 44 +++++-------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index d9a88a422..d1793bad6 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -264,7 +264,8 @@ mod tests { #[test] fn test_empty_candidates() { - let result = determinant_diversity_post_process::(Vec::new(), &[1.0, 2.0], 5, 0.5, 1.0); + let result = + determinant_diversity_post_process::(Vec::new(), &[1.0, 2.0], 5, 0.5, 1.0); assert_eq!(result.len(), 0); } @@ -297,10 +298,7 @@ mod tests { #[test] fn test_k_larger_than_candidates() { - let candidates = vec![ - (0u32, 0.5, vec![1.0, 0.0]), - (1u32, 0.3, vec![0.0, 1.0]), - ]; + let candidates = vec![(0u32, 0.5, vec![1.0, 0.0]), (1u32, 0.3, vec![0.0, 1.0])]; let query = &[1.0, 1.0]; let result = determinant_diversity_post_process(candidates, query, 10, 0.5, 1.0); assert_eq!(result.len(), 2); // Should return min(k, candidates.len()) @@ -315,7 +313,7 @@ mod tests { ]; let query = &[1.0, 1.0]; let result = determinant_diversity_post_process(candidates, query, 2, 1.0, 1.0); - + assert_eq!(result.len(), 2); // Should select based on diversity metric with eta > 0 assert!(result.iter().all(|(id, _)| *id < 3)); @@ -330,7 +328,7 @@ mod tests { ]; let query = &[1.0, 1.0]; let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); - + assert_eq!(result.len(), 2); // Should select based on greedy orthogonalization (eta == 0) assert!(result.iter().all(|(id, _)| *id < 3)); @@ -338,41 +336,23 @@ mod tests { #[test] fn test_power_parameter() { - let candidates = vec![ - (0u32, 0.1, vec![1.0, 0.0]), - (1u32, 0.2, vec![0.0, 1.0]), - ]; + let candidates = vec![(0u32, 0.1, vec![1.0, 0.0]), (1u32, 0.2, vec![0.0, 1.0])]; let query = &[1.0, 1.0]; - + // Test with different power values - should still work without panicking - let result1 = determinant_diversity_post_process( - candidates.clone(), - query, - 2, - 0.0, - 1.0, - ); - let result2 = determinant_diversity_post_process( - candidates, - query, - 2, - 0.0, - 2.0, - ); - + let result1 = determinant_diversity_post_process(candidates.clone(), query, 2, 0.0, 1.0); + let result2 = determinant_diversity_post_process(candidates, query, 2, 0.0, 2.0); + assert_eq!(result1.len(), 2); assert_eq!(result2.len(), 2); } #[test] fn test_distances_preserved() { - let candidates = vec![ - (0u32, 0.5, vec![1.0, 0.0]), - (1u32, 0.3, vec![0.0, 1.0]), - ]; + let candidates = vec![(0u32, 0.5, vec![1.0, 0.0]), (1u32, 0.3, vec![0.0, 1.0])]; let query = &[1.0, 1.0]; let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); - + // Verify that distances are preserved from input assert!(result.iter().all(|(_, dist)| *dist == 0.5 || *dist == 0.3)); } From d9e66ba0eee6560613488e303b3a4d32bf11c187 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 12:58:29 +0530 Subject: [PATCH 25/38] WIP: Benchmarks refactoring - threading fix, rich struct params, as_str() methods --- .../src/backend/index/benchmarks.rs | 68 +++++++++++-------- .../src/backend/index/search/plugins.rs | 60 +++++++++++++++- diskann-benchmark/src/utils/tokio.rs | 7 ++ 3 files changed, 105 insertions(+), 30 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 6491f7ec7..75e14e4a4 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -47,7 +47,7 @@ use crate::{ streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, inputs::graph_index::{ - DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, SearchPhaseKind, + DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, }, utils::{ self, @@ -453,14 +453,18 @@ impl Strategy { /// Execute a topk search with a custom per-query search function, collecting timing and recall. /// -/// Encapsulates the outer thread/run/L loops, the per-rep timing harness, per-query latency +/// This function properly creates tokio runtimes with the specified thread counts to accurately +/// measure performance under different concurrency levels. Each thread count is benchmarked +/// independently with its own runtime to avoid misleading single-threaded measurements. +/// +/// Encapsulates the thread/run/L loops, the per-rep timing harness, per-query latency /// collection, and [`SearchResults`] construction. The caller supplies only the actual /// per-query search as a closure `search_fn(knn_params, query, output) -> Result<()>`. fn run_topk_timed( topk: &crate::inputs::graph_index::TopkSearchPhase, queries: &Matrix, groundtruth: &Matrix, - mut search_fn: impl FnMut( + search_fn: impl Fn( diskann::graph::search::Knn, &[f32], &mut Vec>, @@ -469,6 +473,11 @@ fn run_topk_timed( let mut all_results = Vec::new(); for threads in &topk.num_threads { + // Create a tokio runtime with the specified number of threads. + // This ensures that searches are actually executed with the desired concurrency level, + // not just recorded with that thread count while running serially. + let rt = utils::tokio::runtime(threads.get())?; + for run in &topk.runs { for search_l in &run.search_l { let knn_params = @@ -485,26 +494,29 @@ fn run_topk_timed( let search_start = Instant::now(); let mut per_query_latencies = Vec::with_capacity(queries.nrows()); - for query_idx in 0..queries.nrows() { - let query = queries.row(query_idx); - let mut output: Vec> = Vec::new(); - - let query_start = Instant::now(); - search_fn(knn_params, query, &mut output)?; - per_query_latencies.push(query_start.elapsed().as_micros() as u64); - - let gt = groundtruth.row(query_idx); - let mut matches = 0; - for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { - if i >= gt.len() { - break; - } - if gt.contains(&neighbor.id) { - matches += 1; + rt.block_on(async { + for query_idx in 0..queries.nrows() { + let query = queries.row(query_idx); + let mut output: Vec> = Vec::new(); + + let query_start = Instant::now(); + search_fn(knn_params, query, &mut output)?; + per_query_latencies.push(query_start.elapsed().as_micros() as u64); + + let gt = groundtruth.row(query_idx); + let mut matches = 0; + for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { + if i >= gt.len() { + break; + } + if gt.contains(&neighbor.id) { + matches += 1; + } } + all_recalls.push(matches); } - all_recalls.push(matches); - } + Ok::<(), anyhow::Error>(()) + })?; let elapsed: MicroSeconds = search_start.elapsed().into(); let elapsed_secs = elapsed.as_seconds(); @@ -582,7 +594,7 @@ where } fn kind(&self) -> &'static str { - "topk + determinant-diversity" + plugins::DeterminantDiversity::as_str() } fn run( @@ -591,11 +603,11 @@ where phase: &SearchPhase, _strategy: &Strategy, ) -> anyhow::Result { - let (topk, power, eta) = plugins::DeterminantDiversity::get(phase)?; + let (topk, params) = plugins::DeterminantDiversity::get(phase)?; let strategy = common::FullPrecision; let context = DefaultContext; - let det_div = post_processor::DeterminantDiversity::new(power, eta); + let det_div = post_processor::DeterminantDiversity::new(params.power, params.eta); let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( &topk.queries, @@ -626,7 +638,7 @@ where } fn kind(&self) -> &'static str { - SearchPhaseKind::Topk.as_str() + plugins::Topk::as_str() } fn run( @@ -669,7 +681,7 @@ where } fn kind(&self) -> &'static str { - SearchPhaseKind::Range.as_str() + plugins::Range::as_str() } fn run( @@ -713,7 +725,7 @@ where } fn kind(&self) -> &'static str { - SearchPhaseKind::TopkBetaFilter.as_str() + plugins::TopkBetaFilter::as_str() } fn run( @@ -772,7 +784,7 @@ where } fn kind(&self) -> &'static str { - SearchPhaseKind::TopkMultihopFilter.as_str() + plugins::TopkMultihopFilter::as_str() } fn run( diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 66ea4e38a..30dbeda91 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -155,12 +155,51 @@ impl Topk { .ok() .is_some_and(|topk| topk.post_processor.is_none()) } + + pub(crate) const fn as_str() -> &'static str { + "topk" + } } /// A search plugin for determinant-diversity top-k post-processing. #[derive(Debug, Clone, Copy)] pub(crate) struct DeterminantDiversity; +/// Parameters for determinant-diversity post-processing. +/// +/// This struct encapsulates the validated parameters needed for determinant-diversity +/// reranking. The parameters are validated at parse time to ensure correct values. +#[derive(Debug, Clone, Copy)] +pub(crate) struct DeterminantDiversityParams { + /// The power parameter controlling relevance vs. diversity trade-off. + /// + /// Higher values prefer relevance over diversity. + /// Must be > 0.0. + pub power: f32, + /// The ridge regularization parameter for numerical stability. + /// + /// Higher values provide more numerical robustness but bias toward relevance. + /// Must be >= 0.0. + pub eta: f32, +} + +impl DeterminantDiversityParams { + /// Create new determinant-diversity parameters with validation. + /// + /// # Errors + /// + /// Returns an error if `power <= 0.0` or `eta < 0.0`. + pub(crate) fn new(power: f32, eta: f32) -> anyhow::Result { + if power <= 0.0 { + anyhow::bail!("determinant-diversity power must be > 0.0, got: {}", power); + } + if eta < 0.0 { + anyhow::bail!("determinant-diversity eta must be >= 0.0, got: {}", eta); + } + Ok(Self { power, eta }) + } +} + impl DeterminantDiversity { pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase @@ -170,11 +209,16 @@ impl DeterminantDiversity { .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) } - pub(crate) fn get(phase: &SearchPhase) -> anyhow::Result<(&TopkSearchPhase, f32, f32)> { + pub(crate) const fn as_str() -> &'static str { + "topk + determinant-diversity" + } + + pub(crate) fn get(phase: &SearchPhase) -> anyhow::Result<(&TopkSearchPhase, DeterminantDiversityParams)> { let topk = phase.as_topk()?; match topk.post_processor.as_ref() { Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { - Ok((topk, *power, *eta)) + let params = DeterminantDiversityParams::new(*power, *eta)?; + Ok((topk, params)) } _ => Err(anyhow::anyhow!( "determinant-diversity plugin selected for non determinant-diversity input", @@ -191,6 +235,10 @@ impl Range { pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase.as_range().is_ok() } + + pub(crate) const fn as_str() -> &'static str { + "range" + } } /// A search plugin for beta-filtered search. @@ -201,6 +249,10 @@ impl TopkBetaFilter { pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase.as_topk_beta_filter().is_ok() } + + pub(crate) const fn as_str() -> &'static str { + "topk + beta filter" + } } /// A search plugin for multi-hop filtered search. @@ -211,4 +263,8 @@ impl TopkMultihopFilter { pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase.as_topk_multihop_filter().is_ok() } + + pub(crate) const fn as_str() -> &'static str { + "topk + multihop filter" + } } diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index 72dbeb918..cd07e593a 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -11,3 +11,10 @@ pub(crate) fn block_on(future: F) -> F::Output { .expect("current thread runtime initialization failed") .block_on(future) } + +/// Create a multi-threaded runtime with the specified number of threads. +pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { + Ok(tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads) + .build()?) +} From 9cda6e022708235677d3962ebbcf1edae49920ae Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:04:59 +0530 Subject: [PATCH 26/38] Add post-processor generic parameter to KNN struct in benchmark-core - Add 4th type parameter PP to KNN with default of () - Add with_postprocessor() constructor for post-processor support - Keep new() for standard KNN without post-processing - Add accessor methods for index, queries, strategy, post_processor - Update Search impl to work with generic KNN - Maintains backward compatibility - all existing code continues to work - Provides foundation for DeterminantDiversity plugin refactoring --- .../src/search/graph/knn.rs | 66 +++++++++++++++++-- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 2149842a8..19983431d 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -22,7 +22,7 @@ use crate::{ }; /// A built-in helper for benchmarking the K-nearest neighbors method -/// [`graph::DiskANNIndex::search`]. +/// [`graph::DiskANNIndex::search`] with optional post-processing support. /// /// This is intended to be used in conjunction with [`search::search`] or /// [`search::search_all`] and provides some basic additional metrics for @@ -31,21 +31,29 @@ use crate::{ /// /// The provided implementation of [`Search`] accepts [`graph::search::Knn`] /// and returns [`Metrics`] as additional output. +/// +/// # Type Parameters +/// +/// - `DP`: The data provider type +/// - `T`: The query element type +/// - `S`: The search strategy type +/// - `PP`: Optional post-processor type (defaults to `()` for no post-processing) #[derive(Debug)] -pub struct KNN +pub struct KNN where DP: provider::DataProvider, { index: Arc>, queries: Arc>, strategy: Strategy, + post_processor: Option, } -impl KNN +impl KNN where DP: provider::DataProvider, { - /// Construct a new [`KNN`] searcher. + /// Construct a new [`KNN`] searcher without post-processing. /// /// If `strategy` is one of the container variants of [`Strategy`], its length /// must match the number of rows in `queries`. If this is the case, then the @@ -67,8 +75,56 @@ where index, queries, strategy, + post_processor: None, + })) + } +} + +impl KNN +where + DP: provider::DataProvider, +{ + /// Construct a new [`KNN`] searcher with post-processing. + /// + /// # Errors + /// + /// Returns an error if the number of elements in `strategy` is not compatible with + /// the number of rows in `queries`. + pub fn with_postprocessor( + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: PP, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor: Some(post_processor), })) } + + /// Access the index. + pub fn index(&self) -> &Arc> { + &self.index + } + + /// Access the queries. + pub fn queries(&self) -> &Arc> { + &self.queries + } + + /// Access the strategy. + pub fn strategy(&self) -> &Strategy { + &self.strategy + } + + /// Access the post-processor, if present. + pub fn post_processor(&self) -> &Option { + &self.post_processor + } } /// Additional metrics collected during [`KNN`] search. @@ -85,7 +141,7 @@ pub struct Metrics { pub hops: u32, } -impl Search for KNN +impl Search for KNN where DP: provider::DataProvider, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, From 8b076767b1857bf2b07b90bcc2c9643cd600fbc3 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:12:05 +0530 Subject: [PATCH 27/38] Task 5: Create unified validation struct for DeterminantDiversity in diskann-providers - Add diskann-providers/src/post_processor.rs with centralized DeterminantDiversityParams - Define DeterminantDiversityError for precise error reporting - Provide unified validation: power > 0.0, eta >= 0.0 - Add accessor methods and Display impl - Import tests included - Update diskann-benchmark to use shared type from diskann-providers - Remove duplicate local definition in plugins.rs - Maintains backward compatibility with anyhow::Error conversion - Foundation for sharing post-processor logic across subsystems (benchmark, disk-index) --- .../src/backend/index/search/plugins.rs | 39 +---- diskann-providers/src/lib.rs | 2 + diskann-providers/src/post_processor.rs | 138 ++++++++++++++++++ 3 files changed, 143 insertions(+), 36 deletions(-) create mode 100644 diskann-providers/src/post_processor.rs diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 30dbeda91..2b2d50a47 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -36,6 +36,7 @@ use std::sync::Arc; use diskann::{graph::DiskANNIndex, provider::DataProvider}; use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; +use diskann_providers::post_processor::DeterminantDiversityParams; use crate::{ backend::index::result::AggregatedSearchResults, @@ -165,41 +166,6 @@ impl Topk { #[derive(Debug, Clone, Copy)] pub(crate) struct DeterminantDiversity; -/// Parameters for determinant-diversity post-processing. -/// -/// This struct encapsulates the validated parameters needed for determinant-diversity -/// reranking. The parameters are validated at parse time to ensure correct values. -#[derive(Debug, Clone, Copy)] -pub(crate) struct DeterminantDiversityParams { - /// The power parameter controlling relevance vs. diversity trade-off. - /// - /// Higher values prefer relevance over diversity. - /// Must be > 0.0. - pub power: f32, - /// The ridge regularization parameter for numerical stability. - /// - /// Higher values provide more numerical robustness but bias toward relevance. - /// Must be >= 0.0. - pub eta: f32, -} - -impl DeterminantDiversityParams { - /// Create new determinant-diversity parameters with validation. - /// - /// # Errors - /// - /// Returns an error if `power <= 0.0` or `eta < 0.0`. - pub(crate) fn new(power: f32, eta: f32) -> anyhow::Result { - if power <= 0.0 { - anyhow::bail!("determinant-diversity power must be > 0.0, got: {}", power); - } - if eta < 0.0 { - anyhow::bail!("determinant-diversity eta must be >= 0.0, got: {}", eta); - } - Ok(Self { power, eta }) - } -} - impl DeterminantDiversity { pub(crate) fn is_match(phase: &SearchPhase) -> bool { phase @@ -217,7 +183,8 @@ impl DeterminantDiversity { let topk = phase.as_topk()?; match topk.post_processor.as_ref() { Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { - let params = DeterminantDiversityParams::new(*power, *eta)?; + let params = DeterminantDiversityParams::new(*power, *eta) + .map_err(|e| anyhow::anyhow!("{}", e))?; Ok((topk, params)) } _ => Err(anyhow::anyhow!( diff --git a/diskann-providers/src/lib.rs b/diskann-providers/src/lib.rs index 0edeb2625..8b0aa43ba 100644 --- a/diskann-providers/src/lib.rs +++ b/diskann-providers/src/lib.rs @@ -14,6 +14,8 @@ pub mod model; pub mod common; +pub mod post_processor; + pub mod index; pub mod storage; diff --git a/diskann-providers/src/post_processor.rs b/diskann-providers/src/post_processor.rs new file mode 100644 index 000000000..db593c87e --- /dev/null +++ b/diskann-providers/src/post_processor.rs @@ -0,0 +1,138 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Unified post-processor parameter types with validation. +//! +//! This module provides centralized definitions and validation for post-processor +//! parameters like Determinant-Diversity, ensuring consistent validation across +//! different search contexts (in-memory, disk, benchmarking). + +use std::fmt; + +/// Parameters for Determinant-Diversity post-processor with validation. +/// +/// Determinant-Diversity is a diversity-promoting reranking algorithm that takes +/// relevance-ranked neighbors and reorders them to maximize geometric diversity +/// while maintaining relevance. +/// +/// # Parameters +/// +/// - `power`: Relevance weighting exponent. Controls the emphasis on maintaining +/// relevance scores from the original search. Must be > 0.0. +/// +/// - `eta`: Numerical stability parameter for ridge-regularization. Controls the +/// trade-off between exact determinant computation (eta=0) and numerical robustness +/// (eta>0). Must be >= 0.0. +/// +/// # Errors +/// +/// Construction fails if: +/// - `power <= 0.0` (invalid power weighting) +/// - `eta < 0.0` (negative stability parameter) +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversityParams { + /// Relevance weighting exponent. Must be > 0.0. + pub power: f32, + /// Numerical stability parameter. Must be >= 0.0. + pub eta: f32, +} + +impl DeterminantDiversityParams { + /// Create and validate new Determinant-Diversity parameters. + /// + /// # Errors + /// + /// Returns an error if validation fails: + /// - `power <= 0.0`: invalid relevance weighting + /// - `eta < 0.0`: invalid numerical stability parameter + pub fn new(power: f32, eta: f32) -> Result { + if power <= 0.0 { + return Err(DeterminantDiversityError::InvalidPower(power)); + } + if eta < 0.0 { + return Err(DeterminantDiversityError::InvalidEta(eta)); + } + Ok(Self { power, eta }) + } + + /// Get power parameter. + #[inline] + pub fn power(&self) -> f32 { + self.power + } + + /// Get eta parameter. + #[inline] + pub fn eta(&self) -> f32 { + self.eta + } +} + +impl fmt::Display for DeterminantDiversityParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "DeterminantDiversity(power={}, eta={})", + self.power, self.eta + ) + } +} + +/// Validation error for Determinant-Diversity parameters. +#[derive(Debug, Clone)] +pub enum DeterminantDiversityError { + /// Power parameter <= 0.0 + InvalidPower(f32), + /// Eta parameter < 0.0 + InvalidEta(f32), +} + +impl fmt::Display for DeterminantDiversityError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidPower(p) => write!( + f, + "determinant-diversity power must be > 0.0, got: {}", + p + ), + Self::InvalidEta(e) => write!( + f, + "determinant-diversity eta must be >= 0.0, got: {}", + e + ), + } + } +} + +impl std::error::Error for DeterminantDiversityError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_params() { + assert!(DeterminantDiversityParams::new(1.0, 0.0).is_ok()); + assert!(DeterminantDiversityParams::new(0.5, 1.5).is_ok()); + assert!(DeterminantDiversityParams::new(2.0, 0.1).is_ok()); + } + + #[test] + fn test_invalid_power() { + assert!(DeterminantDiversityParams::new(0.0, 1.0).is_err()); + assert!(DeterminantDiversityParams::new(-1.0, 1.0).is_err()); + } + + #[test] + fn test_invalid_eta() { + assert!(DeterminantDiversityParams::new(1.0, -0.1).is_err()); + } + + #[test] + fn test_display() { + let params = DeterminantDiversityParams::new(1.5, 0.5).unwrap(); + assert_eq!(params.to_string(), "DeterminantDiversity(power=1.5, eta=0.5)"); + } +} From 8ce2130365162f0dd288c182c3d663136572b79d Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:13:47 +0530 Subject: [PATCH 28/38] Task 6: Add module-level documentation to determinant_diversity_post_process.rs - Document algorithm overview, parameters (power, eta), variants, time complexity - Clarify distinction between eta=0 (exact) and eta>0 (ridge-regularized) paths - Reference asymptotic complexity O(m^3) due to determinant computation --- .../determinant_diversity_post_process.rs | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index d1793bad6..3b2d8c914 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -3,6 +3,49 @@ * Licensed under the MIT license. */ +//! Determinant-Diversity post-processing for search results. +//! +//! This module implements the Determinant-Diversity algorithm for diversity-promoting +//! reranking of approximate nearest neighbor search results. The algorithm takes +//! relevance-ranked candidates and reorders them to maximize geometric diversity +//! while maintaining relevance to the original query. +//! +//! # Algorithm Overview +//! +//! Determinant-Diversity selects a diverse subset from an initial set of candidates +//! by iteratively choosing points that maximize the determinant of the distance matrix. +//! This creates a diverse set that is both relevant to the query and geometrically spread out. +//! +//! # Parameters +//! +//! - **power**: Relevance weighting exponent (must be > 0.0). Controls the emphasis on +//! maintaining relevance scores from the initial search. Higher values prefer relevance +//! over diversity. +//! +//! - **eta**: Numerical stability parameter (must be >= 0.0). Used for ridge regularization: +//! - `eta = 0`: Exact determinant computation (can be numerically unstable for some inputs) +//! - `eta > 0`: Ridge-regularized computation for improved numerical stability +//! +//! # Variants +//! +//! The module provides two implementations: +//! +//! - `post_process_with_eta_f32()`: Uses ridge regularization for numerical stability +//! - `post_process_without_eta_f32()`: Computes exact determinants (faster but less stable) +//! +//! These are selected automatically based on the eta parameter value. +//! +//! # Time Complexity +//! +//! O(m³) where m is the number of candidates, due to determinant computation. +//! In practice, m is typically small (search returns hundreds of candidates, +//! but only top-k ≪ m are selected). +//! +//! # References +//! +//! The algorithm is based on diversity-promoting ranking methods for nearest neighbor search, +//! as used in approximate nearest neighbor indices like DiskANN. + use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; pub fn determinant_diversity_post_process( From 309ecb3bd490a6fbc20b6feec64a0cf17b101234 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:18:01 +0530 Subject: [PATCH 29/38] Task 7: Add algorithmic tests to determinant_diversity_post_process.rs - test_diversity_selects_orthogonal_candidates: Verify orthogonal pair chosen over parallel - test_diversity_selects_orthogonal_candidates_with_eta: Same test for eta>0 variant - test_high_power_prefers_closer_candidates: Verify relevance weighting with high power - test_equal_distances: Verify stable behavior with equal-distance candidates - test_eta_zero_is_greedy_path: Confirm eta=0.0 routes to greedy orthogonalization Note: Diversity tests use equal distances to eliminate relevance weighting interference, making them pure tests of the geometric diversity property. --- .../determinant_diversity_post_process.rs | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 3b2d8c914..3953df782 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -399,4 +399,90 @@ mod tests { // Verify that distances are preserved from input assert!(result.iter().all(|(_, dist)| *dist == 0.5 || *dist == 0.3)); } + + /// Verify that diversity is actually promoted: when candidates lie along orthogonal + /// directions, a 2-element diverse subset should choose orthogonal pairs over similar ones. + /// + /// Using equal distances ensures pure diversity drives selection without relevance weighting. + #[test] + fn test_diversity_selects_orthogonal_candidates() { + // Three candidates with equal distance: two very similar (nearly parallel) and one orthogonal. + // Equal distances remove relevance weighting, so pure diversity drives selection. + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), // along x + (1u32, 0.1, vec![0.0, 1.0, 0.0]), // along y - orthogonal to 0 + (2u32, 0.1, vec![0.99, 0.01, 0.0]), // nearly parallel to 0 + ]; + let query = &[1.0, 1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + // Should select 2 candidates + assert_eq!(result.len(), 2); + // The diverse pair is (0, 1) - orthogonal. Candidate 2 is redundant with 0. + let ids: Vec = result.iter().map(|(id, _)| *id).collect(); + assert!(ids.contains(&0), "Expected candidate 0 to be selected"); + assert!(ids.contains(&1), "Expected candidate 1 (orthogonal) to be selected, not redundant candidate 2"); + } + + /// Verify eta variant selects the same k results. + #[test] + fn test_diversity_selects_orthogonal_candidates_with_eta() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), + (1u32, 0.1, vec![0.0, 1.0, 0.0]), + (2u32, 0.1, vec![0.99, 0.01, 0.0]), + ]; + let query = &[1.0, 1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.5, 1.0); + + assert_eq!(result.len(), 2); + let ids: Vec = result.iter().map(|(id, _)| *id).collect(); + assert!(ids.contains(&0), "Expected candidate 0 to be selected"); + assert!(ids.contains(&1), "Expected candidate 1 (orthogonal) to be selected"); + } + + /// Verify power=high weights nearby candidates (distance=0.1) more strongly than far ones. + #[test] + fn test_high_power_prefers_closer_candidates() { + // Two orthogonal candidates: one close, one far + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), // close to query + (1u32, 0.9, vec![0.0, 1.0]), // far from query + ]; + let query = &[1.0, 0.0]; + + // With high power, relevance is heavily weighted so the closest candidate dominates + let result = determinant_diversity_post_process(candidates.clone(), query, 1, 0.0, 10.0); + assert_eq!(result.len(), 1); + // Closest candidate should be preferred due to high power weighting + assert_eq!(result[0].0, 0, "Closest candidate should be selected with high power"); + } + + /// Verify that distance-to-similarity conversion handles equal distances gracefully. + #[test] + fn test_equal_distances() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 0.0]), + (1u32, 0.5, vec![0.0, 1.0]), // same distance as 0 + ]; + let query = &[1.0, 0.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + // Should still return candidates without panicking + assert_eq!(result.len(), 2); + } + + /// Test eta=0 exactly matches greedy orthogonalization path. + #[test] + fn test_eta_zero_is_greedy_path() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.0, 1.0]), + (2u32, 0.3, vec![0.5, 0.5]), + ]; + let query = &[1.0, 1.0]; + // eta=0.0 must invoke greedy path, not ridge-regularized + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + assert_eq!(result.len(), 2); + } } From 3be546afd103a49951430f9d51478bb86b352638 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:20:10 +0530 Subject: [PATCH 30/38] Task 8: Merge similar routines in determinant_diversity_post_process.rs - Remove duplicate post_process_with_eta_f32 and post_process_greedy_orthogonalization_f32 - Unify into single greedy_orthogonal_select() with inv_sqrt_eta parameter - inv_sqrt_eta=1.0/sqrt(eta) for ridge-regularized path (eta>0) - inv_sqrt_eta=1.0 for exact greedy path (eta=0) - Eliminates ~80 lines of near-duplicate code - All 14 existing tests continue to pass --- .../determinant_diversity_post_process.rs | 133 ++++-------------- 1 file changed, 27 insertions(+), 106 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 3953df782..50553b328 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -89,29 +89,39 @@ pub fn determinant_diversity_post_process( (min_distance, max_distance) }; - if determinant_diversity_eta > 0.0 { - post_process_with_eta_f32( - candidates, - k, - determinant_diversity_eta, - determinant_diversity_power, - distance_range, - ) + // For eta=0, the inv_sqrt_eta factor is 1.0 (greedy orthogonalization without regularization). + // For eta>0, the factor scales residuals for ridge-regularized determinant computation. + let inv_sqrt_eta = if determinant_diversity_eta > 0.0 { + 1.0 / determinant_diversity_eta.sqrt() } else { - post_process_greedy_orthogonalization_f32( - candidates, - k, - determinant_diversity_power, - distance_range, - ) - } + 1.0 + }; + + greedy_orthogonal_select( + candidates, + k, + determinant_diversity_power, + inv_sqrt_eta, + distance_range, + ) } -fn post_process_with_eta_f32( +/// Core greedy selection algorithm for Determinant-Diversity. +/// +/// Iteratively selects the candidate with the largest residual norm after projecting +/// out previously selected candidates. The `inv_sqrt_eta` parameter controls the +/// ridge-regularization scaling: +/// +/// - `inv_sqrt_eta = 1.0`: exact greedy orthogonalization (eta=0 case) +/// - `inv_sqrt_eta = 1/sqrt(eta)`: ridge-regularized variant for numerical stability +/// +/// This unified implementation replaces two nearly-identical functions that only +/// differed in whether the scale factor included the eta term. +fn greedy_orthogonal_select( candidates: Vec<(Id, f32, Vec)>, k: usize, - eta: f32, power: f32, + inv_sqrt_eta: f32, distance_range: (f32, f32), ) -> Vec<(Id, f32)> { let n = candidates.len(); @@ -120,7 +130,6 @@ fn post_process_with_eta_f32( return Vec::new(); } - let inv_sqrt_eta = 1.0 / eta.sqrt(); let mut residuals = Vec::with_capacity(n); let mut norms_sq = Vec::with_capacity(n); @@ -199,94 +208,6 @@ fn post_process_with_eta_f32( .collect() } -fn post_process_greedy_orthogonalization_f32( - candidates: Vec<(Id, f32, Vec)>, - k: usize, - power: f32, - distance_range: (f32, f32), -) -> Vec<(Id, f32)> { - let n = candidates.len(); - let k = k.min(n); - if k == 0 { - return Vec::new(); - } - - let mut residuals = Vec::with_capacity(n); - let mut norms_sq = Vec::with_capacity(n); - - for (_, distance_to_query, v) in &candidates { - let scale = distance_to_similarity(*distance_to_query, distance_range).powf(power); - let residual: Vec = v.iter().map(|&x| x * scale).collect(); - let norm_sq = dot_product(&residual, &residual); - residuals.push(residual); - norms_sq.push(norm_sq); - } - - let mut available = vec![true; n]; - let mut selected = Vec::with_capacity(k); - let mut projections = vec![0.0f32; n]; - - for _ in 0..k { - let best = available - .iter() - .enumerate() - .filter(|&(_, &avail)| avail) - .max_by(|(i, _), (j, _)| { - norms_sq[*i] - .partial_cmp(&norms_sq[*j]) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let Some((best_index, _)) = best else { - break; - }; - - let best_norm_sq = norms_sq[best_index]; - selected.push(best_index); - available[best_index] = false; - - if selected.len() == k { - break; - } - - if best_norm_sq <= 0.0 { - continue; - } - - let inv_norm_sq_star = 1.0 / best_norm_sq; - let r_star_copy = residuals[best_index].clone(); - - for j in 0..n { - if !available[j] { - projections[j] = 0.0; - } else { - projections[j] = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; - } - } - - for j in 0..n { - if !available[j] { - continue; - } - - let projection = projections[j]; - for (residual, &star) in residuals[j].iter_mut().zip(r_star_copy.iter()) { - *residual -= projection * star; - } - - norms_sq[j] = (norms_sq[j] - projection * projection * best_norm_sq).max(0.0); - } - } - - selected - .iter() - .map(|&idx| { - let (id, dist, _) = &candidates[idx]; - (*id, *dist) - }) - .collect() -} - fn distance_to_similarity(distance: f32, distance_range: (f32, f32)) -> f32 { let (min_distance, max_distance) = distance_range; let span = (max_distance - min_distance).max(f32::EPSILON); From f92c3bed040d31316946e194595df2db573c75cb Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:23:00 +0530 Subject: [PATCH 31/38] Task 9: Replace Vec> with Matrix for residuals storage - Use contiguous Matrix allocation instead of separate Vec per candidate - Reduces heap allocations from O(n) to O(1) where n = num candidates - Improves cache locality during orthogonalization iteration - Access residuals via row(i) / row_mut(i) instead of indexing Vec of Vecs - All 14 tests continue to pass --- .../determinant_diversity_post_process.rs | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 50553b328..8a6e2eb58 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -46,6 +46,7 @@ //! The algorithm is based on diversity-promoting ranking methods for nearest neighbor search, //! as used in approximate nearest neighbor indices like DiskANN. +use diskann_utils::views::Matrix; use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; pub fn determinant_diversity_post_process( @@ -130,15 +131,22 @@ fn greedy_orthogonal_select( return Vec::new(); } - let mut residuals = Vec::with_capacity(n); + let dim = candidates[0].2.len(); + + // Use a contiguous Matrix allocation for residuals instead of Vec>. + // This reduces the number of heap allocations from O(n) to O(1) and improves + // cache locality when accessing residuals sequentially during orthogonalization. + let mut residuals = Matrix::new(0.0f32, n, dim); let mut norms_sq = Vec::with_capacity(n); - for (_, distance_to_query, v) in &candidates { + for (i, (_, distance_to_query, v)) in candidates.iter().enumerate() { let scale = distance_to_similarity(*distance_to_query, distance_range).powf(power) * inv_sqrt_eta; - let residual: Vec = v.iter().map(|&x| x * scale).collect(); - let norm_sq = dot_product(&residual, &residual); - residuals.push(residual); + let row = residuals.row_mut(i); + for (r, &x) in row.iter_mut().zip(v.iter()) { + *r = x * scale; + } + let norm_sq = dot_product(residuals.row(i), residuals.row(i)); norms_sq.push(norm_sq); } @@ -175,13 +183,14 @@ fn greedy_orthogonal_select( } let inv_norm_sq = 1.0 / best_norm_sq; - let r_star_copy = residuals[selected_index].clone(); + // Clone selected row before mutable iteration over remaining rows. + let r_star_copy: Vec = residuals.row(selected_index).to_vec(); for i in 0..n { if !available[i] { projections[i] = 0.0; } else { - projections[i] = dot_product(&residuals[i], &r_star_copy) * inv_norm_sq; + projections[i] = dot_product(residuals.row(i), &r_star_copy) * inv_norm_sq; } } @@ -191,7 +200,7 @@ fn greedy_orthogonal_select( } let projection = projections[i]; - for (residual, &star) in residuals[i].iter_mut().zip(r_star_copy.iter()) { + for (residual, &star) in residuals.row_mut(i).iter_mut().zip(r_star_copy.iter()) { *residual -= projection * star; } From 10b0182a6dac677caee228135dd409abfffd5450 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 13:32:36 +0530 Subject: [PATCH 32/38] Task 10: Move determinant_diversity_post_process out of async_ module - Move implementation to model/graph/provider/determinant_diversity.rs - Export via provider::mod.rs as determinant_diversity_post_process - Keep backward-compatible re-export from async_::mod.rs - Remove old async_-scoped source file - Verify build and full diskann-providers tests pass --- diskann-providers/src/model/graph/provider/async_/mod.rs | 8 ++++---- ...diversity_post_process.rs => determinant_diversity.rs} | 0 diskann-providers/src/model/graph/provider/mod.rs | 7 +++++++ 3 files changed, 11 insertions(+), 4 deletions(-) rename diskann-providers/src/model/graph/provider/{async_/determinant_diversity_post_process.rs => determinant_diversity.rs} (100%) diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index a0bfb3010..25cd95ccb 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -7,12 +7,12 @@ pub mod experimental; pub mod common; pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; -mod determinant_diversity_post_process; -pub(crate) mod postprocess; -pub use determinant_diversity_post_process::determinant_diversity_post_process; +pub(crate) mod postprocess; +// Re-export from parent module for backward compatibility. +// The algorithm is not async-specific and lives in provider::determinant_diversity. pub mod distances; - +pub use super::determinant_diversity_post_process; pub mod memory_vector_provider; pub use memory_vector_provider::MemoryVectorProviderAsync; diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/determinant_diversity.rs similarity index 100% rename from diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs rename to diskann-providers/src/model/graph/provider/determinant_diversity.rs diff --git a/diskann-providers/src/model/graph/provider/mod.rs b/diskann-providers/src/model/graph/provider/mod.rs index 0e045bfb5..f0ac174dd 100644 --- a/diskann-providers/src/model/graph/provider/mod.rs +++ b/diskann-providers/src/model/graph/provider/mod.rs @@ -6,3 +6,10 @@ pub mod async_; // Layers for the async index. pub mod layers; + +/// Determinant-diversity post-processing algorithm. +/// +/// This module is not async-specific and is re-exported here for clarity. +/// It provides diversity-promoting reranking for nearest neighbor search results. +pub mod determinant_diversity; +pub use determinant_diversity::determinant_diversity_post_process; From ca20a24acf20e00b78e43fd8e263a9e7f02bc16b Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 14:24:05 +0530 Subject: [PATCH 33/38] Refactor determinant-diversity benchmark path --- .../src/search/graph/knn.rs | 7 + .../src/backend/index/benchmarks.rs | 220 +++++++----------- .../src/backend/index/search/knn.rs | 42 +--- diskann-benchmark/src/utils/tokio.rs | 6 - 4 files changed, 95 insertions(+), 180 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 19983431d..089c9b12c 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -141,6 +141,13 @@ pub struct Metrics { pub hops: u32, } +impl Metrics { + /// Construct a new metrics value. + pub fn new(comparisons: u32, hops: u32) -> Self { + Self { comparisons, hops } + } +} + impl Search for KNN where DP: provider::DataProvider, diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 75e14e4a4..6a4c65ee9 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use std::{io::Write, num::NonZeroUsize, sync::Arc, time::Instant}; +use std::{io::Write, num::NonZeroUsize, sync::Arc}; use diskann::{ graph::SampleableForStart, @@ -18,7 +18,7 @@ use diskann_benchmark_core::{ use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, - utils::{datatype, MicroSeconds}, + utils::datatype, Benchmark, Checkpoint, }; use diskann_providers::{ @@ -42,7 +42,7 @@ use super::{ use crate::{ backend::index::{ post_processor, - result::{AggregatedSearchResults, BuildResult, SearchResults}, + result::{AggregatedSearchResults, BuildResult}, search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, @@ -451,133 +451,77 @@ impl Strategy { // Topk // //------// -/// Execute a topk search with a custom per-query search function, collecting timing and recall. -/// -/// This function properly creates tokio runtimes with the specified thread counts to accurately -/// measure performance under different concurrency levels. Each thread count is benchmarked -/// independently with its own runtime to avoid misleading single-threaded measurements. -/// -/// Encapsulates the thread/run/L loops, the per-rep timing harness, per-query latency -/// collection, and [`SearchResults`] construction. The caller supplies only the actual -/// per-query search as a closure `search_fn(knn_params, query, output) -> Result<()>`. -fn run_topk_timed( - topk: &crate::inputs::graph_index::TopkSearchPhase, - queries: &Matrix, - groundtruth: &Matrix, - search_fn: impl Fn( - diskann::graph::search::Knn, - &[f32], - &mut Vec>, - ) -> anyhow::Result<()>, -) -> anyhow::Result> { - let mut all_results = Vec::new(); - - for threads in &topk.num_threads { - // Create a tokio runtime with the specified number of threads. - // This ensures that searches are actually executed with the desired concurrency level, - // not just recorded with that thread count while running serially. - let rt = utils::tokio::runtime(threads.get())?; - - for run in &topk.runs { - for search_l in &run.search_l { - let knn_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - - let mut all_recalls = Vec::new(); - let mut qps = Vec::with_capacity(topk.reps.get()); - let mut search_latencies = Vec::with_capacity(topk.reps.get()); - let mut mean_latencies = Vec::with_capacity(topk.reps.get()); - let mut p90_latencies = Vec::with_capacity(topk.reps.get()); - let mut p99_latencies = Vec::with_capacity(topk.reps.get()); - - for _ in 0..topk.reps.get() { - let search_start = Instant::now(); - let mut per_query_latencies = Vec::with_capacity(queries.nrows()); - - rt.block_on(async { - for query_idx in 0..queries.nrows() { - let query = queries.row(query_idx); - let mut output: Vec> = Vec::new(); - - let query_start = Instant::now(); - search_fn(knn_params, query, &mut output)?; - per_query_latencies.push(query_start.elapsed().as_micros() as u64); - - let gt = groundtruth.row(query_idx); - let mut matches = 0; - for (i, neighbor) in output.iter().take(run.recall_k).enumerate() { - if i >= gt.len() { - break; - } - if gt.contains(&neighbor.id) { - matches += 1; - } - } - all_recalls.push(matches); - } - Ok::<(), anyhow::Error>(()) - })?; - - let elapsed: MicroSeconds = search_start.elapsed().into(); - let elapsed_secs = elapsed.as_seconds(); - qps.push(if elapsed_secs > 0.0 { - queries.nrows() as f64 / elapsed_secs - } else { - 0.0 - }); - - per_query_latencies.sort_unstable(); - let len = per_query_latencies.len(); - let p90_idx = ((len as f64 * 0.90).ceil() as usize) - .saturating_sub(1) - .min(len.saturating_sub(1)); - let p99_idx = ((len as f64 * 0.99).ceil() as usize) - .saturating_sub(1) - .min(len.saturating_sub(1)); - let mean = if len > 0 { - per_query_latencies.iter().sum::() as f64 / len as f64 - } else { - 0.0 - }; - - search_latencies.push(elapsed); - mean_latencies.push(mean); - p90_latencies.push(MicroSeconds::new( - *per_query_latencies.get(p90_idx).unwrap_or(&0), - )); - p99_latencies.push(MicroSeconds::new( - *per_query_latencies.get(p99_idx).unwrap_or(&0), - )); - } +struct DeterminantDiversityKnn { + index: Arc>>, + queries: Arc>, + strategy: benchmark_core::search::graph::Strategy, + post_processor: post_processor::DeterminantDiversity, +} - let avg_recall = all_recalls.iter().sum::() as f32 - / (queries.nrows() * run.recall_k * topk.reps.get()) as f32; - - all_results.push(SearchResults { - num_tasks: threads.get(), - search_n: run.search_n, - search_l: *search_l, - qps, - search_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall: utils::recall::RecallMetrics { - recall_k: run.recall_k, - recall_n: run.search_n, - num_queries: queries.nrows(), - average: avg_recall as f64, - minimum: *all_recalls.iter().min().unwrap_or(&0), - maximum: *all_recalls.iter().max().unwrap_or(&0), - }, - mean_cmps: 0.0, - mean_hops: 0.0, - }); - } - } +impl DeterminantDiversityKnn { + fn new( + index: Arc>>, + queries: Arc>, + strategy: benchmark_core::search::graph::Strategy, + post_processor: post_processor::DeterminantDiversity, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor, + })) + } +} + +impl benchmark_core::search::Search for DeterminantDiversityKnn +where + common::FullPrecision: for<'a, 'b> glue::SearchStrategy< + FullPrecisionProvider, + &'a [f32], + SearchAccessor<'b>: post_processor::determinant_diversity::FullPrecisionVectorAccessor, + >, +{ + type Id = u32; + type Parameters = diskann::graph::search::Knn; + type Output = benchmark_core::search::graph::knn::Metrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> benchmark_core::search::IdCount { + benchmark_core::search::IdCount::Fixed(parameters.k_value()) } - Ok(all_results) + async fn search( + &self, + parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> diskann::ANNResult + where + O: diskann::graph::SearchOutputBuffer + Send, + { + let context = DefaultContext; + let stats = self + .index + .search_with( + *parameters, + self.strategy.get(index)?, + self.post_processor, + &context, + self.queries.row(index), + buffer, + ) + .await?; + + Ok(benchmark_core::search::graph::knn::Metrics::new( + stats.cmps, + stats.hops, + )) + } } impl search::Plugin, SearchPhase, Strategy> @@ -605,24 +549,20 @@ where ) -> anyhow::Result { let (topk, params) = plugins::DeterminantDiversity::get(phase)?; - let strategy = common::FullPrecision; - let context = DefaultContext; - let det_div = post_processor::DeterminantDiversity::new(params.power, params.eta); - let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( &topk.queries, ))?); let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; - let results = run_topk_timed(topk, &queries, &groundtruth, |params, query, output| { - utils::tokio::block_on(async { - index - .search_with(params, &strategy, det_div, &context, query, output) - .await - }) - .map(|_| ()) - .map_err(anyhow::Error::from) - })?; + let knn = DeterminantDiversityKnn::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(common::FullPrecision), + post_processor::DeterminantDiversity::new(params.power, params.eta), + )?; + + let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); + let results = search::knn::run(&knn, &groundtruth, steps)?; Ok(AggregatedSearchResults::Topk(results)) } diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index b50e69010..a4485933c 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -78,45 +78,19 @@ pub(crate) trait Knn { // Impls // /////////// -impl Knn for Arc> +impl Knn for Arc where - DP: diskann::provider::DataProvider, - core_search::graph::KNN: core_search::Search< - Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, - Output = core_search::graph::knn::Metrics, - >, + I: benchmark_core::recall::RecallCompatible, + R: core_search::Search< + Id = I, + Parameters = diskann::graph::search::Knn, + Output = core_search::graph::knn::Metrics, + > + 'static, { fn search_all( &self, parameters: Vec>, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result> { - let results = core_search::search_all( - self.clone(), - parameters.into_iter(), - core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n), - )?; - - Ok(results.into_iter().map(SearchResults::new).collect()) - } -} - -impl Knn for Arc> -where - DP: diskann::provider::DataProvider, - core_search::graph::MultiHop: core_search::Search< - Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, - Output = core_search::graph::knn::Metrics, - >, -{ - fn search_all( - &self, - parameters: Vec>, - groundtruth: &dyn benchmark_core::recall::Rows, + groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, ) -> anyhow::Result> { diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index cd07e593a..f50d232af 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -12,9 +12,3 @@ pub(crate) fn block_on(future: F) -> F::Output { .block_on(future) } -/// Create a multi-threaded runtime with the specified number of threads. -pub(crate) fn runtime(num_threads: usize) -> anyhow::Result { - Ok(tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads) - .build()?) -} From f538583a16fea65024a6dcf6efcc53a383c5b685 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Wed, 13 May 2026 15:19:44 +0530 Subject: [PATCH 34/38] cargo fmt and clippy fixes for CI --- .../src/backend/index/benchmarks.rs | 7 ++----- .../src/backend/index/search/plugins.rs | 4 +++- diskann-benchmark/src/utils/tokio.rs | 1 - .../src/model/graph/provider/async_/mod.rs | 1 - .../graph/provider/determinant_diversity.rs | 19 ++++++++++++----- diskann-providers/src/post_processor.rs | 21 +++++++++---------- 6 files changed, 29 insertions(+), 24 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 6a4c65ee9..63170e446 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -46,9 +46,7 @@ use crate::{ search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, }, - inputs::graph_index::{ - DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase, - }, + inputs::graph_index::{DynamicIndexRun, IndexBuild, IndexOperation, IndexSource, SearchPhase}, utils::{ self, datafiles::{self}, @@ -518,8 +516,7 @@ where .await?; Ok(benchmark_core::search::graph::knn::Metrics::new( - stats.cmps, - stats.hops, + stats.cmps, stats.hops, )) } } diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 2b2d50a47..bb7a9a7ae 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -179,7 +179,9 @@ impl DeterminantDiversity { "topk + determinant-diversity" } - pub(crate) fn get(phase: &SearchPhase) -> anyhow::Result<(&TopkSearchPhase, DeterminantDiversityParams)> { + pub(crate) fn get( + phase: &SearchPhase, + ) -> anyhow::Result<(&TopkSearchPhase, DeterminantDiversityParams)> { let topk = phase.as_topk()?; match topk.post_processor.as_ref() { Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { diff --git a/diskann-benchmark/src/utils/tokio.rs b/diskann-benchmark/src/utils/tokio.rs index f50d232af..72dbeb918 100644 --- a/diskann-benchmark/src/utils/tokio.rs +++ b/diskann-benchmark/src/utils/tokio.rs @@ -11,4 +11,3 @@ pub(crate) fn block_on(future: F) -> F::Output { .expect("current thread runtime initialization failed") .block_on(future) } - diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 25cd95ccb..5ad549563 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -7,7 +7,6 @@ pub mod experimental; pub mod common; pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; - pub(crate) mod postprocess; // Re-export from parent module for backward compatibility. // The algorithm is not async-specific and lives in provider::determinant_diversity. diff --git a/diskann-providers/src/model/graph/provider/determinant_diversity.rs b/diskann-providers/src/model/graph/provider/determinant_diversity.rs index 8a6e2eb58..6563bc00e 100644 --- a/diskann-providers/src/model/graph/provider/determinant_diversity.rs +++ b/diskann-providers/src/model/graph/provider/determinant_diversity.rs @@ -339,8 +339,8 @@ mod tests { // Three candidates with equal distance: two very similar (nearly parallel) and one orthogonal. // Equal distances remove relevance weighting, so pure diversity drives selection. let candidates = vec![ - (0u32, 0.1, vec![1.0, 0.0, 0.0]), // along x - (1u32, 0.1, vec![0.0, 1.0, 0.0]), // along y - orthogonal to 0 + (0u32, 0.1, vec![1.0, 0.0, 0.0]), // along x + (1u32, 0.1, vec![0.0, 1.0, 0.0]), // along y - orthogonal to 0 (2u32, 0.1, vec![0.99, 0.01, 0.0]), // nearly parallel to 0 ]; let query = &[1.0, 1.0, 1.0]; @@ -351,7 +351,10 @@ mod tests { // The diverse pair is (0, 1) - orthogonal. Candidate 2 is redundant with 0. let ids: Vec = result.iter().map(|(id, _)| *id).collect(); assert!(ids.contains(&0), "Expected candidate 0 to be selected"); - assert!(ids.contains(&1), "Expected candidate 1 (orthogonal) to be selected, not redundant candidate 2"); + assert!( + ids.contains(&1), + "Expected candidate 1 (orthogonal) to be selected, not redundant candidate 2" + ); } /// Verify eta variant selects the same k results. @@ -368,7 +371,10 @@ mod tests { assert_eq!(result.len(), 2); let ids: Vec = result.iter().map(|(id, _)| *id).collect(); assert!(ids.contains(&0), "Expected candidate 0 to be selected"); - assert!(ids.contains(&1), "Expected candidate 1 (orthogonal) to be selected"); + assert!( + ids.contains(&1), + "Expected candidate 1 (orthogonal) to be selected" + ); } /// Verify power=high weights nearby candidates (distance=0.1) more strongly than far ones. @@ -385,7 +391,10 @@ mod tests { let result = determinant_diversity_post_process(candidates.clone(), query, 1, 0.0, 10.0); assert_eq!(result.len(), 1); // Closest candidate should be preferred due to high power weighting - assert_eq!(result[0].0, 0, "Closest candidate should be selected with high power"); + assert_eq!( + result[0].0, 0, + "Closest candidate should be selected with high power" + ); } /// Verify that distance-to-similarity conversion handles equal distances gracefully. diff --git a/diskann-providers/src/post_processor.rs b/diskann-providers/src/post_processor.rs index db593c87e..07ecd9b39 100644 --- a/diskann-providers/src/post_processor.rs +++ b/diskann-providers/src/post_processor.rs @@ -92,16 +92,12 @@ pub enum DeterminantDiversityError { impl fmt::Display for DeterminantDiversityError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::InvalidPower(p) => write!( - f, - "determinant-diversity power must be > 0.0, got: {}", - p - ), - Self::InvalidEta(e) => write!( - f, - "determinant-diversity eta must be >= 0.0, got: {}", - e - ), + Self::InvalidPower(p) => { + write!(f, "determinant-diversity power must be > 0.0, got: {}", p) + } + Self::InvalidEta(e) => { + write!(f, "determinant-diversity eta must be >= 0.0, got: {}", e) + } } } } @@ -133,6 +129,9 @@ mod tests { #[test] fn test_display() { let params = DeterminantDiversityParams::new(1.5, 0.5).unwrap(); - assert_eq!(params.to_string(), "DeterminantDiversity(power=1.5, eta=0.5)"); + assert_eq!( + params.to_string(), + "DeterminantDiversity(power=1.5, eta=0.5)" + ); } } From f58789d5d47ecae6dcef968f72d0b0e536897aed Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 18 May 2026 20:21:47 +0530 Subject: [PATCH 35/38] Use shared determinant-diversity params validation --- diskann-benchmark/src/inputs/post_processor.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs index 5ff739321..06d813472 100644 --- a/diskann-benchmark/src/inputs/post_processor.rs +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -4,6 +4,7 @@ */ use diskann_benchmark_runner::{CheckDeserialization, Checker}; +use diskann_providers::post_processor::DeterminantDiversityParams; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -16,12 +17,8 @@ impl CheckDeserialization for TopkPostProcessor { fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { match self { TopkPostProcessor::DeterminantDiversity { power, eta } => { - if *power <= 0.0 { - anyhow::bail!("determinant-diversity power must be > 0.0, got: {}", power); - } - if *eta < 0.0 { - anyhow::bail!("determinant-diversity eta must be >= 0.0, got: {}", eta); - } + DeterminantDiversityParams::new(*power, *eta) + .map_err(|e| anyhow::anyhow!("{}", e))?; Ok(()) } } From 3c11d36864436663b44f8832b73b5ee3d35ef6f6 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 18 May 2026 20:29:47 +0530 Subject: [PATCH 36/38] code review comment, use a struct instead of a tuple --- .../graph/provider/determinant_diversity.rs | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/determinant_diversity.rs b/diskann-providers/src/model/graph/provider/determinant_diversity.rs index 6563bc00e..9daa2f518 100644 --- a/diskann-providers/src/model/graph/provider/determinant_diversity.rs +++ b/diskann-providers/src/model/graph/provider/determinant_diversity.rs @@ -49,6 +49,12 @@ use diskann_utils::views::Matrix; use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; +#[derive(Clone, Copy)] +struct DistanceRange { + min: f32, + max: f32, +} + pub fn determinant_diversity_post_process( candidates: Vec<(Id, f32, Vec)>, query: &[f32], @@ -87,7 +93,10 @@ pub fn determinant_diversity_post_process( max_distance = max_distance.max(*distance); } - (min_distance, max_distance) + DistanceRange { + min: min_distance, + max: max_distance, + } }; // For eta=0, the inv_sqrt_eta factor is 1.0 (greedy orthogonalization without regularization). @@ -115,15 +124,12 @@ pub fn determinant_diversity_post_process( /// /// - `inv_sqrt_eta = 1.0`: exact greedy orthogonalization (eta=0 case) /// - `inv_sqrt_eta = 1/sqrt(eta)`: ridge-regularized variant for numerical stability -/// -/// This unified implementation replaces two nearly-identical functions that only -/// differed in whether the scale factor included the eta term. fn greedy_orthogonal_select( candidates: Vec<(Id, f32, Vec)>, k: usize, power: f32, inv_sqrt_eta: f32, - distance_range: (f32, f32), + distance_range: DistanceRange, ) -> Vec<(Id, f32)> { let n = candidates.len(); let k = k.min(n); @@ -217,12 +223,11 @@ fn greedy_orthogonal_select( .collect() } -fn distance_to_similarity(distance: f32, distance_range: (f32, f32)) -> f32 { - let (min_distance, max_distance) = distance_range; - let span = (max_distance - min_distance).max(f32::EPSILON); +fn distance_to_similarity(distance: f32, distance_range: DistanceRange) -> f32 { + let span = (distance_range.max - distance_range.min).max(f32::EPSILON); // Distances are lower-is-better in DiskANN distance semantics. - ((max_distance - distance) / span).max(0.0) + f32::EPSILON + ((distance_range.max - distance) / span).max(0.0) + f32::EPSILON } #[inline] From b65e673db223adc5ce9ff06657d04c0321708a12 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 18 May 2026 20:38:08 +0530 Subject: [PATCH 37/38] Refine determinant-diversity invariants and range representation --- .../graph/provider/determinant_diversity.rs | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/determinant_diversity.rs b/diskann-providers/src/model/graph/provider/determinant_diversity.rs index 9daa2f518..111ec7adf 100644 --- a/diskann-providers/src/model/graph/provider/determinant_diversity.rs +++ b/diskann-providers/src/model/graph/provider/determinant_diversity.rs @@ -66,14 +66,12 @@ pub fn determinant_diversity_post_process( return Vec::new(); } - let candidates: Vec<_> = candidates - .into_iter() - .filter(|(_, _, vector)| vector.len() == query.len()) - .collect(); - - if candidates.is_empty() { - return Vec::new(); - } + assert!( + candidates + .iter() + .all(|(_, _, vector)| vector.len() == query.len()), + "all candidate vectors must have the same dimension as query" + ); let k = k.min(candidates.len()); if k == 0 { @@ -255,14 +253,14 @@ mod tests { } #[test] - fn test_mismatched_dimensions() { + #[should_panic(expected = "all candidate vectors must have the same dimension as query")] + fn test_mismatched_dimensions_panics() { let candidates = vec![ (0u32, 0.5, vec![1.0, 2.0]), (1u32, 0.3, vec![1.0]), // Wrong dimension ]; let query = &[1.0, 2.0, 3.0]; - let result = determinant_diversity_post_process(candidates, query, 5, 0.5, 1.0); - assert_eq!(result.len(), 0); // All candidates filtered due to dimension mismatch + let _ = determinant_diversity_post_process(candidates, query, 5, 0.5, 1.0); } #[test] From 85797ce4a048f9fae7d98171a2eae7359be5b410 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 18 May 2026 21:29:12 +0530 Subject: [PATCH 38/38] minor code cleanup --- diskann-benchmark-core/src/search/graph/knn.rs | 10 ---------- diskann-benchmark/src/backend/index/benchmarks.rs | 2 +- diskann-providers/src/post_processor.rs | 4 ++-- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 089c9b12c..321383cb7 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -111,16 +111,6 @@ where &self.index } - /// Access the queries. - pub fn queries(&self) -> &Arc> { - &self.queries - } - - /// Access the strategy. - pub fn strategy(&self) -> &Strategy { - &self.strategy - } - /// Access the post-processor, if present. pub fn post_processor(&self) -> &Option { &self.post_processor diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 8a74e2fee..9c2aa7d5b 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -548,7 +548,7 @@ where index, queries, benchmark_core::search::graph::Strategy::broadcast(common::FullPrecision), - post_processor::DeterminantDiversity::new(params.power, params.eta), + post_processor::DeterminantDiversity::new(params.power(), params.eta()), )?; let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); diff --git a/diskann-providers/src/post_processor.rs b/diskann-providers/src/post_processor.rs index 07ecd9b39..07263aaa5 100644 --- a/diskann-providers/src/post_processor.rs +++ b/diskann-providers/src/post_processor.rs @@ -34,9 +34,9 @@ use std::fmt; #[derive(Debug, Clone, Copy)] pub struct DeterminantDiversityParams { /// Relevance weighting exponent. Must be > 0.0. - pub power: f32, + power: f32, /// Numerical stability parameter. Must be >= 0.0. - pub eta: f32, + eta: f32, } impl DeterminantDiversityParams {