From e116635753526e9feadc7f2c33c402bc29c1aadf Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 9 Mar 2026 19:22:14 +0530 Subject: [PATCH 01/47] Refactor search post-processing API and default processors --- .../src/search/graph/knn.rs | 5 +- .../src/search/graph/multihop.rs | 5 +- .../src/search/graph/range.rs | 5 +- .../src/backend/index/benchmarks.rs | 2 +- .../src/search/provider/disk_provider.rs | 24 ++- .../inline_beta_search/inline_beta_filter.rs | 20 +- diskann-providers/src/index/diskann_async.rs | 15 +- diskann-providers/src/index/wrapped_async.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 61 ++++-- .../graph/provider/async_/caching/provider.rs | 29 ++- .../graph/provider/async_/debug_provider.rs | 44 ++-- .../provider/async_/inmem/full_precision.rs | 37 +++- .../graph/provider/async_/inmem/product.rs | 48 +++-- .../graph/provider/async_/inmem/scalar.rs | 35 +++- .../graph/provider/async_/inmem/spherical.rs | 31 ++- .../model/graph/provider/async_/inmem/test.rs | 12 +- .../model/graph/provider/layers/betafilter.rs | 27 ++- diskann/src/graph/glue.rs | 134 ++++++++++-- diskann/src/graph/index.rs | 16 +- diskann/src/graph/search/diverse_search.rs | 12 +- diskann/src/graph/search/knn_search.rs | 190 ++++++++++++------ diskann/src/graph/search/mod.rs | 2 +- diskann/src/graph/search/multihop_search.rs | 16 +- diskann/src/graph/search/range_search.rs | 14 +- diskann/src/graph/test/provider.rs | 14 +- 25 files changed, 570 insertions(+), 230 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 82cd0cea8..ebc0e3b4b 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -88,7 +88,10 @@ pub struct Metrics { impl Search for KNN where DP: provider::DataProvider, - S: glue::SearchStrategy + Clone + AsyncFriendly, + S: glue::SearchStrategy + + glue::HasDefaultProcessor + + Clone + + AsyncFriendly, T: AsyncFriendly + Clone, { type Id = DP::ExternalId; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index de62d5241..584c70baf 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -86,7 +86,10 @@ where impl Search for MultiHop where DP: provider::DataProvider, - S: glue::SearchStrategy + Clone + AsyncFriendly, + S: glue::SearchStrategy + + glue::HasDefaultProcessor + + Clone + + AsyncFriendly, T: AsyncFriendly + Clone, { type Id = DP::ExternalId; diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index 9b6078da6..7d9348244 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -79,7 +79,10 @@ pub struct Metrics {} impl Search for Range where DP: provider::DataProvider, - S: glue::SearchStrategy + Clone + AsyncFriendly, + S: glue::SearchStrategy + + glue::HasDefaultProcessor + + Clone + + AsyncFriendly, T: AsyncFriendly + Clone, { type Id = DP::ExternalId; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index f8b128c4d..a5583af94 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -349,7 +349,7 @@ where DP: DataProvider + provider::SetElement<[T]>, T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, - S: glue::SearchStrategy + Clone + AsyncFriendly, + S: glue::SearchStrategy + glue::HasDefaultProcessor + Clone + AsyncFriendly, { match &input { SearchPhase::Topk(search_phase) => { diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 25916df4f..870541f0c 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -18,7 +18,10 @@ use std::{ use diskann::{ graph::{ self, - glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy}, + glue::{ + self, ExpandBeam, HasDefaultProcessor, IdIterator, SearchExt, SearchPostProcess, + SearchStrategy, + }, search::Knn, search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, }, @@ -351,7 +354,6 @@ where type QueryComputer = DiskQueryComputer; type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; type SearchAccessorError = ANNError; - type PostProcessor = RerankAndFilter<'this>; fn search_accessor<'a>( &'a self, @@ -366,8 +368,24 @@ where self.scratch_pool, ) } +} + +impl<'this, Data, ProviderFactory> + HasDefaultProcessor< + DiskProvider, + [Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DiskSearchStrategy<'this, Data, ProviderFactory> +where + Data: GraphDataType, + ProviderFactory: VertexProviderFactory, +{ + type Processor = RerankAndFilter<'this>; - fn post_processor(&self) -> Self::PostProcessor { + fn create_processor(&self) -> Self::Processor { RerankAndFilter::new(self.vector_filter) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index b25b1746f..ea509b3ff 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -37,7 +37,6 @@ where Q: AsyncFriendly + Clone, { type QueryComputer = InlineBetaComputer; - type PostProcessor = FilterResults; type SearchAccessorError = ANNError; type SearchAccessor<'a> = EncodedDocumentAccessor>; @@ -60,10 +59,25 @@ where self.beta, )) } +} + +/// [`HasDefaultProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps +/// the inner strategy's default processor with [`FilterResults`]. +impl + diskann::graph::glue::HasDefaultProcessor< + DocumentProvider>, + FilteredQuery, + > for InlineBetaStrategy +where + DP: DataProvider, + Strategy: diskann::graph::glue::HasDefaultProcessor, + Q: AsyncFriendly + Clone, +{ + type Processor = FilterResults; - fn post_processor(&self) -> Self::PostProcessor { + fn create_processor(&self) -> Self::Processor { FilterResults { - inner_post_processor: self.inner.post_processor(), + inner_post_processor: self.inner.create_processor(), } } } diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index f5d129db8..2b928cff7 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -175,7 +175,10 @@ pub(crate) mod tests { graph::{ self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, StartPointStrategy, config::IntraBatchCandidates, - glue::{AsElement, InplaceDeleteStrategy, InsertStrategy, SearchStrategy, aliases}, + glue::{ + AsElement, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + SearchStrategy, aliases, + }, index::{PartitionedNeighbors, QueryLabelProvider, QueryVisitDecision}, search::{Knn, Range}, search_output_buffer, @@ -347,7 +350,7 @@ pub(crate) mod tests { mut checker: Checker, ) where DP: DataProvider, - S: SearchStrategy, + S: SearchStrategy + HasDefaultProcessor, Q: std::fmt::Debug + Sync + ?Sized, Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, { @@ -395,7 +398,7 @@ pub(crate) mod tests { filter: &dyn QueryLabelProvider, ) where DP: DataProvider, - S: SearchStrategy, + S: SearchStrategy + HasDefaultProcessor, Q: std::fmt::Debug + Sync + ?Sized, Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, { @@ -501,8 +504,8 @@ pub(crate) mod tests { quant_strategy: QS, ) where DP: DataProvider, - FS: SearchStrategy + Clone + 'static, - QS: SearchStrategy + Clone + 'static, + FS: SearchStrategy + HasDefaultProcessor + Clone + 'static, + QS: SearchStrategy + HasDefaultProcessor + Clone + 'static, T: Default + Clone + Send + Sync + std::fmt::Debug, { // Assume all vectors have the same length. @@ -925,6 +928,7 @@ pub(crate) mod tests { T: VectorRepr + GenerateSphericalData + Into, S: InsertStrategy, [T]> + SearchStrategy, [T]> + + HasDefaultProcessor, [T]> + Clone + 'static, rand::distr::StandardUniform: Distribution, @@ -1052,6 +1056,7 @@ pub(crate) mod tests { T: VectorRepr + GenerateSphericalData + Into, S: InsertStrategy, [T]> + SearchStrategy, [T]> + + HasDefaultProcessor, [T]> + Clone + 'static, rand::distr::StandardUniform: Distribution, diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index b3b07e626..bfcf448a7 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -232,7 +232,7 @@ where ) -> ANNResult where T: Sync + ?Sized, - S: SearchStrategy, + S: SearchStrategy + glue::HasDefaultProcessor, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 403dfac3f..ee63e1b00 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -16,13 +16,14 @@ use std::{ use serde::{Deserialize, Serialize}; use bf_tree::{BfTree, Config}; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ - self, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, - SearchExt, SearchStrategy, + self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -1475,7 +1476,6 @@ where type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D>; type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; fn search_accessor<'a>( &'a self, @@ -1484,10 +1484,15 @@ where ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, +{ + delegate_default_post_process!(RemoveDeletedIdsAndCopy); } /// Perform a search entirely in the full-precision space. @@ -1502,7 +1507,6 @@ where type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D>; type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -1511,10 +1515,15 @@ where ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, +{ + delegate_default_post_process!(glue::Pipeline); } /// An [`glue::SearchPostProcess`] implementation that reranks PQ vectors. @@ -1580,7 +1589,6 @@ where type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, T, D>; type SearchAccessorError = Panics; - type PostProcessor = Rerank; fn search_accessor<'a>( &'a self, @@ -1589,10 +1597,14 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> for Internal +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, +{ + delegate_default_post_process!(Rerank); } /// Perform a search entirely in the quantized space. @@ -1607,7 +1619,6 @@ where type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, T, D>; type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -1616,10 +1627,14 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> for Hybrid +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, +{ + delegate_default_post_process!(glue::Pipeline); } // Pruning @@ -1730,6 +1745,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; type SearchStrategy = Internal; fn search_strategy(&self) -> Self::SearchStrategy { Internal(Self) @@ -1739,6 +1755,10 @@ where Self } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + async fn get_delete_element<'a>( &'a self, provider: &'a BfTreeProvider, @@ -1764,6 +1784,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; type SearchStrategy = Internal; fn search_strategy(&self) -> Self::SearchStrategy { Internal(*self) @@ -1773,6 +1794,10 @@ where *self } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + async fn get_delete_element<'a>( &'a self, provider: &'a BfTreeProvider, diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index a02f663aa..5ce7220d7 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -962,7 +962,6 @@ where >>::Accessor, >; type SearchAccessorError = CachingError; - type PostProcessor = Pipeline; fn search_accessor<'a>( &'a self, @@ -979,9 +978,28 @@ where .as_cache_accessor_for(inner) .map_err(CachingError::Cache) } +} - fn post_processor(&self) -> Self::PostProcessor { - Pipeline::new(Unwrap, self.strategy.post_processor()) +/// [`HasDefaultProcessor`] delegation for [`Cached`]. The processor is composed by +/// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. +impl glue::HasDefaultProcessor, T> for Cached +where + T: ?Sized, + DP: DataProvider, + S: glue::HasDefaultProcessor + + for<'a> SearchStrategy: CacheableAccessor>, + C: for<'a> AsCacheAccessorFor< + 'a, + SearchAccessor<'a, S, DP, T>, + Accessor: NeighborCache, + Error = E, + > + AsyncFriendly, + E: StandardError, +{ + type Processor = Pipeline; + + fn create_processor(&self) -> Self::Processor { + Pipeline::new(Unwrap, self.strategy.create_processor()) } } @@ -1066,6 +1084,7 @@ where type PruneStrategy = Cached; type SearchStrategy = Cached; + type SearchPostProcessor = S::SearchPostProcessor; fn prune_strategy(&self) -> Self::PruneStrategy { Cached { @@ -1079,6 +1098,10 @@ where } } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + self.strategy.search_post_processor() + } + fn get_delete_element<'a>( &'a self, provider: &'a CachingProvider, diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index 0c0ba780e..0061f3627 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -11,13 +11,15 @@ use std::{ }, }; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNErrorKind, ANNResult, graph::{ AdjacencyList, glue::{ - AsElement, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, - InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, + AsElement, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, + InplaceDeleteStrategy, InsertStrategy, Pipeline, PruneStrategy, SearchExt, + SearchStrategy, }, }, provider::{ @@ -888,7 +890,6 @@ impl FillSet for HybridAccessor<'_> { impl SearchStrategy for Internal { type QueryComputer = ::QueryDistance; - type PostProcessor = postprocess::RemoveDeletedIdsAndCopy; type SearchAccessorError = Panics; type SearchAccessor<'a> = FullAccessor<'a>; @@ -899,15 +900,14 @@ impl SearchStrategy for Internal { ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor for Internal { + delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); } impl SearchStrategy for FullPrecision { type QueryComputer = ::QueryDistance; - type PostProcessor = Pipeline; type SearchAccessorError = Panics; type SearchAccessor<'a> = FullAccessor<'a>; @@ -918,15 +918,14 @@ impl SearchStrategy for FullPrecision { ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor for FullPrecision { + delegate_default_post_process!(Pipeline); } impl SearchStrategy for Internal { type QueryComputer = pq::distance::QueryComputer>; - type PostProcessor = postprocess::RemoveDeletedIdsAndCopy; type SearchAccessorError = Panics; type SearchAccessor<'a> = QuantAccessor<'a>; @@ -937,15 +936,14 @@ impl SearchStrategy for Internal { ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor for Internal { + delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); } impl SearchStrategy for Quantized { type QueryComputer = pq::distance::QueryComputer>; - type PostProcessor = Pipeline; type SearchAccessorError = Panics; type SearchAccessor<'a> = QuantAccessor<'a>; @@ -956,10 +954,10 @@ impl SearchStrategy for Quantized { ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor for Quantized { + delegate_default_post_process!(Pipeline); } impl PruneStrategy for FullPrecision { @@ -1051,6 +1049,7 @@ impl InplaceDeleteStrategy for FullPrecision { type DeleteElementGuard = Vec; type DeleteElementError = Panics; type PruneStrategy = Self; + type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; type SearchStrategy = Internal; fn prune_strategy(&self) -> Self::PruneStrategy { @@ -1061,6 +1060,10 @@ impl InplaceDeleteStrategy for FullPrecision { Internal(*self) } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + fn get_delete_element<'a>( &'a self, provider: &'a DebugProvider, @@ -1077,6 +1080,7 @@ impl InplaceDeleteStrategy for Quantized { type DeleteElementGuard = Vec; type DeleteElementError = Panics; type PruneStrategy = Self; + type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; type SearchStrategy = Internal; fn prune_strategy(&self) -> Self::PruneStrategy { @@ -1087,6 +1091,10 @@ impl InplaceDeleteStrategy for Quantized { Internal(*self) } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + fn get_delete_element<'a>( &'a self, provider: &'a DebugProvider, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e74419a46..0592994af 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -5,13 +5,14 @@ use std::{collections::HashMap, fmt::Debug, future::Future}; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::{ SearchOutputBuffer, glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, + InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -453,7 +454,6 @@ where type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; type SearchAccessorError = Panics; - type PostProcessor = RemoveDeletedIdsAndCopy; fn search_accessor<'a>( &'a self, @@ -462,10 +462,17 @@ where ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(RemoveDeletedIdsAndCopy); } /// Perform a search entirely in the full-precision space. @@ -481,7 +488,6 @@ where type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -490,10 +496,16 @@ where ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(glue::Pipeline); } // Pruning @@ -560,6 +572,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; type SearchStrategy = Internal; fn search_strategy(&self) -> Self::SearchStrategy { Internal(Self) @@ -569,6 +582,10 @@ where Self } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + async fn get_delete_element<'a>( &'a self, provider: &'a FullPrecisionProvider, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 2886d1d20..4d176052b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -5,11 +5,12 @@ use std::{collections::HashMap, future::Future, sync::Arc}; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InplaceDeleteStrategy, + InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -473,7 +474,6 @@ where type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = Panics; - type PostProcessor = Rerank; fn search_accessor<'a>( &'a self, @@ -482,10 +482,16 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> + for Internal +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(Rerank); } /// Perform a search entirely in the quantized space. @@ -501,7 +507,6 @@ where type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -510,10 +515,15 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> for Hybrid +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(glue::Pipeline); } impl PruneStrategy> for Hybrid @@ -578,6 +588,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; type SearchStrategy = Internal; fn search_strategy(&self) -> Self::SearchStrategy { Internal(*self) @@ -587,6 +598,10 @@ where *self } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + async fn get_delete_element<'a>( &'a self, provider: &'a FullPrecisionProvider, @@ -613,7 +628,6 @@ where type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; type SearchAccessorError = Panics; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -622,10 +636,16 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> + for Quantized +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(glue::Pipeline); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 9d3fd9c32..0a450d8e2 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -6,11 +6,12 @@ use std::{future::Future, sync::Mutex}; use crate::storage::{StorageReadProvider, StorageWriteProvider}; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, PruneStrategy, SearchExt, - SearchStrategy, + self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -612,7 +613,6 @@ where type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -621,10 +621,18 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider, true)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl + HasDefaultProcessor, D, Ctx>, [T]> for Quantized +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, + Unsigned: Representation, + QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, +{ + delegate_default_post_process!(glue::Pipeline); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -642,7 +650,6 @@ where type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, NoStore, D, Ctx>; type SearchAccessorError = ANNError; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -651,10 +658,18 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider, true)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl + HasDefaultProcessor, D, Ctx>, [T]> for Quantized +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, + Unsigned: Representation, + QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, +{ + delegate_default_post_process!(glue::Pipeline); } impl PruneStrategy, D, Ctx>> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index b705b43d3..552001a07 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -7,12 +7,13 @@ use std::{future::Future, sync::Mutex}; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, PruneStrategy, SearchExt, - SearchStrategy, + self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -561,7 +562,6 @@ where UnwrapErr; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -570,10 +570,16 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider, self.layout, self.is_search)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> + for Quantized +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(glue::Pipeline); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -589,7 +595,6 @@ where UnwrapErr; type SearchAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; type SearchAccessorError = ANNError; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -598,10 +603,16 @@ where ) -> Result, Self::SearchAccessorError> { Ok(QuantAccessor::new(provider, self.layout, self.is_search)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor, [T]> + for Quantized +where + T: VectorRepr, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + delegate_default_post_process!(glue::Pipeline); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 9c6f908b1..ef3329a3e 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -5,12 +5,13 @@ use std::{future::Future, sync::Mutex}; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, error::{RankedError, ToRanked, TransientError}, graph::glue::{ - AsElement, CopyIds, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, SearchExt, - SearchStrategy, + AsElement, CopyIds, ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, neighbor::Neighbor, provider::{ @@ -236,7 +237,6 @@ impl SearchStrategy for Flaky { type QueryComputer = as BuildQueryComputer<[f32]>>::QueryComputer; type SearchAccessor<'a> = FlakyAccessor<'a>; type SearchAccessorError = ANNError; - type PostProcessor = CopyIds; fn search_accessor<'a>( &'a self, @@ -249,10 +249,10 @@ impl SearchStrategy for Flaky { self.fail_every, )) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl HasDefaultProcessor for Flaky { + delegate_default_post_process!(CopyIds); } impl FillSet for FlakyAccessor<'_> {} diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index edd1e3139..afad564dd 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -142,13 +142,23 @@ where beta: self.beta, }) } +} - /// Forward the post-processing error from the inner strategy. - type PostProcessor = glue::Pipeline; +/// [`HasDefaultProcessor`] delegation for [`BetaFilter`]. The processor is composed by +/// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. +impl glue::HasDefaultProcessor + for BetaFilter +where + T: ?Sized, + I: VectorId, + O: Send, + Provider: DataProvider, + Strategy: glue::HasDefaultProcessor, +{ + type Processor = glue::Pipeline; - /// Delegate post-processing to the inner strategy's post-processing routine. - fn post_processor(&self) -> Self::PostProcessor { - glue::Pipeline::new(Unwrap, self.strategy.post_processor()) + fn create_processor(&self) -> Self::Processor { + glue::Pipeline::new(Unwrap, self.strategy.create_processor()) } } @@ -538,7 +548,6 @@ mod tests { impl SearchStrategy for SimpleStrategy { type SearchAccessor<'a> = Doubler; type QueryComputer = AddingComputer; - type PostProcessor = CopyIds; type SearchAccessorError = ANNError; fn search_accessor<'a>( @@ -548,10 +557,10 @@ mod tests { ) -> Result, Self::SearchAccessorError> { Ok(Doubler::default()) } + } - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } + impl glue::HasDefaultProcessor for SimpleStrategy { + diskann::delegate_default_post_process!(CopyIds); } /// A simple `QueryLabelProvider` that matches multiples of 3. diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 411a97031..5dc1179fb 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -316,9 +316,6 @@ where + Sync + 'static; - /// The associated [`SearchPostProcess`]or for the final results. - type PostProcessor: for<'a> SearchPostProcess, T, O> + Send + Sync; - /// An error that can occur when getting a search_accessor. type SearchAccessorError: StandardError; @@ -334,10 +331,116 @@ where provider: &'a Provider, context: &'a Provider::Context, ) -> Result, Self::SearchAccessorError>; +} + +/// Strategy-level bridge connecting a [`SearchStrategy`] to a specific processor type `P`. +/// +/// This trait is the surface that the search infrastructure ([`super::search::Knn`], +/// [`super::search::KnnWith`], etc.) bounds on. +/// +/// The blanket impl covers `P = DefaultPostProcess` for any strategy implementing +/// [`HasDefaultProcessor`]. Custom processor types (e.g. `RagSearchParams`) can have +/// their own `PostProcess` impls without coherence conflicts. +pub trait PostProcess::InternalId>: + SearchStrategy +where + Provider: DataProvider, + T: ?Sized, + O: Send, + P: Send + Sync, +{ + /// Run post-processing with the given `processor` on `candidates`, writing + /// results into `output`. + fn post_process_with<'a, I, B>( + &self, + processor: &P, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized; +} + +/// Opt-in trait for strategies that have a default post-processor. +/// +/// Strategies implementing this trait work with [`super::search::Knn`] (no explicit +/// processor). The old `SearchStrategy::PostProcessor` associated type is replaced by +/// `HasDefaultProcessor::Processor`. +pub trait HasDefaultProcessor::InternalId>: + SearchStrategy +where + Provider: DataProvider, + T: ?Sized, + O: Send, +{ + /// The default post-processor type. + type Processor: for<'a> SearchPostProcess, T, O> + Send + Sync; + + /// Create the default post-processor. + fn create_processor(&self) -> Self::Processor; +} + +/// Convenience macro for implementing [`HasDefaultProcessor`] when the processor +/// is a [`Default`]-constructible type. +/// +/// # Example +/// +/// ```ignore +/// impl HasDefaultProcessor for MyStrategy { +/// delegate_default_post_process!(CopyIds); +/// } +/// ``` +#[macro_export] +macro_rules! delegate_default_post_process { + ($Processor:ty) => { + type Processor = $Processor; + fn create_processor(&self) -> Self::Processor { + Default::default() + } + }; +} + +/// A zero-sized marker representing "use the default post-processor". +/// +/// The blanket `PostProcess` impl covers exactly `P = DefaultPostProcess`. +/// Custom processor types are free to have their own `PostProcess` impls +/// without coherence conflicts. +#[derive(Debug, Default, Clone, Copy)] +pub struct DefaultPostProcess; - /// Construct the [`SearchPostProcess`] struct to post-process the results of search and - /// store them into the output container. - fn post_processor(&self) -> Self::PostProcessor; +impl PostProcess for S +where + S: HasDefaultProcessor, + Provider: DataProvider, + T: ?Sized + Sync, + O: Send, +{ + fn post_process_with<'a, I, B>( + &self, + _processor: &DefaultPostProcess, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + use crate::error::IntoANNResult; + async move { + self.create_processor() + .post_process(accessor, query, computer, candidates, output) + .send() + .await + .into_ann_result() + } + } } /// Perform post-processing on the results of search, storing the results in an output buffer. @@ -741,8 +844,13 @@ where /// The pruning strategy to use after the initial search is complete. type PruneStrategy: PruneStrategy; + /// The processor used during the delete-search phase. + type SearchPostProcessor: Send + Sync; + /// The type of the search strategy to use for graph traversal. - type SearchStrategy: for<'a> SearchStrategy>; + /// It must support [`PostProcess`] with [`Self::SearchPostProcessor`]. + type SearchStrategy: for<'a> SearchStrategy> + + for<'a> PostProcess, Self::SearchPostProcessor>; /// Construct the prune strategy object. fn prune_strategy(&self) -> Self::PruneStrategy; @@ -750,6 +858,9 @@ where /// Construct the search strategy object. fn search_strategy(&self) -> Self::SearchStrategy; + /// Construct the search post-processor for the delete-search phase. + fn search_post_processor(&self) -> Self::SearchPostProcessor; + /// Construct the accessor used to retrieve the item being deleted. fn get_delete_element<'a>( &'a self, @@ -1012,7 +1123,6 @@ mod tests { impl SearchStrategy for Strategy { type QueryComputer = QueryComputer; - type PostProcessor = CopyIds; type SearchAccessorError = ANNError; type SearchAccessor<'a> = Retriever<'a>; @@ -1027,10 +1137,10 @@ mod tests { self.errors_are_unrecoverable, )) } + } - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } + impl HasDefaultProcessor for Strategy { + delegate_default_post_process!(CopyIds); } // Use the provided implementation. @@ -1081,7 +1191,7 @@ mod tests { let mut output = vec![Neighbor::::default(); output_len]; let count = strategy - .post_processor() + .create_processor() .post_process( &mut accessor, &query, diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 649177cd5..f64046185 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -27,7 +27,7 @@ use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, glue::{ self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, + PostProcess, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -1296,18 +1296,17 @@ where // NOTE: We rely on `post_process` to remove deleted items from the results // placed into the output. let proxy = v.async_lower(); + let post_processor = strategy.search_post_processor(); let num_results = search_strategy - .post_processor() - .post_process( + .post_process_with( + &post_processor, &mut search_accessor, &*proxy, &computer, scratch.best.iter(), &mut neighbor::BackInserter::new(output.as_mut_slice()), ) - .send() - .await - .into_ann_result()?; + .await?; let mut undeleted_ids: Vec<_> = output .iter() @@ -2198,7 +2197,8 @@ where ) -> ANNResult where T: ?Sized, - S: SearchStrategy: IdIterator>, + S: SearchStrategy: IdIterator> + + glue::HasDefaultProcessor, I: Iterator::InternalId>, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, @@ -2233,7 +2233,7 @@ where } let result_count = strategy - .post_processor() + .create_processor() .post_process( &mut accessor, query, diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index b5c05c993..b6f44677c 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -14,7 +14,7 @@ use crate::{ error::IntoANNResult, graph::{ DiverseSearchParams, - glue::{SearchExt, SearchPostProcess, SearchStrategy}, + glue::{DefaultPostProcess, PostProcess, SearchExt, SearchStrategy}, index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, @@ -96,7 +96,7 @@ impl Search for Diverse

where DP: DataProvider, T: Sync + ?Sized, - S: SearchStrategy, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send, P: AttributeValueProvider, @@ -136,17 +136,15 @@ where diverse_scratch.best.post_process(); let result_count = strategy - .post_processor() - .post_process( + .post_process_with( + &DefaultPostProcess, &mut accessor, query, &computer, diverse_scratch.best.iter().take(self.inner.l_value().get()), output, ) - .send() - .await - .into_ann_result()?; + .await?; Ok(stats.finish(result_count as u32)) } diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index fce7d6eae..4c676309f 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -7,7 +7,7 @@ use std::{fmt::Debug, num::NonZeroUsize}; -use diskann_utils::future::{AssertSend, SendFuture}; +use diskann_utils::future::SendFuture; use thiserror::Error; use super::Search; @@ -15,7 +15,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{SearchExt, SearchPostProcess, SearchStrategy}, + glue::{DefaultPostProcess, PostProcess, SearchExt}, index::{DiskANNIndex, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, @@ -143,40 +143,71 @@ impl Knn { } } +impl Knn { + /// Shared search core parameterised over the post-processor type. + async fn search_core( + &self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + post_processor: &PP, + ) -> ANNResult + where + DP: DataProvider, + T: Sync + ?Sized, + S: PostProcess, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + PP: Send + Sync, + { + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); + + let stats = index + .search_internal( + Some(self.beam_width.get()), + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let result_count = strategy + .post_process_with( + post_processor, + &mut accessor, + query, + &computer, + scratch.best.iter().take(self.l_value.get().into_usize()), + output, + ) + .await?; + + Ok(stats.finish(result_count as u32)) + } +} + impl Search for Knn where DP: DataProvider, T: Sync + ?Sized, - S: SearchStrategy, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send + ?Sized, { type Output = SearchStats; - /// Execute the k-NN search on the given index. - /// - /// This method executes a search using the provided `strategy` to access and process elements. - /// It computes the similarity between the query vector and the elements in the index, traversing - /// the graph towards the nearest neighbors according to the search parameters. - /// - /// # Arguments - /// - /// * `index` - The DiskANN index to search. - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// - /// # Returns - /// - /// Returns [`SearchStats`] containing: - /// - The number of distance computations performed. - /// - The number of hops (graph traversal steps). - /// - Timing information for the search operation. - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or computing distances. + /// Execute the k-NN search on the given index using the default post-processor. fn search( self, index: &DiskANNIndex, @@ -186,40 +217,8 @@ where output: &mut OB, ) -> impl SendFuture> { async move { - let mut accessor = strategy - .search_accessor(&index.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); - - let stats = index - .search_internal( - Some(self.beam_width.get()), - &start_ids, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - let result_count = strategy - .post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(self.l_value.get().into_usize()), - output, - ) - .send() + self.search_core(index, strategy, context, query, output, &DefaultPostProcess) .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) } } } @@ -250,7 +249,7 @@ impl<'r, DP, S, T, O, OB, SR> Search for RecordedKnn<'r, SR> where DP: DataProvider, T: Sync + ?Sized, - S: SearchStrategy, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send + ?Sized, SR: super::record::SearchRecord + ?Sized, @@ -287,8 +286,8 @@ where .await?; let result_count = strategy - .post_processor() - .post_process( + .post_process_with( + &DefaultPostProcess, &mut accessor, query, &computer, @@ -298,15 +297,74 @@ where .take(self.inner.l_value.get().into_usize()), output, ) - .send() - .await - .into_ann_result()?; + .await?; Ok(stats.finish(result_count as u32)) } } } +///////////////////////// +// KnnWith // +///////////////////////// + +/// K-NN search with an explicit caller-supplied post-processor. +/// +/// This allows using a custom post-processor `PP` instead of the strategy's default. +/// Use [`KnnWith::new`] to wrap a base [`Knn`] with a post-processor. +#[derive(Debug, Clone)] +pub struct KnnWith { + /// Base k-NN search parameters. + pub inner: Knn, + /// The caller-supplied post-processor. + pub post_processor: PP, +} + +impl KnnWith { + /// Create new k-NN search parameters with an explicit post-processor. + pub fn new(inner: Knn, post_processor: PP) -> Self { + Self { + inner, + post_processor, + } + } +} + +impl Search for KnnWith +where + DP: DataProvider, + T: Sync + ?Sized, + S: PostProcess, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + PP: Send + Sync, +{ + type Output = SearchStats; + + /// Execute the k-NN search with the caller-supplied post-processor. + fn search( + self, + index: &DiskANNIndex, + strategy: &S, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> { + async move { + self.inner + .search_core( + index, + strategy, + context, + query, + output, + &self.post_processor, + ) + .await + } + } +} + /////////// // Tests // /////////// diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index c3786ac8d..5186bc684 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -88,7 +88,7 @@ where } // Re-export search parameter types. -pub use knn_search::{Knn, KnnSearchError, RecordedKnn}; +pub use knn_search::{Knn, KnnSearchError, KnnWith, RecordedKnn}; pub use multihop_search::MultihopSearch; pub use range_search::{Range, RangeSearchError, RangeSearchOutput}; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index d5e1e7a35..2f7c9e659 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -6,7 +6,7 @@ //! Label-filtered search using multi-hop expansion. use diskann_utils::Reborrow; -use diskann_utils::future::{AssertSend, SendFuture}; +use diskann_utils::future::SendFuture; use diskann_vector::PreprocessedDistanceFunction; use hashbrown::HashSet; @@ -16,8 +16,8 @@ use crate::{ error::{ErrorExt, IntoANNResult}, graph::{ glue::{ - self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, - SearchPostProcess, SearchStrategy, + self, DefaultPostProcess, ExpandBeam, HybridPredicate, PostProcess, Predicate, + PredicateMut, SearchExt, }, index::{ DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, @@ -57,7 +57,7 @@ impl<'q, DP, S, T, O, OB> Search for MultihopSearch<'q, DP::Int where DP: DataProvider, T: Sync + ?Sized, - S: SearchStrategy, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send, { @@ -93,17 +93,15 @@ where .await?; let result_count = strategy - .post_processor() - .post_process( + .post_process_with( + &DefaultPostProcess, &mut accessor, query, &computer, scratch.best.iter().take(self.inner.l_value().get()), output, ) - .send() - .await - .into_ann_result()?; + .await?; Ok(stats.finish(result_count as u32)) } diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index d1a6ec698..4228b28d8 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -5,7 +5,7 @@ //! Range-based search within a distance radius. -use diskann_utils::future::{AssertSend, SendFuture}; +use diskann_utils::future::SendFuture; use thiserror::Error; use super::{Search, scratch::SearchScratch}; @@ -13,7 +13,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy}, + glue::{self, DefaultPostProcess, ExpandBeam, PostProcess, SearchExt}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer, @@ -171,7 +171,7 @@ impl Search for Range where DP: DataProvider, T: Sync + ?Sized, - S: SearchStrategy, + S: PostProcess, O: Send + Default + Clone, { type Output = RangeSearchOutput; @@ -251,17 +251,15 @@ where ); let _ = strategy - .post_processor() - .post_process( + .post_process_with( + &DefaultPostProcess, &mut accessor, query, &computer, scratch.in_range.iter().copied(), &mut output_buffer, ) - .send() - .await - .into_ann_result()?; + .await?; // Filter by inner/outer radius let inner_cutoff = if let Some(inner_radius) = self.inner_radius() { diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index acac9ef8e..68994f7ba 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -16,7 +16,7 @@ use diskann_vector::distance::Metric; use thiserror::Error; use crate::{ - ANNError, ANNResult, + ANNError, ANNResult, delegate_default_post_process, error::{Infallible, message}, graph::{AdjacencyList, glue, test::synthetic}, internal::counter::{Counter, LocalCounter}, @@ -952,7 +952,6 @@ impl Strategy { impl glue::SearchStrategy for Strategy { type QueryComputer = ::QueryDistance; - type PostProcessor = glue::CopyIds; type SearchAccessorError = Infallible; type SearchAccessor<'a> = Accessor<'a>; @@ -963,10 +962,10 @@ impl glue::SearchStrategy for Strategy { ) -> Result, Infallible> { Ok(Accessor::new(provider)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl glue::HasDefaultProcessor for Strategy { + delegate_default_post_process!(glue::CopyIds); } impl glue::PruneStrategy for Strategy { @@ -1016,6 +1015,7 @@ impl glue::InplaceDeleteStrategy for Strategy { type DeleteElementError = AccessedInvalidId; type PruneStrategy = Self; type SearchStrategy = Self; + type SearchPostProcessor = glue::DefaultPostProcess; fn prune_strategy(&self) -> Self::PruneStrategy { *self @@ -1025,6 +1025,10 @@ impl glue::InplaceDeleteStrategy for Strategy { *self } + fn search_post_processor(&self) -> Self::SearchPostProcessor { + Default::default() + } + async fn get_delete_element<'a>( &'a self, provider: &'a Provider, From 855d67351548e1b967c68ecb82f4876900d72e37 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 9 Mar 2026 20:06:58 +0530 Subject: [PATCH 02/47] Fix nextest -Dwarnings build and cached delete postprocess bounds --- .../src/model/graph/provider/async_/caching/provider.rs | 7 ++++--- diskann/src/graph/search/diverse_search.rs | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index 5ce7220d7..915b8ca09 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -1075,7 +1075,8 @@ where DP: DataProvider, S: InplaceDeleteStrategy, Cached: PruneStrategy>, - Cached: for<'a> SearchStrategy, S::DeleteElement<'a>>, + for<'a> Cached: + glue::HasDefaultProcessor, S::DeleteElement<'a>>, C: AsyncFriendly, { type DeleteElement<'a> = S::DeleteElement<'a>; @@ -1084,7 +1085,7 @@ where type PruneStrategy = Cached; type SearchStrategy = Cached; - type SearchPostProcessor = S::SearchPostProcessor; + type SearchPostProcessor = glue::DefaultPostProcess; fn prune_strategy(&self) -> Self::PruneStrategy { Cached { @@ -1099,7 +1100,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - self.strategy.search_post_processor() + glue::DefaultPostProcess } fn get_delete_element<'a>( diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index b6f44677c..29cef0f8f 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -5,7 +5,7 @@ //! Diversity-aware search. -use diskann_utils::future::{AssertSend, SendFuture}; +use diskann_utils::future::SendFuture; use hashbrown::HashSet; use super::{Knn, Search, record::NoopSearchRecord, scratch::SearchScratch}; @@ -14,7 +14,7 @@ use crate::{ error::IntoANNResult, graph::{ DiverseSearchParams, - glue::{DefaultPostProcess, PostProcess, SearchExt, SearchStrategy}, + glue::{DefaultPostProcess, PostProcess, SearchExt}, index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, From ecd81bb49fe328e016de32e63225ee745846e5b7 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 9 Mar 2026 20:57:39 +0530 Subject: [PATCH 03/47] Add determinant-diversity search mode to benchmark pipeline --- .../src/search/graph/determinant_diversity.rs | 206 ++++++++++++ .../src/search/graph/mod.rs | 1 + .../example/wikipedia_compare_detdiv.json | 60 ++++ .../src/backend/index/benchmarks.rs | 63 +++- diskann-benchmark/src/backend/index/result.rs | 36 ++ .../src/backend/index/search/knn.rs | 98 ++++++ diskann-benchmark/src/inputs/async_.rs | 36 ++ .../provider/async_/inmem/full_precision.rs | 123 ++++++- diskann/src/graph/glue.rs | 2 +- .../determinant_diversity_post_process.rs | 318 ++++++++++++++++++ diskann/src/graph/search/mod.rs | 4 + 11 files changed, 944 insertions(+), 3 deletions(-) create mode 100644 diskann-benchmark-core/src/search/graph/determinant_diversity.rs create mode 100644 diskann-benchmark/example/wikipedia_compare_detdiv.json create mode 100644 diskann/src/graph/search/determinant_diversity_post_process.rs diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs new file mode 100644 index 000000000..4720545ba --- /dev/null +++ b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs @@ -0,0 +1,206 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::Arc; + +use diskann::{ + ANNResult, + graph::{self, glue}, + provider, +}; +use diskann_benchmark_runner::utils::{MicroSeconds, percentiles}; +use diskann_utils::{future::AsyncFriendly, views::Matrix}; + +use crate::{ + recall, + search::{self, Search, graph::Strategy}, + utils, +}; + +/// A built-in helper for benchmarking determinant-diversity K-nearest neighbors. +#[derive(Debug)] +pub struct KNN +where + DP: provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + strategy: Strategy, +} + +impl KNN +where + DP: provider::DataProvider, +{ + pub fn new( + index: Arc>, + queries: Arc>, + strategy: Strategy, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + })) + } +} + +impl Search for KNN +where + DP: provider::DataProvider, + S: glue::SearchStrategy + + glue::HasDefaultProcessor + + glue::PostProcess + + Clone + + AsyncFriendly, + T: AsyncFriendly + Clone, +{ + type Id = DP::ExternalId; + type Parameters = graph::search::KnnWith; + type Output = super::knn::Metrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { + search::IdCount::Fixed(parameters.inner.k_value()) + } + + async fn search( + &self, + parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> ANNResult + where + O: graph::SearchOutputBuffer + Send, + { + let context = DP::Context::default(); + let stats = self + .index + .search( + parameters.clone(), + self.strategy.get(index)?, + &context, + self.queries.row(index), + buffer, + ) + .await?; + + Ok(super::knn::Metrics { + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + +/// Summary for determinant-diversity KNN runs. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct Summary { + pub setup: search::Setup, + pub parameters: graph::search::KnnWith, + pub end_to_end_latencies: Vec, + pub mean_latencies: Vec, + pub p90_latencies: Vec, + pub p99_latencies: Vec, + pub recall: recall::RecallMetrics, + pub mean_cmps: f64, + pub mean_hops: f64, +} + +pub struct Aggregator<'a, I> { + groundtruth: &'a dyn crate::recall::Rows, + recall_k: usize, + recall_n: usize, +} + +impl<'a, I> Aggregator<'a, I> { + pub fn new( + groundtruth: &'a dyn crate::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> Self { + Self { + groundtruth, + recall_k, + recall_n, + } + } +} + +impl + search::Aggregate< + graph::search::KnnWith, + I, + super::knn::Metrics, + > for Aggregator<'_, I> +where + I: crate::recall::RecallCompatible, +{ + type Output = Summary; + + fn aggregate( + &mut self, + run: search::Run>, + mut results: Vec>, + ) -> anyhow::Result

{ + let recall = match results.first() { + Some(first) => crate::recall::knn( + self.groundtruth, + None, + first.ids().as_rows(), + self.recall_k, + self.recall_n, + true, + )?, + None => anyhow::bail!("Results must be non-empty"), + }; + + let mut mean_latencies = Vec::with_capacity(results.len()); + let mut p90_latencies = Vec::with_capacity(results.len()); + let mut p99_latencies = Vec::with_capacity(results.len()); + + results.iter_mut().for_each(|r| { + match percentiles::compute_percentiles(r.latencies_mut()) { + Ok(values) => { + let percentiles::Percentiles { mean, p90, p99, .. } = values; + mean_latencies.push(mean); + p90_latencies.push(p90); + p99_latencies.push(p99); + } + Err(_) => { + let zero = MicroSeconds::new(0); + mean_latencies.push(0.0); + p90_latencies.push(zero); + p99_latencies.push(zero); + } + } + }); + + Ok(Summary { + setup: run.setup().clone(), + parameters: run.parameters().clone(), + end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), + recall, + mean_latencies, + p90_latencies, + p99_latencies, + mean_cmps: utils::average_all( + results + .iter() + .flat_map(|r| r.output().iter().map(|o| o.comparisons)), + ), + mean_hops: utils::average_all( + results + .iter() + .flat_map(|r| r.output().iter().map(|o| o.hops)), + ), + }) + } +} diff --git a/diskann-benchmark-core/src/search/graph/mod.rs b/diskann-benchmark-core/src/search/graph/mod.rs index eddb4fbcf..cfcecb0db 100644 --- a/diskann-benchmark-core/src/search/graph/mod.rs +++ b/diskann-benchmark-core/src/search/graph/mod.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub mod determinant_diversity; pub mod knn; pub mod multihop; pub mod range; diff --git a/diskann-benchmark/example/wikipedia_compare_detdiv.json b/diskann-benchmark/example/wikipedia_compare_detdiv.json new file mode 100644 index 000000000..3e4ffd150 --- /dev/null +++ b/diskann-benchmark/example/wikipedia_compare_detdiv.json @@ -0,0 +1,60 @@ +{ + "search_directories": [ + "C:/wikipedia_dataset" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" + }, + "search_phase": { + "search-type": "topk", + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "reps": 1, + "num_threads": [8], + "runs": [ + { + "search_n": 10, + "search_l": [20, 30, 40, 50, 100, 200], + "recall_k": 10 + } + ] + } + } + }, + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" + }, + "search_phase": { + "search-type": "topk", + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "reps": 1, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 1.0, + "determinant_diversity_results_k": 10, + "num_threads": [8], + "runs": [ + { + "search_n": 10, + "search_l": [20, 30, 40, 50, 100, 200], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index a5583af94..e0f8aed90 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -493,6 +493,67 @@ where } } +pub(super) fn run_search_outer_full_precision( + input: &SearchPhase, + search_strategy: S, + index: Index, + build_stats: Option, + checkpoint: Checkpoint<'_>, +) -> anyhow::Result +where + DP: DataProvider + + provider::SetElement<[T]>, + T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, + S: glue::SearchStrategy + + glue::HasDefaultProcessor + + glue::PostProcess + + Clone + + AsyncFriendly, +{ + if let SearchPhase::Topk(search_phase) = input { + if let (Some(eta), Some(power)) = ( + search_phase.determinant_diversity_eta, + search_phase.determinant_diversity_power, + ) { + let mut result = BuildResult::new_topk(build_stats); + checkpoint.checkpoint(&result)?; + + let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( + &search_phase.queries, + ))?); + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; + + let knn = benchmark_core::search::graph::determinant_diversity::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + )?; + + let steps = search::knn::SearchSteps::new( + search_phase.reps, + &search_phase.num_threads, + &search_phase.runs, + ); + + let search_results = search::knn::run_determinant_diversity( + &knn, + &groundtruth, + steps, + eta, + power, + search_phase.determinant_diversity_results_k, + )?; + + result.append(AggregatedSearchResults::Topk(search_results)); + return Ok(result); + } + } + + run_search_outer(input, search_strategy, index, build_stats, checkpoint) +} + macro_rules! impl_build { ($T:ty) => { impl<'a> BuildAndSearch<'a> for FullPrecision<'a, $T> { @@ -544,7 +605,7 @@ macro_rules! impl_build { } }; - let result = run_search_outer( + let result = run_search_outer_full_precision( &self.input.search_phase, common::FullPrecision, index, diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 1d6102f9b..1f9c2e50a 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -155,6 +155,42 @@ impl SearchResults { mean_hops: mean_hops as f32, } } + + pub fn new_determinant_diversity( + summary: benchmark_core::search::graph::determinant_diversity::Summary, + ) -> Self { + let benchmark_core::search::graph::determinant_diversity::Summary { + setup, + parameters, + end_to_end_latencies, + mean_latencies, + p90_latencies, + p99_latencies, + recall, + mean_cmps, + mean_hops, + .. + } = summary; + + let qps = end_to_end_latencies + .iter() + .map(|latency| recall.num_queries as f64 / latency.as_seconds()) + .collect(); + + Self { + num_tasks: setup.tasks.into(), + search_n: parameters.inner.k_value().get(), + search_l: parameters.inner.l_value().get(), + qps, + search_latencies: end_to_end_latencies, + mean_latencies, + p90_latencies, + p99_latencies, + recall: (&recall).into(), + mean_cmps: mean_cmps as f32, + mean_hops: mean_hops as f32, + } + } } fn format_search_results_table( diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 915b8eca6..957f9c924 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -63,6 +63,48 @@ pub(crate) fn run( Ok(all) } +pub(crate) fn run_determinant_diversity( + runner: &dyn DeterminantDiversityKnn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + eta: f64, + power: f64, + results_k: Option, +) -> anyhow::Result> { + let mut all = Vec::new(); + + for threads in steps.num_tasks.iter() { + for run in steps.runs.iter() { + let setup = core_search::Setup { + threads: *threads, + tasks: *threads, + reps: steps.reps, + }; + + let parameters: Vec<_> = run + .search_l + .iter() + .map(|search_l| { + let base = + diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + let processor = diskann::graph::search::DeterminantDiversitySearchParams::new( + results_k.unwrap_or(run.search_n), + eta, + power, + ); + let search_params = diskann::graph::search::KnnWith::new(base, processor); + + core_search::Run::new(search_params, setup.clone()) + }) + .collect(); + + all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); + } + } + + Ok(all) +} + type Run = core_search::Run; pub(crate) trait Knn { fn search_all( @@ -74,6 +116,20 @@ pub(crate) trait Knn { ) -> anyhow::Result>; } +type DeterminantRun = core_search::Run< + diskann::graph::search::KnnWith, +>; + +pub(crate) trait DeterminantDiversityKnn { + fn search_all( + &self, + parameters: Vec, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result>; +} + /////////// // Impls // /////////// @@ -129,3 +185,45 @@ where Ok(results.into_iter().map(SearchResults::new).collect()) } } + +impl DeterminantDiversityKnn + for Arc> +where + DP: diskann::provider::DataProvider, + core_search::graph::determinant_diversity::KNN: core_search::Search< + Id = DP::InternalId, + Parameters = diskann::graph::search::KnnWith< + diskann::graph::search::DeterminantDiversitySearchParams, + >, + Output = core_search::graph::knn::Metrics, + >, +{ + fn search_all( + &self, + parameters: Vec< + core_search::Run< + diskann::graph::search::KnnWith< + diskann::graph::search::DeterminantDiversitySearchParams, + >, + >, + >, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result> { + let results = core_search::search_all( + self.clone(), + parameters.into_iter(), + core_search::graph::determinant_diversity::Aggregator::new( + groundtruth, + recall_k, + recall_n, + ), + )?; + + Ok(results + .into_iter() + .map(SearchResults::new_determinant_diversity) + .collect()) + } +} diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 19230977d..0f8026b58 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -123,6 +123,9 @@ pub(crate) struct TopkSearchPhase { pub(crate) queries: InputFile, pub(crate) groundtruth: InputFile, pub(crate) reps: NonZeroUsize, + pub(crate) determinant_diversity_eta: Option, + pub(crate) determinant_diversity_power: Option, + pub(crate) determinant_diversity_results_k: Option, // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, @@ -139,6 +142,36 @@ impl CheckDeserialization for TopkSearchPhase { .with_context(|| format!("search run {}", i))?; } + if self.determinant_diversity_eta.is_some() != self.determinant_diversity_power.is_some() { + return Err(anyhow!( + "determinant_diversity_eta and determinant_diversity_power must either both be set or both be omitted" + )); + } + + if let Some(eta) = self.determinant_diversity_eta { + if eta < 0.0 { + return Err(anyhow!( + "determinant_diversity_eta must be >= 0.0, got {}", + eta + )); + } + } + + if let Some(power) = self.determinant_diversity_power { + if power < 0.0 { + return Err(anyhow!( + "determinant_diversity_power must be >= 0.0, got {}", + power + )); + } + } + + if let Some(k) = self.determinant_diversity_results_k { + if k == 0 { + return Err(anyhow!("determinant_diversity_results_k must be > 0")); + } + } + Ok(()) } } @@ -164,6 +197,9 @@ impl Example for TopkSearchPhase { queries: InputFile::new("path/to/queries"), groundtruth: InputFile::new("path/to/groundtruth"), reps: REPS, + determinant_diversity_eta: None, + determinant_diversity_power: None, + determinant_diversity_results_k: None, num_threads: THREAD_COUNTS.to_vec(), runs, } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 0592994af..45c76f84a 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -8,12 +8,15 @@ use std::{collections::HashMap, fmt::Debug, future::Future}; use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, + error::IntoANNResult, graph::{ SearchOutputBuffer, glue::{ self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, - InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + InplaceDeleteStrategy, InsertStrategy, PostProcess, PruneStrategy, SearchExt, + SearchStrategy, }, + search::{DeterminantDiversitySearchParams, determinant_diversity_post_process}, }, neighbor::Neighbor, provider::{ @@ -475,6 +478,65 @@ where delegate_default_post_process!(RemoveDeletedIdsAndCopy); } +impl + PostProcess, [T], DeterminantDiversitySearchParams> + for Internal +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: &DeterminantDiversitySearchParams, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + _computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + let query_f32 = T::as_f32(query).into_ann_result()?.to_vec(); + let mut candidates_with_vectors = Vec::new(); + + for candidate in candidates { + if accessor.provider.deleted.deletion_check(candidate.id) { + continue; + } + + let vector = accessor.get_element(candidate.id).await.into_ann_result()?; + let vector_f32 = T::as_f32(vector).into_ann_result()?; + candidates_with_vectors.push(( + candidate.id, + candidate.distance, + vector_f32.to_vec(), + )); + } + + let borrowed: Vec<(u32, f32, &[f32])> = candidates_with_vectors + .iter() + .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) + .collect(); + + let reranked = determinant_diversity_post_process( + borrowed, + &query_f32, + processor.top_k, + processor.determinant_diversity_eta, + processor.determinant_diversity_power, + ); + + Ok(output.extend(reranked)) + } + } +} + /// Perform a search entirely in the full-precision space. /// /// Starting points are not filtered out of the final results. @@ -508,6 +570,65 @@ where delegate_default_post_process!(glue::Pipeline); } +impl + PostProcess, [T], DeterminantDiversitySearchParams> + for FullPrecision +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: &DeterminantDiversitySearchParams, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + _computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + let query_f32 = T::as_f32(query).into_ann_result()?.to_vec(); + let mut candidates_with_vectors = Vec::new(); + + for candidate in candidates { + if accessor.provider.deleted.deletion_check(candidate.id) { + continue; + } + + let vector = accessor.get_element(candidate.id).await.into_ann_result()?; + let vector_f32 = T::as_f32(vector).into_ann_result()?; + candidates_with_vectors.push(( + candidate.id, + candidate.distance, + vector_f32.to_vec(), + )); + } + + let borrowed: Vec<(u32, f32, &[f32])> = candidates_with_vectors + .iter() + .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) + .collect(); + + let reranked = determinant_diversity_post_process( + borrowed, + &query_f32, + processor.top_k, + processor.determinant_diversity_eta, + processor.determinant_diversity_power, + ); + + Ok(output.extend(reranked)) + } + } +} + // Pruning impl PruneStrategy> for FullPrecision where diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 5dc1179fb..01638a0f8 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -339,7 +339,7 @@ where /// [`super::search::KnnWith`], etc.) bounds on. /// /// The blanket impl covers `P = DefaultPostProcess` for any strategy implementing -/// [`HasDefaultProcessor`]. Custom processor types (e.g. `RagSearchParams`) can have +/// [`HasDefaultProcessor`]. Custom processor types (e.g. `DeterminantDiversitySearchParams`) can have /// their own `PostProcess` impls without coherence conflicts. pub trait PostProcess::InternalId>: SearchStrategy diff --git a/diskann/src/graph/search/determinant_diversity_post_process.rs b/diskann/src/graph/search/determinant_diversity_post_process.rs new file mode 100644 index 000000000..33e4b987d --- /dev/null +++ b/diskann/src/graph/search/determinant_diversity_post_process.rs @@ -0,0 +1,318 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Determinant-diversity search post-processing. +//! +//! This module provides post-processing functionality for determinant-diversity search, +//! which reranks search results to maximize diversity using a greedy +//! orthogonalization algorithm. + +use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; + +/// Parameters for determinant-diversity reranking. +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversitySearchParams { + pub top_k: usize, + pub determinant_diversity_eta: f64, + pub determinant_diversity_power: f64, +} + +impl DeterminantDiversitySearchParams { + pub fn new( + top_k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, + ) -> Self { + Self { + top_k, + determinant_diversity_eta, + determinant_diversity_power, + } + } +} + +/// Post-process search results using determinant-diversity reranking. +/// +/// If `determinant_diversity_eta > 0.0`, uses a ridge-aware variant. +/// Otherwise, uses greedy orthogonalization. +pub fn determinant_diversity_post_process( + candidates: Vec<(Id, f32, &[f32])>, + query: &[f32], + k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + if candidates.is_empty() { + return Vec::new(); + } + + let k = k.min(candidates.len()); + + let candidates_f32: Vec<(Id, f32, Vec)> = candidates + .into_iter() + .map(|(id, dist, v)| (id, dist, v.to_vec())) + .collect(); + + let results = if determinant_diversity_eta > 0.0 { + post_process_with_eta_f32( + candidates_f32, + query, + k, + determinant_diversity_eta, + determinant_diversity_power, + ) + } else { + post_process_greedy_orthogonalization_f32( + candidates_f32, + query, + k, + determinant_diversity_power, + ) + }; + + debug_assert_eq!( + results.len(), + k, + "determinant-diversity post-process should return exactly k={} results, got {}", + k, + results.len() + ); + + results +} + +fn post_process_with_eta_f32( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + let eta = determinant_diversity_eta as f32; + let power = determinant_diversity_power; + + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let n = candidates.len(); + let k = k.min(n); + + if k == 0 { + return Vec::new(); + } + + let d = candidates[0].2.len(); + if d == 0 { + return Vec::new(); + } + + let inv_sqrt_eta = 1.0 / eta.sqrt(); + + let mut residuals: Vec> = Vec::with_capacity(n); + let mut norms_sq: Vec = Vec::with_capacity(n); + + for (_, _, v) in &candidates { + let similarity = dot_product(v, query); + let scale = similarity.max(0.0).powf(power as f32) * inv_sqrt_eta; + let r: Vec = v.iter().map(|&x| x * scale).collect(); + let s = dot_product(&r, &r); + residuals.push(r); + norms_sq.push(s); + } + + let mut available: Vec = vec![true; n]; + let mut selected: Vec = Vec::with_capacity(k); + + for _ in 0..k { + let best_idx = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i); + + let Some(j) = best_idx else { + break; + }; + + selected.push(j); + available[j] = false; + + if selected.len() == k { + break; + } + + let norm_factor = 1.0 / (1.0 + norms_sq[j]).sqrt(); + let q: Vec = residuals[j].iter().map(|&x| x * norm_factor).collect(); + + for i in 0..n { + if !available[i] { + continue; + } + + let alpha = dot_product(&q, &residuals[i]); + + for (r_val, &q_val) in residuals[i].iter_mut().zip(q.iter()) { + *r_val -= alpha * q_val; + } + + norms_sq[i] = (norms_sq[i] - alpha * alpha).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = candidates[idx]; + (id, dist) + }) + .collect() +} + +fn post_process_greedy_orthogonalization_f32( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + let power = determinant_diversity_power; + + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let n = candidates.len(); + let k = k.min(n); + + if k == 0 { + return Vec::new(); + } + + let mut residuals: Vec> = Vec::with_capacity(n); + let mut norms_sq: Vec = Vec::with_capacity(n); + + for (_, _, v) in &candidates { + let similarity = dot_product(v, query); + let scale = similarity.max(0.0).powf(power as f32); + let r: Vec = v.iter().map(|&x| x * scale).collect(); + let s = dot_product(&r, &r); + residuals.push(r); + norms_sq.push(s); + } + + let mut available: Vec = vec![true; n]; + let mut selected: Vec = Vec::with_capacity(k); + + for _ in 0..k { + let best = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let Some((i_star, _)) = best else { + break; + }; + + let best_norm_sq = norms_sq[i_star]; + + selected.push(i_star); + available[i_star] = false; + + if selected.len() == k { + break; + } + + if best_norm_sq <= 0.0 { + continue; + } + + let r_star = residuals[i_star].clone(); + let inv_norm_sq_star = 1.0 / best_norm_sq; + + for j in 0..n { + if !available[j] { + continue; + } + + let proj_coeff = dot_product(&residuals[j], &r_star) * inv_norm_sq_star; + + for (r_val, &rs_val) in residuals[j].iter_mut().zip(r_star.iter()) { + *r_val -= proj_coeff * rs_val; + } + + norms_sq[j] = (norms_sq[j] - proj_coeff * proj_coeff * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = candidates[idx]; + (id, dist) + }) + .collect() +} + +#[inline] +fn dot_product(a: &[f32], b: &[f32]) -> f32 { + >>::evaluate(a, b) + .into_inner() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_determinant_diversity_post_process_with_eta() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.0f32, 1.0, 0.0]; + let v3 = vec![0.0f32, 0.0, 1.0]; + let candidates = vec![ + (1u32, 0.5f32, v1.as_slice()), + (2u32, 0.3f32, v2.as_slice()), + (3u32, 0.7f32, v3.as_slice()), + ]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); + assert_eq!(result.len(), 3); + } + + #[test] + fn test_determinant_diversity_post_process_enabled_greedy() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.99f32, 0.1, 0.0]; + let v3 = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![ + (1u32, 0.5f32, v1.as_slice()), + (2u32, 0.3f32, v2.as_slice()), + (3u32, 0.4f32, v3.as_slice()), + ]; + let query = vec![1.0, 1.0, 0.0]; + + let result = determinant_diversity_post_process(candidates, &query, 2, 0.0, 1.0); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_determinant_diversity_post_process_empty() { + let candidates: Vec<(u32, f32, &[f32])> = vec![]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); + assert!(result.is_empty()); + } +} diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 5186bc684..1df51d4cc 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -28,6 +28,7 @@ use diskann_utils::future::SendFuture; use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; +mod determinant_diversity_post_process; mod knn_search; mod multihop_search; mod range_search; @@ -88,6 +89,9 @@ where } // Re-export search parameter types. +pub use determinant_diversity_post_process::{ + DeterminantDiversitySearchParams, determinant_diversity_post_process, +}; pub use knn_search::{Knn, KnnSearchError, KnnWith, RecordedKnn}; pub use multihop_search::MultihopSearch; pub use range_search::{Range, RangeSearchError, RangeSearchOutput}; From f3d308a9bb75240865e59712b90a6c6f6d56d6bd Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 9 Mar 2026 21:11:04 +0530 Subject: [PATCH 04/47] Refactor search API to explicit processors and add search_with --- diskann/src/graph/index.rs | 30 +++++++++++++++++++-- diskann/src/graph/search/diverse_search.rs | 10 ++++--- diskann/src/graph/search/knn_search.rs | 19 ++++++++----- diskann/src/graph/search/mod.rs | 4 ++- diskann/src/graph/search/multihop_search.rs | 11 ++++---- diskann/src/graph/search/range_search.rs | 10 ++++--- 6 files changed, 61 insertions(+), 23 deletions(-) diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index f64046185..2b7662eae 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2152,11 +2152,37 @@ where output: &mut OB, ) -> impl SendFuture> where - P: super::search::Search, + P: super::search::Search, T: ?Sized, OB: ?Sized, { - search_params.search(self, strategy, context, query, output) + self.search_with( + search_params, + strategy, + &glue::DefaultPostProcess, + context, + query, + output, + ) + } + + /// Execute a search with an explicit post-processor parameter. + pub fn search_with( + &self, + search_params: P, + strategy: &S, + processor: &PP, + context: &DP::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> + where + P: super::search::Search, + PP: Send + Sync, + T: ?Sized, + OB: ?Sized, + { + search_params.search(self, strategy, processor, context, query, output) } /// Performs a brute-force flat search over the points matching a provided filter function. diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 29cef0f8f..da9aa8486 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -14,7 +14,7 @@ use crate::{ error::IntoANNResult, graph::{ DiverseSearchParams, - glue::{DefaultPostProcess, PostProcess, SearchExt}, + glue::{PostProcess, SearchExt}, index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, @@ -92,14 +92,15 @@ where } } -impl Search for Diverse

+impl Search for Diverse

where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send, P: AttributeValueProvider, + PP: Send + Sync, { type Output = SearchStats; @@ -107,6 +108,7 @@ where self, index: &DiskANNIndex, strategy: &S, + processor: &PP, context: &DP::Context, query: &T, output: &mut OB, @@ -137,7 +139,7 @@ where let result_count = strategy .post_process_with( - &DefaultPostProcess, + processor, &mut accessor, query, &computer, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 4c676309f..5e0cb4160 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -197,13 +197,14 @@ impl Knn { } } -impl Search for Knn +impl Search for Knn where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send + ?Sized, + PP: Send + Sync, { type Output = SearchStats; @@ -212,12 +213,13 @@ where self, index: &DiskANNIndex, strategy: &S, + processor: &PP, context: &DP::Context, query: &T, output: &mut OB, ) -> impl SendFuture> { async move { - self.search_core(index, strategy, context, query, output, &DefaultPostProcess) + self.search_core(index, strategy, context, query, output, processor) .await } } @@ -245,14 +247,15 @@ impl<'r, SR: ?Sized> RecordedKnn<'r, SR> { } } -impl<'r, DP, S, T, O, OB, SR> Search for RecordedKnn<'r, SR> +impl<'r, DP, S, T, O, OB, SR, PP> Search for RecordedKnn<'r, SR> where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send + ?Sized, SR: super::record::SearchRecord + ?Sized, + PP: Send + Sync, { type Output = SearchStats; @@ -260,6 +263,7 @@ where self, index: &DiskANNIndex, strategy: &S, + processor: &PP, context: &DP::Context, query: &T, output: &mut OB, @@ -287,7 +291,7 @@ where let result_count = strategy .post_process_with( - &DefaultPostProcess, + processor, &mut accessor, query, &computer, @@ -330,7 +334,7 @@ impl KnnWith { } } -impl Search for KnnWith +impl Search for KnnWith where DP: DataProvider, T: Sync + ?Sized, @@ -346,6 +350,7 @@ where self, index: &DiskANNIndex, strategy: &S, + _processor: &DefaultPostProcess, context: &DP::Context, query: &T, output: &mut OB, diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 1df51d4cc..3a1b5f83e 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -49,9 +49,10 @@ pub(crate) mod scratch; /// - [`Diverse`] - Diversity-aware search (feature-gated) /// - [`MultihopSearch`] - Label-filtered search with multi-hop expansion /// - [`RecordedKnn`] - K-NN search with path recording for debugging -pub trait Search +pub trait Search where DP: DataProvider, + PP: Send + Sync, { /// The result type returned by this search. type Output; @@ -82,6 +83,7 @@ where self, index: &DiskANNIndex, strategy: &S, + processor: &PP, context: &DP::Context, query: &T, output: &mut OB, diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 2f7c9e659..6038ef962 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -16,8 +16,7 @@ use crate::{ error::{ErrorExt, IntoANNResult}, graph::{ glue::{ - self, DefaultPostProcess, ExpandBeam, HybridPredicate, PostProcess, Predicate, - PredicateMut, SearchExt, + self, ExpandBeam, HybridPredicate, PostProcess, Predicate, PredicateMut, SearchExt, }, index::{ DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, @@ -53,13 +52,14 @@ impl<'q, InternalId> MultihopSearch<'q, InternalId> { } } -impl<'q, DP, S, T, O, OB> Search for MultihopSearch<'q, DP::InternalId> +impl<'q, DP, S, T, O, OB, PP> Search for MultihopSearch<'q, DP::InternalId> where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, + S: PostProcess, O: Send, OB: SearchOutputBuffer + Send, + PP: Send + Sync, { type Output = SearchStats; @@ -67,6 +67,7 @@ where self, index: &DiskANNIndex, strategy: &S, + processor: &PP, context: &DP::Context, query: &T, output: &mut OB, @@ -94,7 +95,7 @@ where let result_count = strategy .post_process_with( - &DefaultPostProcess, + processor, &mut accessor, query, &computer, diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 4228b28d8..160aa932e 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -13,7 +13,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, DefaultPostProcess, ExpandBeam, PostProcess, SearchExt}, + glue::{self, ExpandBeam, PostProcess, SearchExt}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer, @@ -167,12 +167,13 @@ impl Range { } } -impl Search for Range +impl Search for Range where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, + S: PostProcess, O: Send + Default + Clone, + PP: Send + Sync, { type Output = RangeSearchOutput; @@ -180,6 +181,7 @@ where self, index: &DiskANNIndex, strategy: &S, + processor: &PP, context: &DP::Context, query: &T, _output: &mut (), @@ -252,7 +254,7 @@ where let _ = strategy .post_process_with( - &DefaultPostProcess, + processor, &mut accessor, query, &computer, From 4a74343480a4435690b2c5ed4357569ef8ca2cb8 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 9 Mar 2026 22:01:37 +0530 Subject: [PATCH 05/47] Add disk determinant-diversity search wiring and benchmark inputs --- ...kipedia_disk_build_and_compare_detdiv.json | 82 ++++++++ .../wikipedia_disk_build_baseline.json | 37 ++++ .../wikipedia_disk_compare_detdiv.json | 52 +++++ .../wikipedia_disk_load_compare_detdiv.json | 52 +++++ .../src/backend/disk_index/build.rs | 6 +- .../src/backend/disk_index/search.rs | 39 +++- diskann-benchmark/src/inputs/disk.rs | 30 +++ .../src/search/provider/disk_provider.rs | 181 +++++++++++++++++- 8 files changed, 466 insertions(+), 13 deletions(-) create mode 100644 diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json create mode 100644 diskann-benchmark/example/wikipedia_disk_build_baseline.json create mode 100644 diskann-benchmark/example/wikipedia_disk_compare_detdiv.json create mode 100644 diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json diff --git a/diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json b/diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json new file mode 100644 index 000000000..166f1edc4 --- /dev/null +++ b/diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json @@ -0,0 +1,82 @@ +{ + "search_directories": [ + "C:/wikipedia_dataset" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "C:/wikipedia_dataset/data.bin", + "distance": "squared_l2", + "dim": 1024, + "max_degree": 64, + "l_build": 100, + "num_threads": 8, + "build_ram_limit_gb": 32.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 1.0, + "determinant_diversity_results_k": 10 + } + } + } + ] +} diff --git a/diskann-benchmark/example/wikipedia_disk_build_baseline.json b/diskann-benchmark/example/wikipedia_disk_build_baseline.json new file mode 100644 index 000000000..a3abf3eb7 --- /dev/null +++ b/diskann-benchmark/example/wikipedia_disk_build_baseline.json @@ -0,0 +1,37 @@ +{ + "search_directories": [ + "C:/wikipedia_dataset" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "C:/wikipedia_dataset/data.bin", + "distance": "squared_l2", + "dim": 1024, + "max_degree": 64, + "l_build": 100, + "num_threads": 8, + "build_ram_limit_gb": 32.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + } + ] +} diff --git a/diskann-benchmark/example/wikipedia_disk_compare_detdiv.json b/diskann-benchmark/example/wikipedia_disk_compare_detdiv.json new file mode 100644 index 000000000..d3c430a28 --- /dev/null +++ b/diskann-benchmark/example/wikipedia_disk_compare_detdiv.json @@ -0,0 +1,52 @@ +{ + "search_directories": [ + "C:/wikipedia_dataset" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 1.0, + "determinant_diversity_results_k": 10 + } + } + } + ] +} diff --git a/diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json b/diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json new file mode 100644 index 000000000..a7e08c15b --- /dev/null +++ b/diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json @@ -0,0 +1,52 @@ +{ + "search_directories": [ + "C:/wikipedia_dataset" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" + }, + "search_phase": { + "queries": "C:/wikipedia_dataset/query.bin", + "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", + "search_list": [20, 30, 40, 50, 100, 200], + "beam_width": 8, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 1.0, + "determinant_diversity_results_k": 10 + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/disk_index/build.rs b/diskann-benchmark/src/backend/disk_index/build.rs index b080d443a..b6ebf3b83 100644 --- a/diskann-benchmark/src/backend/disk_index/build.rs +++ b/diskann-benchmark/src/backend/disk_index/build.rs @@ -94,13 +94,13 @@ where let build_parameters = DiskIndexBuildParameters::new( MemoryBudget::try_from_gb(params.build_ram_limit_gb)?, params.quantization_type, - NumPQChunks::new_with(params.num_pq_chunks.get(), metadata.ndims)?, + NumPQChunks::new_with(params.num_pq_chunks.get(), metadata.ndims())?, ); let index_configuration = IndexConfiguration::new( metric, - metadata.ndims, - metadata.npoints, + metadata.ndims(), + metadata.npoints(), ONE, params.num_threads, config, diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 65e5804a7..1de1d0eb0 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -9,6 +9,7 @@ use std::{collections::HashSet, fmt, sync::atomic::AtomicBool, time::Instant}; use opentelemetry::{global, trace::Span, trace::Tracer}; use opentelemetry_sdk::trace::SdkTracerProvider; +use diskann::graph::search::DeterminantDiversitySearchParams; use diskann::utils::VectorRepr; use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ @@ -269,14 +270,38 @@ where as Box bool + Send + Sync>) }; - match searcher.search( - q, - search_params.recall_at, - l, - Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, + let search_result = if let (Some(eta), Some(power)) = ( + search_params.determinant_diversity_eta, + search_params.determinant_diversity_power, ) { + let processor = DeterminantDiversitySearchParams::new( + search_params + .determinant_diversity_results_k + .unwrap_or(search_params.recall_at as usize), + eta, + power, + ); + searcher.search_determinant_diversity( + q, + search_params.recall_at, + l, + Some(search_params.beam_width), + vector_filter, + search_params.is_flat_search, + processor, + ) + } else { + searcher.search( + q, + search_params.recall_at, + l, + Some(search_params.beam_width), + vector_filter, + search_params.is_flat_search, + ) + }; + + match search_result { Ok(search_result) => { *stats = search_result.stats.query_statistics; *rc = search_result.results.len() as u32; diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index bf843d72f..22376a648 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -85,6 +85,9 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, + pub(crate) determinant_diversity_eta: Option, + pub(crate) determinant_diversity_power: Option, + pub(crate) determinant_diversity_results_k: Option, } ///////// @@ -234,6 +237,30 @@ impl CheckDeserialization for DiskSearchPhase { anyhow::bail!("search_io_limit must be positive if specified"); } } + + if self.determinant_diversity_eta.is_some() != self.determinant_diversity_power.is_some() { + anyhow::bail!( + "determinant_diversity_eta and determinant_diversity_power must either both be set or both omitted" + ); + } + + if let Some(eta) = self.determinant_diversity_eta { + if eta < 0.0 { + anyhow::bail!("determinant_diversity_eta must be >= 0.0"); + } + } + + if let Some(power) = self.determinant_diversity_power { + if power < 0.0 { + anyhow::bail!("determinant_diversity_power must be >= 0.0"); + } + } + + if let Some(k) = self.determinant_diversity_results_k { + if k == 0 { + anyhow::bail!("determinant_diversity_results_k must be > 0"); + } + } Ok(()) } } @@ -272,6 +299,9 @@ impl Example for DiskIndexOperation { vector_filters_file: None, num_nodes_to_cache: None, search_io_limit: None, + determinant_diversity_eta: None, + determinant_diversity_power: None, + determinant_diversity_results_k: None, }; Self { diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 870541f0c..79223b78d 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -16,13 +16,14 @@ use std::{ }; use diskann::{ + error::IntoANNResult, graph::{ self, glue::{ - self, ExpandBeam, HasDefaultProcessor, IdIterator, SearchExt, SearchPostProcess, - SearchStrategy, + self, ExpandBeam, HasDefaultProcessor, IdIterator, PostProcess, SearchExt, + SearchPostProcess, SearchStrategy, }, - search::Knn, + search::{determinant_diversity_post_process, DeterminantDiversitySearchParams, Knn}, search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, }, neighbor::Neighbor, @@ -390,6 +391,93 @@ where } } +impl<'this, Data, ProviderFactory> + PostProcess< + DiskProvider, + [Data::VectorDataType], + DeterminantDiversitySearchParams, + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DiskSearchStrategy<'this, Data, ProviderFactory> +where + Data: GraphDataType, + ProviderFactory: VertexProviderFactory, +{ + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: &DeterminantDiversitySearchParams, + accessor: &mut Self::SearchAccessor<'a>, + query: &[Data::VectorDataType], + _computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator as DataProvider>::InternalId>> + Send, + B: SearchOutputBuffer<( + as DataProvider>::InternalId, + Data::AssociatedDataType, + )> + Send + + ?Sized, + { + async move { + let provider = accessor.provider; + let query_f32 = Data::VectorDataType::as_f32(query) + .into_ann_result()? + .to_vec(); + + let filtered_ids: Vec = candidates + .map(|n| n.id) + .filter(|id| (self.vector_filter)(id)) + .collect(); + + if filtered_ids.is_empty() { + return Ok(0); + } + + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &filtered_ids)?; + + let mut enriched: Vec<(u32, f32, Vec, Data::AssociatedDataType)> = Vec::new(); + for id in filtered_ids { + let vector = accessor.scratch.vertex_provider.get_vector(&id)?; + let vector_f32 = Data::VectorDataType::as_f32(vector) + .into_ann_result()? + .to_vec(); + let distance = provider + .distance_comparer + .evaluate_similarity(query, vector); + let assoc = *accessor.scratch.vertex_provider.get_associated_data(&id)?; + enriched.push((id, distance, vector_f32, assoc)); + } + + let borrowed: Vec<(u32, f32, &[f32])> = enriched + .iter() + .map(|(id, dist, vector, _)| (*id, *dist, vector.as_slice())) + .collect(); + + let reranked = determinant_diversity_post_process( + borrowed, + &query_f32, + processor.top_k, + processor.determinant_diversity_eta, + processor.determinant_diversity_power, + ); + + let mut pairs = Vec::with_capacity(reranked.len()); + for (id, distance) in reranked { + if let Some((_, _, _, assoc)) = enriched.iter().find(|(eid, _, _, _)| *eid == id) { + pairs.push(((id, *assoc), distance)); + } + } + + Ok(output.extend(pairs)) + } + } +} + /// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates. pub struct DiskQueryComputer { num_pq_chunks: usize, @@ -978,6 +1066,93 @@ where Ok(search_result) } + /// Perform a determinant-diversity search on the disk index. + #[allow(clippy::too_many_arguments)] + pub fn search_determinant_diversity( + &self, + query: &[Data::VectorDataType], + return_list_size: u32, + search_list_size: u32, + beam_width: Option, + vector_filter: Option>, + is_flat_search: bool, + processor: DeterminantDiversitySearchParams, + ) -> ANNResult> { + let mut query_stats = QueryStatistics::default(); + let mut indices = vec![0u32; return_list_size as usize]; + let mut distances = vec![0f32; return_list_size as usize]; + let mut associated_data = + vec![Data::AssociatedDataType::default(); return_list_size as usize]; + + let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( + &mut indices, + &mut distances, + &mut associated_data, + ); + + let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); + let strategy = self.search_strategy(query, &*vector_filter); + let timer = Instant::now(); + let k = return_list_size as usize; + let l = search_list_size as usize; + + let stats = if is_flat_search { + self.runtime.block_on(self.index.flat_search( + &strategy, + &DefaultContext, + strategy.query, + strategy.vector_filter, + &Knn::new(k, l, beam_width)?, + &mut result_output_buffer, + ))? + } else { + let knn_search = Knn::new(k, l, beam_width)?; + self.runtime.block_on(self.index.search_with( + knn_search, + &strategy, + &processor, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? + }; + + query_stats.total_comparisons = stats.cmps; + query_stats.search_hops = stats.hops; + query_stats.total_execution_time_us = timer.elapsed().as_micros(); + query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128; + query_stats.total_io_operations = strategy.io_tracker.io_count() as u32; + query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32; + query_stats.query_pq_preprocess_time_us = + IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128; + query_stats.cpu_time_us = query_stats.total_execution_time_us + - query_stats.io_time_us + - query_stats.query_pq_preprocess_time_us; + + let mut search_result = SearchResult { + results: Vec::with_capacity(return_list_size as usize), + stats: SearchResultStats { + cmps: query_stats.total_comparisons, + result_count: stats.result_count, + query_statistics: query_stats.clone(), + }, + }; + + for ((vertex_id, distance), associated_data) in indices + .into_iter() + .zip(distances.into_iter()) + .zip(associated_data.into_iter()) + { + search_result.results.push(SearchResultItem { + vertex_id, + distance, + data: associated_data, + }); + } + + Ok(search_result) + } + /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. #[allow(clippy::too_many_arguments)] From 2729ac74bed0af7c9227299caa83ef9f7138d9c1 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 17:08:20 +0530 Subject: [PATCH 06/47] Refactor processor passing by value and remove KnnWith --- .../src/search/graph/determinant_diversity.rs | 26 +++---- .../src/backend/index/search/knn.rs | 19 +++-- .../src/search/provider/disk_provider.rs | 4 +- .../provider/async_/inmem/full_precision.rs | 61 +--------------- diskann/src/graph/glue.rs | 8 +-- diskann/src/graph/index.rs | 6 +- diskann/src/graph/search/diverse_search.rs | 2 +- diskann/src/graph/search/knn_search.rs | 70 ++----------------- diskann/src/graph/search/mod.rs | 4 +- diskann/src/graph/search/multihop_search.rs | 2 +- diskann/src/graph/search/range_search.rs | 2 +- 11 files changed, 42 insertions(+), 162 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs index 4720545ba..d69d18937 100644 --- a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs +++ b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs @@ -19,6 +19,12 @@ use crate::{ utils, }; +#[derive(Debug, Clone, Copy)] +pub struct Parameters { + pub inner: graph::search::Knn, + pub processor: graph::search::DeterminantDiversitySearchParams, +} + /// A built-in helper for benchmarking determinant-diversity K-nearest neighbors. #[derive(Debug)] pub struct KNN @@ -60,7 +66,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::KnnWith; + type Parameters = Parameters; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -83,9 +89,10 @@ where let context = DP::Context::default(); let stats = self .index - .search( - parameters.clone(), + .search_with( + parameters.inner, self.strategy.get(index)?, + parameters.processor, &context, self.queries.row(index), buffer, @@ -104,7 +111,7 @@ where #[non_exhaustive] pub struct Summary { pub setup: search::Setup, - pub parameters: graph::search::KnnWith, + pub parameters: Parameters, pub end_to_end_latencies: Vec, pub mean_latencies: Vec, pub p90_latencies: Vec, @@ -134,12 +141,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl - search::Aggregate< - graph::search::KnnWith, - I, - super::knn::Metrics, - > for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -147,7 +149,7 @@ where fn aggregate( &mut self, - run: search::Run>, + run: search::Run, mut results: Vec>, ) -> anyhow::Result

{ let recall = match results.first() { @@ -185,7 +187,7 @@ where Ok(Summary { setup: run.setup().clone(), - parameters: run.parameters().clone(), + parameters: *run.parameters(), end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), recall, mean_latencies, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 957f9c924..6ad460fc5 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -92,7 +92,11 @@ pub(crate) fn run_determinant_diversity( eta, power, ); - let search_params = diskann::graph::search::KnnWith::new(base, processor); + let search_params = + diskann_benchmark_core::search::graph::determinant_diversity::Parameters { + inner: base, + processor, + }; core_search::Run::new(search_params, setup.clone()) }) @@ -116,9 +120,8 @@ pub(crate) trait Knn { ) -> anyhow::Result>; } -type DeterminantRun = core_search::Run< - diskann::graph::search::KnnWith, ->; +type DeterminantRun = + core_search::Run; pub(crate) trait DeterminantDiversityKnn { fn search_all( @@ -192,9 +195,7 @@ where DP: diskann::provider::DataProvider, core_search::graph::determinant_diversity::KNN: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::KnnWith< - diskann::graph::search::DeterminantDiversitySearchParams, - >, + Parameters = diskann_benchmark_core::search::graph::determinant_diversity::Parameters, Output = core_search::graph::knn::Metrics, >, { @@ -202,9 +203,7 @@ where &self, parameters: Vec< core_search::Run< - diskann::graph::search::KnnWith< - diskann::graph::search::DeterminantDiversitySearchParams, - >, + diskann_benchmark_core::search::graph::determinant_diversity::Parameters, >, >, groundtruth: &dyn benchmark_core::recall::Rows, diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 79223b78d..0caca3188 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -408,7 +408,7 @@ where #[allow(clippy::manual_async_fn)] fn post_process_with<'a, I, B>( &self, - processor: &DeterminantDiversitySearchParams, + processor: DeterminantDiversitySearchParams, accessor: &mut Self::SearchAccessor<'a>, query: &[Data::VectorDataType], _computer: &Self::QueryComputer, @@ -1110,7 +1110,7 @@ where self.runtime.block_on(self.index.search_with( knn_search, &strategy, - &processor, + processor, &DefaultContext, strategy.query, &mut result_output_buffer, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 45c76f84a..36298b5d3 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -478,65 +478,6 @@ where delegate_default_post_process!(RemoveDeletedIdsAndCopy); } -impl - PostProcess, [T], DeterminantDiversitySearchParams> - for Internal -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: &DeterminantDiversitySearchParams, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - _computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - let query_f32 = T::as_f32(query).into_ann_result()?.to_vec(); - let mut candidates_with_vectors = Vec::new(); - - for candidate in candidates { - if accessor.provider.deleted.deletion_check(candidate.id) { - continue; - } - - let vector = accessor.get_element(candidate.id).await.into_ann_result()?; - let vector_f32 = T::as_f32(vector).into_ann_result()?; - candidates_with_vectors.push(( - candidate.id, - candidate.distance, - vector_f32.to_vec(), - )); - } - - let borrowed: Vec<(u32, f32, &[f32])> = candidates_with_vectors - .iter() - .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) - .collect(); - - let reranked = determinant_diversity_post_process( - borrowed, - &query_f32, - processor.top_k, - processor.determinant_diversity_eta, - processor.determinant_diversity_power, - ); - - Ok(output.extend(reranked)) - } - } -} - /// Perform a search entirely in the full-precision space. /// /// Starting points are not filtered out of the final results. @@ -582,7 +523,7 @@ where #[allow(clippy::manual_async_fn)] fn post_process_with<'a, I, B>( &self, - processor: &DeterminantDiversitySearchParams, + processor: DeterminantDiversitySearchParams, accessor: &mut Self::SearchAccessor<'a>, query: &[T], _computer: &Self::QueryComputer, diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 01638a0f8..e6dd0b6cd 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -335,8 +335,8 @@ where /// Strategy-level bridge connecting a [`SearchStrategy`] to a specific processor type `P`. /// -/// This trait is the surface that the search infrastructure ([`super::search::Knn`], -/// [`super::search::KnnWith`], etc.) bounds on. +/// This trait is the surface that the search infrastructure (for example, +/// [`super::search::Knn`]) bounds on. /// /// The blanket impl covers `P = DefaultPostProcess` for any strategy implementing /// [`HasDefaultProcessor`]. Custom processor types (e.g. `DeterminantDiversitySearchParams`) can have @@ -353,7 +353,7 @@ where /// results into `output`. fn post_process_with<'a, I, B>( &self, - processor: &P, + processor: P, accessor: &mut Self::SearchAccessor<'a>, query: &T, computer: &Self::QueryComputer, @@ -421,7 +421,7 @@ where { fn post_process_with<'a, I, B>( &self, - _processor: &DefaultPostProcess, + _processor: DefaultPostProcess, accessor: &mut Self::SearchAccessor<'a>, query: &T, computer: &Self::QueryComputer, diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 2b7662eae..e6c49e8a4 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -1299,7 +1299,7 @@ where let post_processor = strategy.search_post_processor(); let num_results = search_strategy .post_process_with( - &post_processor, + post_processor, &mut search_accessor, &*proxy, &computer, @@ -2159,7 +2159,7 @@ where self.search_with( search_params, strategy, - &glue::DefaultPostProcess, + glue::DefaultPostProcess, context, query, output, @@ -2171,7 +2171,7 @@ where &self, search_params: P, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, output: &mut OB, diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index da9aa8486..206ca7465 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -108,7 +108,7 @@ where self, index: &DiskANNIndex, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, output: &mut OB, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 5e0cb4160..c790770bb 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -15,7 +15,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{DefaultPostProcess, PostProcess, SearchExt}, + glue::{PostProcess, SearchExt}, index::{DiskANNIndex, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, @@ -152,7 +152,7 @@ impl Knn { context: &DP::Context, query: &T, output: &mut OB, - post_processor: &PP, + post_processor: PP, ) -> ANNResult where DP: DataProvider, @@ -213,7 +213,7 @@ where self, index: &DiskANNIndex, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, output: &mut OB, @@ -263,7 +263,7 @@ where self, index: &DiskANNIndex, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, output: &mut OB, @@ -308,68 +308,6 @@ where } } -///////////////////////// -// KnnWith // -///////////////////////// - -/// K-NN search with an explicit caller-supplied post-processor. -/// -/// This allows using a custom post-processor `PP` instead of the strategy's default. -/// Use [`KnnWith::new`] to wrap a base [`Knn`] with a post-processor. -#[derive(Debug, Clone)] -pub struct KnnWith { - /// Base k-NN search parameters. - pub inner: Knn, - /// The caller-supplied post-processor. - pub post_processor: PP, -} - -impl KnnWith { - /// Create new k-NN search parameters with an explicit post-processor. - pub fn new(inner: Knn, post_processor: PP) -> Self { - Self { - inner, - post_processor, - } - } -} - -impl Search for KnnWith -where - DP: DataProvider, - T: Sync + ?Sized, - S: PostProcess, - O: Send, - OB: SearchOutputBuffer + Send + ?Sized, - PP: Send + Sync, -{ - type Output = SearchStats; - - /// Execute the k-NN search with the caller-supplied post-processor. - fn search( - self, - index: &DiskANNIndex, - strategy: &S, - _processor: &DefaultPostProcess, - context: &DP::Context, - query: &T, - output: &mut OB, - ) -> impl SendFuture> { - async move { - self.inner - .search_core( - index, - strategy, - context, - query, - output, - &self.post_processor, - ) - .await - } - } -} - /////////// // Tests // /////////// diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 3a1b5f83e..8873ca51c 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -83,7 +83,7 @@ where self, index: &DiskANNIndex, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, output: &mut OB, @@ -94,7 +94,7 @@ where pub use determinant_diversity_post_process::{ DeterminantDiversitySearchParams, determinant_diversity_post_process, }; -pub use knn_search::{Knn, KnnSearchError, KnnWith, RecordedKnn}; +pub use knn_search::{Knn, KnnSearchError, RecordedKnn}; pub use multihop_search::MultihopSearch; pub use range_search::{Range, RangeSearchError, RangeSearchOutput}; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 6038ef962..7259ef1e1 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -67,7 +67,7 @@ where self, index: &DiskANNIndex, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, output: &mut OB, diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 160aa932e..e01a39c9e 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -181,7 +181,7 @@ where self, index: &DiskANNIndex, strategy: &S, - processor: &PP, + processor: PP, context: &DP::Context, query: &T, _output: &mut (), From f764aa424a5d5634bdc3334f3a72b6b875a9fdbe Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 17:19:50 +0530 Subject: [PATCH 07/47] Implement major Search trait refactor and remove Internal - Simplify Search trait: move processor/output buffer to method-level generics - Remove Internal strategy split; use RemoveDeletedIdsAndCopy for delete ops - Add DefaultSearchStrategy aggregate trait combining SearchStrategy + HasDefaultProcessor - Update benchmark-core helpers to use aggregate trait (reduce recurring bounds) - Wire range search output buffer through to caller (support dynamic output handling) - Add no-op SearchOutputBuffer impl for () type to preserve compatibility --- .../src/search/graph/determinant_diversity.rs | 3 +- .../src/search/graph/knn.rs | 5 +- .../src/search/graph/multihop.rs | 5 +- .../src/search/graph/range.rs | 5 +- .../provider/async_/inmem/full_precision.rs | 76 ++++++++----------- diskann/src/graph/glue.rs | 19 +++++ diskann/src/graph/index.rs | 13 +++- diskann/src/graph/search/diverse_search.rs | 14 ++-- diskann/src/graph/search/knn_search.rs | 28 ++++--- diskann/src/graph/search/mod.rs | 12 ++- diskann/src/graph/search/multihop_search.rs | 14 ++-- diskann/src/graph/search/range_search.rs | 19 +++-- diskann/src/graph/search_output_buffer.rs | 21 +++++ 13 files changed, 137 insertions(+), 97 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs index d69d18937..681f5c42b 100644 --- a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs +++ b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs @@ -58,8 +58,7 @@ where impl Search for KNN where DP: provider::DataProvider, - S: glue::SearchStrategy - + glue::HasDefaultProcessor + S: glue::DefaultSearchStrategy + glue::PostProcess + Clone + AsyncFriendly, diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index ebc0e3b4b..6cc2c9673 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -88,10 +88,7 @@ pub struct Metrics { impl Search for KNN where DP: provider::DataProvider, - S: glue::SearchStrategy - + glue::HasDefaultProcessor - + Clone - + AsyncFriendly, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, T: AsyncFriendly + Clone, { type Id = DP::ExternalId; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 584c70baf..41af5c37f 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -86,10 +86,7 @@ where impl Search for MultiHop where DP: provider::DataProvider, - S: glue::SearchStrategy - + glue::HasDefaultProcessor - + Clone - + AsyncFriendly, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, T: AsyncFriendly + Clone, { type Id = DP::ExternalId; diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index 7d9348244..3f9eb3ed2 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -79,10 +79,7 @@ pub struct Metrics {} impl Search for Range where DP: provider::DataProvider, - S: glue::SearchStrategy - + glue::HasDefaultProcessor - + Clone - + AsyncFriendly, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, T: AsyncFriendly + Clone, { type Id = DP::ExternalId; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 36298b5d3..01cdc03e3 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -32,8 +32,8 @@ use crate::model::graph::{ provider::async_::{ FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, common::{ - CreateVectorStore, FullPrecision, Internal, NoDeletes, NoStore, Panics, - PrefetchCacheLineLevel, SetElementHelper, + CreateVectorStore, FullPrecision, NoDeletes, NoStore, Panics, PrefetchCacheLineLevel, + SetElementHelper, }, inmem::DefaultProvider, postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, @@ -434,20 +434,10 @@ where // Strategies // //////////////// -// A layered approach is used for search strategies. The `Internal` version does the heavy -// lifting in terms of establishing accessors and post processing. -// -// However, during post-processing, the `Internal` versions of strategies will not filter -// out the start points. The publicly exposed types *will* filter out the start points. -// -// This layered approach allows algorithms like `InplaceDeleteStrategy` that need to adjust -// the adjacency list for the start point to reuse the `Internal` strategies. - /// Perform a search entirely in the full-precision space. /// /// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> - for Internal +impl SearchStrategy, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, @@ -467,50 +457,48 @@ where } } -impl HasDefaultProcessor, [T]> - for Internal +impl HasDefaultProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(RemoveDeletedIdsAndCopy); + delegate_default_post_process!(glue::Pipeline); } -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision +impl PostProcess, [T], RemoveDeletedIdsAndCopy> + for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; - type SearchAccessorError = Panics; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: RemoveDeletedIdsAndCopy, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } } } -impl HasDefaultProcessor, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - delegate_default_post_process!(glue::Pipeline); -} - impl PostProcess, [T], DeterminantDiversitySearchParams> for FullPrecision @@ -634,10 +622,10 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; - type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; - type SearchStrategy = Internal; + type SearchPostProcessor = RemoveDeletedIdsAndCopy; + type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) + *self } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -645,7 +633,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Default::default() + RemoveDeletedIdsAndCopy } async fn get_delete_element<'a>( diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index e6dd0b6cd..c88bad27f 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -384,6 +384,25 @@ where fn create_processor(&self) -> Self::Processor; } +/// Aggregate trait for strategies that support both search access and a default post-processor. +pub trait DefaultSearchStrategy::InternalId>: + SearchStrategy + HasDefaultProcessor +where + Provider: DataProvider, + T: ?Sized, + O: Send, +{ +} + +impl DefaultSearchStrategy for S +where + S: SearchStrategy + HasDefaultProcessor, + Provider: DataProvider, + T: ?Sized, + O: Send, +{ +} + /// Convenience macro for implementing [`HasDefaultProcessor`] when the processor /// is a [`Default`]-constructible type. /// diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index e6c49e8a4..e4541d07e 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2152,9 +2152,12 @@ where output: &mut OB, ) -> impl SendFuture> where - P: super::search::Search, + P: super::search::Search, + S: glue::PostProcess, + glue::DefaultPostProcess: Send + Sync, + O: Send, + OB: super::search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, - OB: ?Sized, { self.search_with( search_params, @@ -2177,10 +2180,12 @@ where output: &mut OB, ) -> impl SendFuture> where - P: super::search::Search, + P: super::search::Search, + S: glue::PostProcess, PP: Send + Sync, + O: Send, + OB: super::search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, - OB: ?Sized, { search_params.search(self, strategy, processor, context, query, output) } diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 206ca7465..71b217a7e 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -92,19 +92,16 @@ where } } -impl Search for Diverse

+impl Search for Diverse

where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, O: Send, - OB: SearchOutputBuffer + Send, P: AttributeValueProvider, - PP: Send + Sync, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -112,7 +109,12 @@ where context: &DP::Context, query: &T, output: &mut OB, - ) -> impl SendFuture> { + ) -> impl SendFuture> + where + S: PostProcess, + PP: Send + Sync, + OB: SearchOutputBuffer + Send, + { async move { let mut accessor = strategy .search_accessor(&index.data_provider, context) diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index c790770bb..f70dc6aed 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -197,19 +197,16 @@ impl Knn { } } -impl Search for Knn +impl Search for Knn where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, O: Send, - OB: SearchOutputBuffer + Send + ?Sized, - PP: Send + Sync, { type Output = SearchStats; /// Execute the k-NN search on the given index using the default post-processor. - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -217,7 +214,12 @@ where context: &DP::Context, query: &T, output: &mut OB, - ) -> impl SendFuture> { + ) -> impl SendFuture> + where + S: PostProcess, + PP: Send + Sync, + OB: SearchOutputBuffer + Send + ?Sized, + { async move { self.search_core(index, strategy, context, query, output, processor) .await @@ -247,19 +249,16 @@ impl<'r, SR: ?Sized> RecordedKnn<'r, SR> { } } -impl<'r, DP, S, T, O, OB, SR, PP> Search for RecordedKnn<'r, SR> +impl<'r, DP, S, T, O, SR> Search for RecordedKnn<'r, SR> where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, O: Send, - OB: SearchOutputBuffer + Send + ?Sized, SR: super::record::SearchRecord + ?Sized, - PP: Send + Sync, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -267,7 +266,12 @@ where context: &DP::Context, query: &T, output: &mut OB, - ) -> impl SendFuture> { + ) -> impl SendFuture> + where + S: PostProcess, + PP: Send + Sync, + OB: SearchOutputBuffer + Send + ?Sized, + { async move { let mut accessor = strategy .search_accessor(&index.data_provider, context) diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 8873ca51c..90797711e 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -49,10 +49,10 @@ pub(crate) mod scratch; /// - [`Diverse`] - Diversity-aware search (feature-gated) /// - [`MultihopSearch`] - Label-filtered search with multi-hop expansion /// - [`RecordedKnn`] - K-NN search with path recording for debugging -pub trait Search +pub trait Search where DP: DataProvider, - PP: Send + Sync, + O: Send, { /// The result type returned by this search. type Output; @@ -79,7 +79,7 @@ where /// # Errors /// /// Returns an error if there is a failure accessing elements or computing distances. - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -87,7 +87,11 @@ where context: &DP::Context, query: &T, output: &mut OB, - ) -> impl SendFuture>; + ) -> impl SendFuture> + where + S: crate::graph::glue::PostProcess, + PP: Send + Sync, + OB: crate::graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; } // Re-export search parameter types. diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 7259ef1e1..5771cd4de 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -52,18 +52,15 @@ impl<'q, InternalId> MultihopSearch<'q, InternalId> { } } -impl<'q, DP, S, T, O, OB, PP> Search for MultihopSearch<'q, DP::InternalId> +impl<'q, DP, S, T, O> Search for MultihopSearch<'q, DP::InternalId> where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, O: Send, - OB: SearchOutputBuffer + Send, - PP: Send + Sync, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -71,7 +68,12 @@ where context: &DP::Context, query: &T, output: &mut OB, - ) -> impl SendFuture> { + ) -> impl SendFuture> + where + S: PostProcess, + PP: Send + Sync, + OB: SearchOutputBuffer + Send + ?Sized, + { async move { let mut accessor = strategy .search_accessor(&index.data_provider, context) diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index e01a39c9e..495f80de1 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -16,7 +16,7 @@ use crate::{ glue::{self, ExpandBeam, PostProcess, SearchExt}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, - search_output_buffer, + search_output_buffer::{self, SearchOutputBuffer}, }, neighbor::Neighbor, provider::{BuildQueryComputer, DataProvider}, @@ -167,25 +167,28 @@ impl Range { } } -impl Search for Range +impl Search for Range where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, O: Send + Default + Clone, - PP: Send + Sync, { type Output = RangeSearchOutput; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, processor: PP, context: &DP::Context, query: &T, - _output: &mut (), - ) -> impl SendFuture> { + output: &mut OB, + ) -> impl SendFuture> + where + S: PostProcess, + PP: Send + Sync, + OB: SearchOutputBuffer + Send + ?Sized, + { async move { let mut accessor = strategy .search_accessor(&index.data_provider, context) @@ -286,6 +289,8 @@ where let result_count = result_ids.len(); + let _ = output.extend(result_ids.iter().cloned().zip(result_dists.iter().copied())); + Ok(RangeSearchOutput { stats: SearchStats { cmps: stats.cmps, diff --git a/diskann/src/graph/search_output_buffer.rs b/diskann/src/graph/search_output_buffer.rs index fddd7c73b..8d3469ea8 100644 --- a/diskann/src/graph/search_output_buffer.rs +++ b/diskann/src/graph/search_output_buffer.rs @@ -36,6 +36,27 @@ pub trait SearchOutputBuffer { Itr: IntoIterator; } +impl SearchOutputBuffer for () { + fn size_hint(&self) -> Option { + None + } + + fn push(&mut self, _id: I, _distance: D) -> BufferState { + BufferState::Available + } + + fn current_len(&self) -> usize { + 0 + } + + fn extend(&mut self, _itr: Itr) -> usize + where + Itr: IntoIterator, + { + 0 + } +} + /// Indicate whether future calls to [`SearchOutputBuffer::push`] will succeed or not. #[derive(Debug, Clone, Copy, PartialEq)] #[must_use = "This type indicates whether the output buffer is full or not."] From ab56c939e6a716bc6f2ac7bfa6665d0878663fbe Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 17:28:00 +0530 Subject: [PATCH 08/47] Relocate determinant_diversity_post_process from diskann to diskann-providers This commit moves the determinant_diversity_post_process module from diskann to diskann-providers, as it does not depend on diskann internals and logically belongs with other post-processing logic in the providers layer. Changes: - Move determinant_diversity_post_process.rs from diskann/src/graph/search/ to diskann-providers/src/model/graph/provider/async_/ - Update all imports across workspace to use diskann_providers location - Add diskann-providers dependency to diskann-benchmark-core (required for DeterminantDiversitySearchParams access) - Remove old module reference from diskann/src/graph/search/mod.rs - Update diskann-benchmark, diskann-disk imports to use new location Validated with: - cargo clippy --workspace --all-targets -- -D warnings - cargo fmt --all This results in cleaner architectural separation where determinant-diversity search parameters stay with the provider infrastructure that implements them. --- Cargo.lock | 1 + diskann-benchmark-core/Cargo.toml | 1 + .../src/search/graph/determinant_diversity.rs | 5 +++-- diskann-benchmark/src/backend/disk_index/search.rs | 2 +- diskann-benchmark/src/backend/index/benchmarks.rs | 3 ++- diskann-benchmark/src/backend/index/search/knn.rs | 2 +- diskann-disk/src/search/provider/disk_provider.rs | 3 ++- .../provider/async_}/determinant_diversity_post_process.rs | 0 .../src/model/graph/provider/async_/inmem/full_precision.rs | 5 ++++- diskann-providers/src/model/graph/provider/async_/mod.rs | 6 ++++++ diskann/src/graph/search/mod.rs | 5 ----- 11 files changed, 21 insertions(+), 12 deletions(-) rename {diskann/src/graph/search => diskann-providers/src/model/graph/provider/async_}/determinant_diversity_post_process.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index f0405995a..948a3717c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -469,6 +469,7 @@ dependencies = [ "anyhow", "diskann", "diskann-benchmark-runner", + "diskann-providers", "diskann-utils", "diskann-vector", "futures-util", diff --git a/diskann-benchmark-core/Cargo.toml b/diskann-benchmark-core/Cargo.toml index 90e64b9e3..b3eabed45 100644 --- a/diskann-benchmark-core/Cargo.toml +++ b/diskann-benchmark-core/Cargo.toml @@ -11,6 +11,7 @@ edition = "2024" anyhow.workspace = true diskann.workspace = true diskann-benchmark-runner = { workspace = true } +diskann-providers = { workspace = true } diskann-utils.default-features = false diskann-utils.workspace = true futures-util = { workspace = true, default-features = false } diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs index 681f5c42b..2f413a117 100644 --- a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs +++ b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs @@ -11,6 +11,7 @@ use diskann::{ provider, }; use diskann_benchmark_runner::utils::{MicroSeconds, percentiles}; +use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_utils::{future::AsyncFriendly, views::Matrix}; use crate::{ @@ -22,7 +23,7 @@ use crate::{ #[derive(Debug, Clone, Copy)] pub struct Parameters { pub inner: graph::search::Knn, - pub processor: graph::search::DeterminantDiversitySearchParams, + pub processor: DeterminantDiversitySearchParams, } /// A built-in helper for benchmarking determinant-diversity K-nearest neighbors. @@ -59,7 +60,7 @@ impl Search for KNN where DP: provider::DataProvider, S: glue::DefaultSearchStrategy - + glue::PostProcess + + glue::PostProcess + Clone + AsyncFriendly, T: AsyncFriendly + Clone, diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 1de1d0eb0..7fdac7eea 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -9,7 +9,6 @@ use std::{collections::HashSet, fmt, sync::atomic::AtomicBool, time::Instant}; use opentelemetry::{global, trace::Span, trace::Tracer}; use opentelemetry_sdk::trace::SdkTracerProvider; -use diskann::graph::search::DeterminantDiversitySearchParams; use diskann::utils::VectorRepr; use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ @@ -20,6 +19,7 @@ use diskann_disk::{ storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, }; +use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ storage::{ diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index e0f8aed90..cee9dafec 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -12,6 +12,7 @@ use diskann::{ provider::{self, DataProvider, DefaultContext}, utils::VectorRepr, }; +use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_benchmark_core::{ self as benchmark_core, streaming::{executors::bigann, Executor}, @@ -506,7 +507,7 @@ where T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, S: glue::SearchStrategy + glue::HasDefaultProcessor - + glue::PostProcess + + glue::PostProcess + Clone + AsyncFriendly, { diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 6ad460fc5..3f7f294e1 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -87,7 +87,7 @@ pub(crate) fn run_determinant_diversity( .map(|search_l| { let base = diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - let processor = diskann::graph::search::DeterminantDiversitySearchParams::new( + let processor = diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams::new( results_k.unwrap_or(run.search_n), eta, power, diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 0caca3188..da819ba42 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -23,7 +23,7 @@ use diskann::{ self, ExpandBeam, HasDefaultProcessor, IdIterator, PostProcess, SearchExt, SearchPostProcess, SearchStrategy, }, - search::{determinant_diversity_post_process, DeterminantDiversitySearchParams, Knn}, + search::Knn, search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, }, neighbor::Neighbor, @@ -44,6 +44,7 @@ use diskann_providers::{ pq::quantizer_preprocess, PQData, PQScratch, }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, + model::graph::provider::async_::{determinant_diversity_post_process, DeterminantDiversitySearchParams}, }; use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction}; use futures_util::future; diff --git a/diskann/src/graph/search/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs similarity index 100% rename from diskann/src/graph/search/determinant_diversity_post_process.rs rename to diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 01cdc03e3..06a146004 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -16,7 +16,6 @@ use diskann::{ InplaceDeleteStrategy, InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, - search::{DeterminantDiversitySearchParams, determinant_diversity_post_process}, }, neighbor::Neighbor, provider::{ @@ -25,6 +24,10 @@ use diskann::{ }, utils::{IntoUsize, VectorRepr}, }; + +use super::super::determinant_diversity_post_process::{ + DeterminantDiversitySearchParams, determinant_diversity_post_process, +}; use diskann_utils::future::AsyncFriendly; use diskann_vector::{DistanceFunction, distance::Metric}; diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 3d89359e2..bdf620d38 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -43,3 +43,9 @@ pub mod caching; // Debug provider for testing. #[cfg(test)] pub mod debug_provider; + +// Determinant-diversity post-processing. +pub mod determinant_diversity_post_process; +pub use determinant_diversity_post_process::{ + DeterminantDiversitySearchParams, determinant_diversity_post_process, +}; diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 90797711e..ab32581dc 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -28,7 +28,6 @@ use diskann_utils::future::SendFuture; use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; -mod determinant_diversity_post_process; mod knn_search; mod multihop_search; mod range_search; @@ -94,10 +93,6 @@ where OB: crate::graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; } -// Re-export search parameter types. -pub use determinant_diversity_post_process::{ - DeterminantDiversitySearchParams, determinant_diversity_post_process, -}; pub use knn_search::{Knn, KnnSearchError, RecordedKnn}; pub use multihop_search::MultihopSearch; pub use range_search::{Range, RangeSearchError, RangeSearchOutput}; From 5cad020608a613c6aacf5ef7d71c729d9932d4fa Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 18:00:54 +0530 Subject: [PATCH 09/47] Algorithm cleanup: Add determinant-diversity input validation and reduce clones - Add DeterminantDiversityError enum for parameter validation - Convert DeterminantDiversitySearchParams::new() to return Result - Validate top_k > 0, eta >= 0.0 and finite, power > 0.0 and finite - Optimize post_process_with_eta_f32: precompute projections to eliminate vector clones - Optimize post_process_greedy_orthogonalization_f32: single r_star_copy before projection loop - Expand test suite from 3 to 11 tests (7 validation + 4 algorithm tests) - Update callsites in disk_index/search.rs and index/search/knn.rs for error handling - Add early validation checks in main router function --- .../src/backend/disk_index/search.rs | 26 ++- .../src/backend/index/benchmarks.rs | 2 +- .../src/backend/index/search/knn.rs | 20 +- .../src/search/provider/disk_provider.rs | 4 +- .../determinant_diversity_post_process.rs | 192 ++++++++++++++++-- 5 files changed, 209 insertions(+), 35 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 7fdac7eea..70b138cc0 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -274,22 +274,26 @@ where search_params.determinant_diversity_eta, search_params.determinant_diversity_power, ) { - let processor = DeterminantDiversitySearchParams::new( + match DeterminantDiversitySearchParams::new( search_params .determinant_diversity_results_k .unwrap_or(search_params.recall_at as usize), eta, power, - ); - searcher.search_determinant_diversity( - q, - search_params.recall_at, - l, - Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, - processor, - ) + ) { + Ok(processor) => searcher.search_determinant_diversity( + q, + search_params.recall_at, + l, + Some(search_params.beam_width), + vector_filter, + search_params.is_flat_search, + processor, + ), + Err(e) => { + Err(format!("Invalid determinant-diversity parameters: {}", e).into()) + } + } } else { searcher.search( q, diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index cee9dafec..e3e3a54c9 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -12,7 +12,6 @@ use diskann::{ provider::{self, DataProvider, DefaultContext}, utils::VectorRepr, }; -use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_benchmark_core::{ self as benchmark_core, streaming::{executors::bigann, Executor}, @@ -23,6 +22,7 @@ use diskann_benchmark_runner::{ utils::datatype, Any, Checkpoint, }; +use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_providers::{ index::diskann_async, model::{configuration::IndexConfiguration, graph::provider::async_::common}, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 3f7f294e1..1b37571f8 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -91,16 +91,18 @@ pub(crate) fn run_determinant_diversity( results_k.unwrap_or(run.search_n), eta, power, - ); - let search_params = - diskann_benchmark_core::search::graph::determinant_diversity::Parameters { - inner: base, - processor, - }; - - core_search::Run::new(search_params, setup.clone()) + ).map_err(|e| anyhow::anyhow!("Invalid determinant-diversity parameters: {}", e)); + + processor.map(|proc| { + let search_params = + diskann_benchmark_core::search::graph::determinant_diversity::Parameters { + inner: base, + processor: proc, + }; + core_search::Run::new(search_params, setup.clone()) + }) }) - .collect(); + .collect::>>()?; all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index da819ba42..7b8092d67 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -39,12 +39,14 @@ use diskann::{ }; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ + model::graph::provider::async_::{ + determinant_diversity_post_process, DeterminantDiversitySearchParams, + }, model::{ compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType, pq::quantizer_preprocess, PQData, PQScratch, }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, - model::graph::provider::async_::{determinant_diversity_post_process, DeterminantDiversitySearchParams}, }; use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction}; use futures_util::future; diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 33e4b987d..56b663d13 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -11,7 +11,39 @@ use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; +/// Error type for determinant-diversity parameter validation. +#[derive(Debug)] +pub enum DeterminantDiversityError { + InvalidTopK { top_k: usize }, + InvalidEta { eta: f64 }, + InvalidPower { power: f64 }, +} + +impl std::fmt::Display for DeterminantDiversityError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidTopK { top_k } => { + write!(f, "top_k must be > 0, got {}", top_k) + } + Self::InvalidEta { eta } => { + write!(f, "eta must be >= 0.0, got {}", eta) + } + Self::InvalidPower { power } => { + write!(f, "power must be > 0.0, got {}", power) + } + } + } +} + +impl std::error::Error for DeterminantDiversityError {} + /// Parameters for determinant-diversity reranking. +/// +/// # Invariants +/// +/// - `top_k > 0`: Must request at least one result +/// - `determinant_diversity_eta >= 0.0`: Ridge regularization parameter (0 = no ridge) +/// - `determinant_diversity_power > 0.0`: Exponent for diversity scaling (typically 1.0-2.0) #[derive(Debug, Clone, Copy)] pub struct DeterminantDiversitySearchParams { pub top_k: usize, @@ -20,16 +52,43 @@ pub struct DeterminantDiversitySearchParams { } impl DeterminantDiversitySearchParams { + /// Construct parameters with validation. + /// + /// # Arguments + /// + /// * `top_k` - Number of results to return (must be > 0) + /// * `determinant_diversity_eta` - Ridge regularization parameter (must be >= 0.0) + /// * `determinant_diversity_power` - Diversity exponent (must be > 0.0) + /// + /// # Errors + /// + /// Returns [`DeterminantDiversityError`] if any parameter is invalid. pub fn new( top_k: usize, determinant_diversity_eta: f64, determinant_diversity_power: f64, - ) -> Self { - Self { + ) -> Result { + if top_k == 0 { + return Err(DeterminantDiversityError::InvalidTopK { top_k }); + } + + if determinant_diversity_eta < 0.0 || !determinant_diversity_eta.is_finite() { + return Err(DeterminantDiversityError::InvalidEta { + eta: determinant_diversity_eta, + }); + } + + if determinant_diversity_power <= 0.0 || !determinant_diversity_power.is_finite() { + return Err(DeterminantDiversityError::InvalidPower { + power: determinant_diversity_power, + }); + } + + Ok(Self { top_k, determinant_diversity_eta, determinant_diversity_power, - } + }) } } @@ -44,17 +103,25 @@ pub fn determinant_diversity_post_process( determinant_diversity_eta: f64, determinant_diversity_power: f64, ) -> Vec<(Id, f32)> { - if candidates.is_empty() { + if candidates.is_empty() || query.is_empty() { return Vec::new(); } let k = k.min(candidates.len()); + if k == 0 { + return Vec::new(); + } + // Convert vectors to owned format only once let candidates_f32: Vec<(Id, f32, Vec)> = candidates .into_iter() .map(|(id, dist, v)| (id, dist, v.to_vec())) .collect(); + if candidates_f32[0].2.is_empty() { + return Vec::new(); + } + let results = if determinant_diversity_eta > 0.0 { post_process_with_eta_f32( candidates_f32, @@ -114,6 +181,7 @@ fn post_process_with_eta_f32( let mut residuals: Vec> = Vec::with_capacity(n); let mut norms_sq: Vec = Vec::with_capacity(n); + // Initialize residuals and norms (only one allocation per candidate) for (_, _, v) in &candidates { let similarity = dot_product(v, query); let scale = similarity.max(0.0).powf(power as f32) * inv_sqrt_eta; @@ -150,16 +218,27 @@ fn post_process_with_eta_f32( } let norm_factor = 1.0 / (1.0 + norms_sq[j]).sqrt(); - let q: Vec = residuals[j].iter().map(|&x| x * norm_factor).collect(); + // Compute all projections first to avoid needing to clone residuals[j] + let mut projections: Vec = Vec::with_capacity(n); for i in 0..n { if !available[i] { - continue; + projections.push(0.0); + } else { + let alpha = dot_product(&residuals[j], &residuals[i]) * norm_factor * norm_factor; + projections.push(alpha); } + } - let alpha = dot_product(&q, &residuals[i]); + // Now apply all updates using the precomputed projections + let q_scaled: Vec = residuals[j].iter().map(|&x| x * norm_factor).collect(); + for i in 0..n { + if !available[i] { + continue; + } - for (r_val, &q_val) in residuals[i].iter_mut().zip(q.iter()) { + let alpha = projections[i]; + for (r_val, &q_val) in residuals[i].iter_mut().zip(q_scaled.iter()) { *r_val -= alpha * q_val; } @@ -198,6 +277,7 @@ fn post_process_greedy_orthogonalization_f32( let mut residuals: Vec> = Vec::with_capacity(n); let mut norms_sq: Vec = Vec::with_capacity(n); + // Initialize residuals and norms (only one allocation per candidate) for (_, _, v) in &candidates { let similarity = dot_product(v, query); let scale = similarity.max(0.0).powf(power as f32); @@ -226,7 +306,6 @@ fn post_process_greedy_orthogonalization_f32( }; let best_norm_sq = norms_sq[i_star]; - selected.push(i_star); available[i_star] = false; @@ -238,17 +317,28 @@ fn post_process_greedy_orthogonalization_f32( continue; } - let r_star = residuals[i_star].clone(); let inv_norm_sq_star = 1.0 / best_norm_sq; + // Compute all projections and make a copy of r_star to avoid borrow conflicts + let r_star_copy = residuals[i_star].clone(); + let mut projections: Vec = Vec::with_capacity(n); for j in 0..n { if !available[j] { - continue; + projections.push(0.0); + } else { + let proj = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; + projections.push(proj); } + } - let proj_coeff = dot_product(&residuals[j], &r_star) * inv_norm_sq_star; + // Now apply all updates using the precomputed projections + for j in 0..n { + if !available[j] { + continue; + } - for (r_val, &rs_val) in residuals[j].iter_mut().zip(r_star.iter()) { + let proj_coeff = projections[j]; + for (r_val, &rs_val) in residuals[j].iter_mut().zip(r_star_copy.iter()) { *r_val -= proj_coeff * rs_val; } @@ -275,6 +365,70 @@ fn dot_product(a: &[f32], b: &[f32]) -> f32 { mod tests { use super::*; + // ===== Validation Tests ===== + + #[test] + fn test_validation_valid_params() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, 2.0); + assert!(result.is_ok()); + } + + #[test] + fn test_validation_zero_top_k() { + let result = DeterminantDiversitySearchParams::new(0, 0.01, 2.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidTopK { top_k: 0 }) + )); + } + + #[test] + fn test_validation_negative_eta() { + let result = DeterminantDiversitySearchParams::new(10, -0.01, 2.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidEta { .. }) + )); + } + + #[test] + fn test_validation_zero_power() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, 0.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidPower { .. }) + )); + } + + #[test] + fn test_validation_negative_power() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, -1.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidPower { .. }) + )); + } + + #[test] + fn test_validation_nan_eta() { + let result = DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidEta { .. }) + )); + } + + #[test] + fn test_validation_infinity_power() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, f64::INFINITY); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidPower { .. }) + )); + } + + // ===== Algorithm Tests ===== + #[test] fn test_determinant_diversity_post_process_with_eta() { let v1 = vec![1.0f32, 0.0, 0.0]; @@ -315,4 +469,16 @@ mod tests { let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); assert!(result.is_empty()); } + + #[test] + fn test_determinant_diversity_post_process_k_larger_than_candidates() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![(1u32, 0.5f32, v1.as_slice()), (2u32, 0.3f32, v2.as_slice())]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); + // Should return min(k, len(candidates)) = 2 + assert_eq!(result.len(), 2); + } } From 3d79d6d85fb63ca4ef928f58053c187b3531d022 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 18:06:47 +0530 Subject: [PATCH 10/47] Benchmark: collapse determinant-diversity helper into knn - Extract shared run-loop logic into reusable helpers - Route both knn and determinant-diversity through closure-based parameter builders - Preserve determinant-diversity parameter validation/error propagation - Reduce duplicated benchmark orchestration code --- .../src/backend/index/search/knn.rs | 120 +++++++++++------- 1 file changed, 77 insertions(+), 43 deletions(-) diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 1b37571f8..e8ea35f25 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -35,6 +35,72 @@ pub(crate) fn run( groundtruth: &dyn benchmark_core::recall::Rows, steps: SearchSteps<'_>, ) -> anyhow::Result> { + run_search(runner, groundtruth, steps, |setup, search_l, search_n| { + let search_params = + diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + core_search::Run::new(search_params, setup) + }) +} + +pub(crate) fn run_determinant_diversity( + runner: &dyn DeterminantDiversityKnn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + eta: f64, + power: f64, + results_k: Option, +) -> anyhow::Result> { + run_search_determinant_diversity( + runner, + groundtruth, + steps, + |setup, search_l, search_n| { + let base = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + let processor = + diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams::new( + results_k.unwrap_or(search_n), + eta, + power, + ).map_err(|e| anyhow::anyhow!("Invalid determinant-diversity parameters: {}", e))?; + + let search_params = + diskann_benchmark_core::search::graph::determinant_diversity::Parameters { + inner: base, + processor, + }; + Ok(core_search::Run::new(search_params, setup)) + }, + ) +} + +type Run = core_search::Run; +pub(crate) trait Knn { + fn search_all( + &self, + parameters: Vec, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result>; +} + +type DeterminantRun = + core_search::Run; + +/// Generic search infrastructure that unifies `run()` and `run_determinant_diversity()`. +/// +/// This helper extracts the common loop logic (iterating over threads and runs, +/// and building a setup) leaving parameter construction to a builder closure. +/// This collapses the benchmark helper infrastructure and reduces duplication. +fn run_search( + runner: &dyn Knn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + builder: F, +) -> anyhow::Result> +where + F: Fn(core_search::Setup, usize, usize) -> core_search::Run, +{ let mut all = Vec::new(); for threads in steps.num_tasks.iter() { @@ -48,12 +114,7 @@ pub(crate) fn run( let parameters: Vec<_> = run .search_l .iter() - .map(|search_l| { - let search_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - - core_search::Run::new(search_params, setup.clone()) - }) + .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) .collect(); all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); @@ -63,14 +124,18 @@ pub(crate) fn run( Ok(all) } -pub(crate) fn run_determinant_diversity( +/// Generic search infrastructure for determinant-diversity searches. +/// +/// Mirrors the unified logic of `run_search()` but for the DeterminantDiversityKnn trait. +fn run_search_determinant_diversity( runner: &dyn DeterminantDiversityKnn, groundtruth: &dyn benchmark_core::recall::Rows, steps: SearchSteps<'_>, - eta: f64, - power: f64, - results_k: Option, -) -> anyhow::Result> { + builder: F, +) -> anyhow::Result> +where + F: Fn(core_search::Setup, usize, usize) -> anyhow::Result>, +{ let mut all = Vec::new(); for threads in steps.num_tasks.iter() { @@ -84,24 +149,7 @@ pub(crate) fn run_determinant_diversity( let parameters: Vec<_> = run .search_l .iter() - .map(|search_l| { - let base = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - let processor = diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams::new( - results_k.unwrap_or(run.search_n), - eta, - power, - ).map_err(|e| anyhow::anyhow!("Invalid determinant-diversity parameters: {}", e)); - - processor.map(|proc| { - let search_params = - diskann_benchmark_core::search::graph::determinant_diversity::Parameters { - inner: base, - processor: proc, - }; - core_search::Run::new(search_params, setup.clone()) - }) - }) + .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) .collect::>>()?; all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); @@ -111,20 +159,6 @@ pub(crate) fn run_determinant_diversity( Ok(all) } -type Run = core_search::Run; -pub(crate) trait Knn { - fn search_all( - &self, - parameters: Vec, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result>; -} - -type DeterminantRun = - core_search::Run; - pub(crate) trait DeterminantDiversityKnn { fn search_all( &self, From b8c952e90d71bd0356e1c0d2c47914f951c6b2e9 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 18:15:44 +0530 Subject: [PATCH 11/47] Rename HasDefaultProcessor to DelegateDefaultPostProcessor - Promote DelegateDefaultPostProcessor as the canonical trait in glue - Remove compatibility alias layer for HasDefaultProcessor - Rename all trait bounds/impls/usages across diskann, providers, disk, benchmark, and label-filter - Keep delegate_default_post_process! macro usage aligned with trait naming --- .../src/backend/index/benchmarks.rs | 7 +++- .../src/backend/index/search/knn.rs | 40 +++++++++---------- .../src/search/provider/disk_provider.rs | 4 +- .../inline_beta_search/inline_beta_filter.rs | 6 +-- diskann-providers/src/index/diskann_async.rs | 14 +++---- diskann-providers/src/index/wrapped_async.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 13 +++--- .../graph/provider/async_/caching/provider.rs | 8 ++-- .../graph/provider/async_/debug_provider.rs | 10 ++--- .../provider/async_/inmem/full_precision.rs | 5 ++- .../graph/provider/async_/inmem/product.rs | 11 ++--- .../graph/provider/async_/inmem/scalar.rs | 8 ++-- .../graph/provider/async_/inmem/spherical.rs | 6 +-- .../model/graph/provider/async_/inmem/test.rs | 4 +- .../model/graph/provider/layers/betafilter.rs | 8 ++-- diskann/src/graph/glue.rs | 18 ++++----- diskann/src/graph/index.rs | 2 +- diskann/src/graph/test/provider.rs | 2 +- 18 files changed, 88 insertions(+), 80 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index e3e3a54c9..07a0b942a 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -350,7 +350,10 @@ where DP: DataProvider + provider::SetElement<[T]>, T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, - S: glue::SearchStrategy + glue::HasDefaultProcessor + Clone + AsyncFriendly, + S: glue::SearchStrategy + + glue::DelegateDefaultPostProcessor + + Clone + + AsyncFriendly, { match &input { SearchPhase::Topk(search_phase) => { @@ -506,7 +509,7 @@ where + provider::SetElement<[T]>, T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, S: glue::SearchStrategy - + glue::HasDefaultProcessor + + glue::DelegateDefaultPostProcessor + glue::PostProcess + Clone + AsyncFriendly, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index e8ea35f25..a01318481 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -36,8 +36,7 @@ pub(crate) fn run( steps: SearchSteps<'_>, ) -> anyhow::Result> { run_search(runner, groundtruth, steps, |setup, search_l, search_n| { - let search_params = - diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + let search_params = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); core_search::Run::new(search_params, setup) }) } @@ -50,27 +49,22 @@ pub(crate) fn run_determinant_diversity( power: f64, results_k: Option, ) -> anyhow::Result> { - run_search_determinant_diversity( - runner, - groundtruth, - steps, - |setup, search_l, search_n| { - let base = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); - let processor = + run_search_determinant_diversity(runner, groundtruth, steps, |setup, search_l, search_n| { + let base = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + let processor = diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams::new( results_k.unwrap_or(search_n), eta, power, ).map_err(|e| anyhow::anyhow!("Invalid determinant-diversity parameters: {}", e))?; - let search_params = - diskann_benchmark_core::search::graph::determinant_diversity::Parameters { - inner: base, - processor, - }; - Ok(core_search::Run::new(search_params, setup)) - }, - ) + let search_params = + diskann_benchmark_core::search::graph::determinant_diversity::Parameters { + inner: base, + processor, + }; + Ok(core_search::Run::new(search_params, setup)) + }) } type Run = core_search::Run; @@ -88,7 +82,7 @@ type DeterminantRun = core_search::Run; /// Generic search infrastructure that unifies `run()` and `run_determinant_diversity()`. -/// +/// /// This helper extracts the common loop logic (iterating over threads and runs, /// and building a setup) leaving parameter construction to a builder closure. /// This collapses the benchmark helper infrastructure and reduces duplication. @@ -125,7 +119,7 @@ where } /// Generic search infrastructure for determinant-diversity searches. -/// +/// /// Mirrors the unified logic of `run_search()` but for the DeterminantDiversityKnn trait. fn run_search_determinant_diversity( runner: &dyn DeterminantDiversityKnn, @@ -134,7 +128,13 @@ fn run_search_determinant_diversity( builder: F, ) -> anyhow::Result> where - F: Fn(core_search::Setup, usize, usize) -> anyhow::Result>, + F: Fn( + core_search::Setup, + usize, + usize, + ) -> anyhow::Result< + core_search::Run, + >, { let mut all = Vec::new(); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 7b8092d67..6d7a8a928 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -20,7 +20,7 @@ use diskann::{ graph::{ self, glue::{ - self, ExpandBeam, HasDefaultProcessor, IdIterator, PostProcess, SearchExt, + self, DelegateDefaultPostProcessor, ExpandBeam, IdIterator, PostProcess, SearchExt, SearchPostProcess, SearchStrategy, }, search::Knn, @@ -375,7 +375,7 @@ where } impl<'this, Data, ProviderFactory> - HasDefaultProcessor< + DelegateDefaultPostProcessor< DiskProvider, [Data::VectorDataType], ( diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index ea509b3ff..0b6a15a39 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -61,16 +61,16 @@ where } } -/// [`HasDefaultProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps +/// [`DelegateDefaultPostProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps /// the inner strategy's default processor with [`FilterResults`]. impl - diskann::graph::glue::HasDefaultProcessor< + diskann::graph::glue::DelegateDefaultPostProcessor< DocumentProvider>, FilteredQuery, > for InlineBetaStrategy where DP: DataProvider, - Strategy: diskann::graph::glue::HasDefaultProcessor, + Strategy: diskann::graph::glue::DelegateDefaultPostProcessor, Q: AsyncFriendly + Clone, { type Processor = FilterResults; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 2b928cff7..281abdec0 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -176,7 +176,7 @@ pub(crate) mod tests { self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, StartPointStrategy, config::IntraBatchCandidates, glue::{ - AsElement, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + AsElement, DelegateDefaultPostProcessor, InplaceDeleteStrategy, InsertStrategy, SearchStrategy, aliases, }, index::{PartitionedNeighbors, QueryLabelProvider, QueryVisitDecision}, @@ -350,7 +350,7 @@ pub(crate) mod tests { mut checker: Checker, ) where DP: DataProvider, - S: SearchStrategy + HasDefaultProcessor, + S: SearchStrategy + DelegateDefaultPostProcessor, Q: std::fmt::Debug + Sync + ?Sized, Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, { @@ -398,7 +398,7 @@ pub(crate) mod tests { filter: &dyn QueryLabelProvider, ) where DP: DataProvider, - S: SearchStrategy + HasDefaultProcessor, + S: SearchStrategy + DelegateDefaultPostProcessor, Q: std::fmt::Debug + Sync + ?Sized, Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, { @@ -504,8 +504,8 @@ pub(crate) mod tests { quant_strategy: QS, ) where DP: DataProvider, - FS: SearchStrategy + HasDefaultProcessor + Clone + 'static, - QS: SearchStrategy + HasDefaultProcessor + Clone + 'static, + FS: SearchStrategy + DelegateDefaultPostProcessor + Clone + 'static, + QS: SearchStrategy + DelegateDefaultPostProcessor + Clone + 'static, T: Default + Clone + Send + Sync + std::fmt::Debug, { // Assume all vectors have the same length. @@ -928,7 +928,7 @@ pub(crate) mod tests { T: VectorRepr + GenerateSphericalData + Into, S: InsertStrategy, [T]> + SearchStrategy, [T]> - + HasDefaultProcessor, [T]> + + DelegateDefaultPostProcessor, [T]> + Clone + 'static, rand::distr::StandardUniform: Distribution, @@ -1056,7 +1056,7 @@ pub(crate) mod tests { T: VectorRepr + GenerateSphericalData + Into, S: InsertStrategy, [T]> + SearchStrategy, [T]> - + HasDefaultProcessor, [T]> + + DelegateDefaultPostProcessor, [T]> + Clone + 'static, rand::distr::StandardUniform: Distribution, diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index bfcf448a7..60ba2f632 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -232,7 +232,7 @@ where ) -> ANNResult where T: Sync + ?Sized, - S: SearchStrategy + glue::HasDefaultProcessor, + S: SearchStrategy + glue::DelegateDefaultPostProcessor, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index ee63e1b00..cd6c6510f 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -22,8 +22,8 @@ use diskann::{ graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ - self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, + InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -1486,7 +1486,7 @@ where } } -impl HasDefaultProcessor, [T]> for Internal +impl DelegateDefaultPostProcessor, [T]> for Internal where T: VectorRepr, Q: AsyncFriendly, @@ -1517,7 +1517,7 @@ where } } -impl HasDefaultProcessor, [T]> for FullPrecision +impl DelegateDefaultPostProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, @@ -1599,7 +1599,8 @@ where } } -impl HasDefaultProcessor, [T]> for Internal +impl DelegateDefaultPostProcessor, [T]> + for Internal where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -1629,7 +1630,7 @@ where } } -impl HasDefaultProcessor, [T]> for Hybrid +impl DelegateDefaultPostProcessor, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index 915b8ca09..a66214170 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -980,13 +980,13 @@ where } } -/// [`HasDefaultProcessor`] delegation for [`Cached`]. The processor is composed by +/// [`DelegateDefaultPostProcessor`] delegation for [`Cached`]. The processor is composed by /// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. -impl glue::HasDefaultProcessor, T> for Cached +impl glue::DelegateDefaultPostProcessor, T> for Cached where T: ?Sized, DP: DataProvider, - S: glue::HasDefaultProcessor + S: glue::DelegateDefaultPostProcessor + for<'a> SearchStrategy: CacheableAccessor>, C: for<'a> AsCacheAccessorFor< 'a, @@ -1076,7 +1076,7 @@ where S: InplaceDeleteStrategy, Cached: PruneStrategy>, for<'a> Cached: - glue::HasDefaultProcessor, S::DeleteElement<'a>>, + glue::DelegateDefaultPostProcessor, S::DeleteElement<'a>>, C: AsyncFriendly, { type DeleteElement<'a> = S::DeleteElement<'a>; diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index 0061f3627..4b0983ba8 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -17,7 +17,7 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - AsElement, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, + AsElement, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, }, @@ -902,7 +902,7 @@ impl SearchStrategy for Internal { } } -impl HasDefaultProcessor for Internal { +impl DelegateDefaultPostProcessor for Internal { delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); } @@ -920,7 +920,7 @@ impl SearchStrategy for FullPrecision { } } -impl HasDefaultProcessor for FullPrecision { +impl DelegateDefaultPostProcessor for FullPrecision { delegate_default_post_process!(Pipeline); } @@ -938,7 +938,7 @@ impl SearchStrategy for Internal { } } -impl HasDefaultProcessor for Internal { +impl DelegateDefaultPostProcessor for Internal { delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); } @@ -956,7 +956,7 @@ impl SearchStrategy for Quantized { } } -impl HasDefaultProcessor for Quantized { +impl DelegateDefaultPostProcessor for Quantized { delegate_default_post_process!(Pipeline); } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 06a146004..ed9c451c6 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -12,7 +12,7 @@ use diskann::{ graph::{ SearchOutputBuffer, glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, @@ -460,7 +460,8 @@ where } } -impl HasDefaultProcessor, [T]> for FullPrecision +impl DelegateDefaultPostProcessor, [T]> + for FullPrecision where T: VectorRepr, Q: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 4d176052b..bd6ac6d3b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -9,8 +9,8 @@ use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InplaceDeleteStrategy, - InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, + InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -484,7 +484,7 @@ where } } -impl HasDefaultProcessor, [T]> +impl DelegateDefaultPostProcessor, [T]> for Internal where T: VectorRepr, @@ -517,7 +517,8 @@ where } } -impl HasDefaultProcessor, [T]> for Hybrid +impl DelegateDefaultPostProcessor, [T]> + for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -638,7 +639,7 @@ where } } -impl HasDefaultProcessor, [T]> +impl DelegateDefaultPostProcessor, [T]> for Quantized where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 0a450d8e2..e894abdb6 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -10,7 +10,7 @@ use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -624,7 +624,8 @@ where } impl - HasDefaultProcessor, D, Ctx>, [T]> for Quantized + DelegateDefaultPostProcessor, D, Ctx>, [T]> + for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -661,7 +662,8 @@ where } impl - HasDefaultProcessor, D, Ctx>, [T]> for Quantized + DelegateDefaultPostProcessor, D, Ctx>, [T]> + for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 552001a07..8efe73c56 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -12,7 +12,7 @@ use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - self, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -572,7 +572,7 @@ where } } -impl HasDefaultProcessor, [T]> +impl DelegateDefaultPostProcessor, [T]> for Quantized where T: VectorRepr, @@ -605,7 +605,7 @@ where } } -impl HasDefaultProcessor, [T]> +impl DelegateDefaultPostProcessor, [T]> for Quantized where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index ef3329a3e..8ffc6acf8 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -10,7 +10,7 @@ use diskann::{ ANNError, ANNResult, error::{RankedError, ToRanked, TransientError}, graph::glue::{ - AsElement, CopyIds, ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, + AsElement, CopyIds, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, neighbor::Neighbor, @@ -251,7 +251,7 @@ impl SearchStrategy for Flaky { } } -impl HasDefaultProcessor for Flaky { +impl DelegateDefaultPostProcessor for Flaky { delegate_default_post_process!(CopyIds); } diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index afad564dd..5c09cbf8e 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -144,16 +144,16 @@ where } } -/// [`HasDefaultProcessor`] delegation for [`BetaFilter`]. The processor is composed by +/// [`DelegateDefaultPostProcessor`] delegation for [`BetaFilter`]. The processor is composed by /// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. -impl glue::HasDefaultProcessor +impl glue::DelegateDefaultPostProcessor for BetaFilter where T: ?Sized, I: VectorId, O: Send, Provider: DataProvider, - Strategy: glue::HasDefaultProcessor, + Strategy: glue::DelegateDefaultPostProcessor, { type Processor = glue::Pipeline; @@ -559,7 +559,7 @@ mod tests { } } - impl glue::HasDefaultProcessor for SimpleStrategy { + impl glue::DelegateDefaultPostProcessor for SimpleStrategy { diskann::delegate_default_post_process!(CopyIds); } diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index c88bad27f..baccb9cf5 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -339,7 +339,7 @@ where /// [`super::search::Knn`]) bounds on. /// /// The blanket impl covers `P = DefaultPostProcess` for any strategy implementing -/// [`HasDefaultProcessor`]. Custom processor types (e.g. `DeterminantDiversitySearchParams`) can have +/// [`DelegateDefaultPostProcessor`]. Custom processor types (e.g. `DeterminantDiversitySearchParams`) can have /// their own `PostProcess` impls without coherence conflicts. pub trait PostProcess::InternalId>: SearchStrategy @@ -369,8 +369,8 @@ where /// /// Strategies implementing this trait work with [`super::search::Knn`] (no explicit /// processor). The old `SearchStrategy::PostProcessor` associated type is replaced by -/// `HasDefaultProcessor::Processor`. -pub trait HasDefaultProcessor::InternalId>: +/// `DelegateDefaultPostProcessor::Processor`. +pub trait DelegateDefaultPostProcessor::InternalId>: SearchStrategy where Provider: DataProvider, @@ -386,7 +386,7 @@ where /// Aggregate trait for strategies that support both search access and a default post-processor. pub trait DefaultSearchStrategy::InternalId>: - SearchStrategy + HasDefaultProcessor + SearchStrategy + DelegateDefaultPostProcessor where Provider: DataProvider, T: ?Sized, @@ -396,20 +396,20 @@ where impl DefaultSearchStrategy for S where - S: SearchStrategy + HasDefaultProcessor, + S: SearchStrategy + DelegateDefaultPostProcessor, Provider: DataProvider, T: ?Sized, O: Send, { } -/// Convenience macro for implementing [`HasDefaultProcessor`] when the processor +/// Convenience macro for implementing [`DelegateDefaultPostProcessor`] when the processor /// is a [`Default`]-constructible type. /// /// # Example /// /// ```ignore -/// impl HasDefaultProcessor for MyStrategy { +/// impl DelegateDefaultPostProcessor for MyStrategy { /// delegate_default_post_process!(CopyIds); /// } /// ``` @@ -433,7 +433,7 @@ pub struct DefaultPostProcess; impl PostProcess for S where - S: HasDefaultProcessor, + S: DelegateDefaultPostProcessor, Provider: DataProvider, T: ?Sized + Sync, O: Send, @@ -1158,7 +1158,7 @@ mod tests { } } - impl HasDefaultProcessor for Strategy { + impl DelegateDefaultPostProcessor for Strategy { delegate_default_post_process!(CopyIds); } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index e4541d07e..c239e8671 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2229,7 +2229,7 @@ where where T: ?Sized, S: SearchStrategy: IdIterator> - + glue::HasDefaultProcessor, + + glue::DelegateDefaultPostProcessor, I: Iterator::InternalId>, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 68994f7ba..786bf35ac 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -964,7 +964,7 @@ impl glue::SearchStrategy for Strategy { } } -impl glue::HasDefaultProcessor for Strategy { +impl glue::DelegateDefaultPostProcessor for Strategy { delegate_default_post_process!(glue::CopyIds); } From b1662c1b71456d1bb7eacd2698cb61650b830716 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 19:06:29 +0530 Subject: [PATCH 12/47] Provider refactor: runtime start-point filtering; remove Internal - Add runtime filter_start_points flag to RemoveDeletedIdsAndCopy and Rerank - Route default search through runtime-configurable processors (no FilterStartPoints pipeline) - Set inplace-delete search processors to filter_start_points=false - Remove Internal strategy indirection and update async providers accordingly --- .../graph/provider/async_/bf_tree/provider.rs | 194 ++++++++++-------- .../src/model/graph/provider/async_/common.rs | 4 - .../graph/provider/async_/debug_provider.rs | 113 +++++----- .../provider/async_/inmem/full_precision.rs | 87 +++++--- .../graph/provider/async_/inmem/product.rs | 76 +++---- .../graph/provider/async_/inmem/scalar.rs | 6 +- .../graph/provider/async_/inmem/spherical.rs | 6 +- .../graph/provider/async_/postprocess.rs | 50 ++++- 8 files changed, 319 insertions(+), 217 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index cd6c6510f..1b54c9fbc 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -19,11 +19,12 @@ use bf_tree::{BfTree, Config}; use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, + error::IntoANNResult, graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, - InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -45,7 +46,7 @@ use crate::model::{ vector_provider::VectorProvider, }, common::{ - CreateDeleteProvider, FullPrecision, Hybrid, Internal, NoDeletes, NoStore, Panics, + CreateDeleteProvider, FullPrecision, Hybrid, NoDeletes, NoStore, Panics, }, distances, postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, @@ -1467,7 +1468,7 @@ where /// Perform a search entirely in the full-precision space. /// /// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for Internal +impl SearchStrategy, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, @@ -1486,7 +1487,7 @@ where } } -impl DelegateDefaultPostProcessor, [T]> for Internal +impl DelegateDefaultPostProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, @@ -1495,48 +1496,58 @@ where delegate_default_post_process!(RemoveDeletedIdsAndCopy); } -/// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. -impl SearchStrategy, [T]> for FullPrecision +impl PostProcess, [T], RemoveDeletedIdsAndCopy> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, { - type QueryComputer = T::QueryDistance; - type SearchAccessor<'a> = FullAccessor<'a, T, Q, D>; - type SearchAccessorError = Panics; - - fn search_accessor<'a>( - &'a self, - provider: &'a BfTreeProvider, - _context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: RemoveDeletedIdsAndCopy, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } } } -impl DelegateDefaultPostProcessor, [T]> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, -{ - delegate_default_post_process!(glue::Pipeline); +/// An [`glue::SearchPostProcess`] implementation that reranks PQ vectors. +#[derive(Debug, Clone, Copy)] +pub struct Rerank { + pub filter_start_points: bool, } -/// An [`glue::SearchPostProcess`] implementation that reranks PQ vectors. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; +impl Default for Rerank { + fn default() -> Self { + Self { + filter_start_points: true, + } + } +} impl<'a, T, D> glue::SearchPostProcess, [T]> for Rerank where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - type Error = Panics; + type Error = ANNError; + #[allow(clippy::manual_async_fn)] fn post_process( &self, accessor: &mut QuantAccessor<'a, T, D>, @@ -1546,42 +1557,55 @@ where output: &mut B, ) -> impl Future> + Send where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, { - let provider = &accessor.provider; - let checker = accessor.as_deletion_check(); - let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { + async move { + let provider = &accessor.provider; + let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); + let is_not_start_point = if self.filter_start_points { + Some(accessor.is_not_start_point().await?) + } else { + None + }; + let checker = accessor.as_deletion_check(); + + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + return None; + } + + if let Some(filter) = is_not_start_point.as_ref() + && !filter(n.id) + { + return None; + } + #[allow(clippy::expect_used)] let vec = provider .full_vectors .get_vector_sync(n.id.into_usize()) .expect("Full vector provider failed to retrieve element"); Some((n.id, f.evaluate_similarity(query, &vec))) - } - }) - .collect(); + }) + .collect(); - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) + reranked.sort_unstable_by(|a, b| { + (a.1) + .partial_cmp(&b.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + Ok(output.extend(reranked)) + } } } /// Perform a search entirely in the quantized space. /// -/// Starting points are not filtered out of the final results but results are reranked using +/// Starting points are are filtered out of the final results and results are reranked using /// the full-precision data. -impl SearchStrategy, [T]> for Internal +impl SearchStrategy, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -1599,8 +1623,7 @@ where } } -impl DelegateDefaultPostProcessor, [T]> - for Internal +impl DelegateDefaultPostProcessor, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -1608,36 +1631,35 @@ where delegate_default_post_process!(Rerank); } -/// Perform a search entirely in the quantized space. -/// -/// Starting points are are filtered out of the final results and results are reranked using -/// the full-precision data. -impl SearchStrategy, [T]> for Hybrid +impl PostProcess, [T], Rerank> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - type QueryComputer = pq::distance::QueryComputer>; - type SearchAccessor<'a> = QuantAccessor<'a, T, D>; - type SearchAccessorError = Panics; - - fn search_accessor<'a>( - &'a self, - provider: &'a BfTreeProvider, - _context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: Rerank, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } } } -impl DelegateDefaultPostProcessor, [T]> for Hybrid -where - T: VectorRepr, - D: AsyncFriendly + DeletionCheck, -{ - delegate_default_post_process!(glue::Pipeline); -} - // Pruning impl PruneStrategy> for FullPrecision where @@ -1746,10 +1768,10 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; - type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; - type SearchStrategy = Internal; + type SearchPostProcessor = RemoveDeletedIdsAndCopy; + type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) + Self } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -1757,7 +1779,9 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Default::default() + RemoveDeletedIdsAndCopy { + filter_start_points: false, + } } async fn get_delete_element<'a>( @@ -1785,10 +1809,10 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; - type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; - type SearchStrategy = Internal; + type SearchPostProcessor = Rerank; + type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) + *self } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -1796,7 +1820,9 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Default::default() + Rerank { + filter_start_points: false, + } } async fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/common.rs b/diskann-providers/src/model/graph/provider/async_/common.rs index e0262a981..7da52b30a 100644 --- a/diskann-providers/src/model/graph/provider/async_/common.rs +++ b/diskann-providers/src/model/graph/provider/async_/common.rs @@ -416,10 +416,6 @@ impl Hybrid { } } -/// Internal variant of above strategy types to avoid start point filtering. -#[derive(Debug)] -pub struct Internal(pub T); - #[cfg(test)] pub struct TestCallCount { count: std::sync::atomic::AtomicUsize, diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index 4b0983ba8..8590c7bec 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -14,14 +14,15 @@ use std::{ use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNErrorKind, ANNResult, + error::IntoANNResult, graph::{ - AdjacencyList, + AdjacencyList, SearchOutputBuffer, glue::{ - AsElement, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, - InplaceDeleteStrategy, InsertStrategy, Pipeline, PruneStrategy, SearchExt, - SearchStrategy, + AsElement, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, + InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, }, + neighbor::Neighbor, provider::{ self, Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultAccessor, DefaultContext, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, @@ -39,7 +40,7 @@ use crate::{ FixedChunkPQTable, distance::{DistanceComputer, QueryComputer}, graph::provider::async_::{ - common::{FullPrecision, Internal, Panics, Quantized}, + common::{FullPrecision, Panics, Quantized}, distances::{self, pq::Hybrid}, postprocess, }, @@ -888,7 +889,7 @@ impl FillSet for HybridAccessor<'_> { // Strategies // //////////////// -impl SearchStrategy for Internal { +impl SearchStrategy for FullPrecision { type QueryComputer = ::QueryDistance; type SearchAccessorError = Panics; type SearchAccessor<'a> = FullAccessor<'a>; @@ -902,29 +903,36 @@ impl SearchStrategy for Internal { } } -impl DelegateDefaultPostProcessor for Internal { +impl DelegateDefaultPostProcessor for FullPrecision { delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); } -impl SearchStrategy for FullPrecision { - type QueryComputer = ::QueryDistance; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = FullAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) +impl PostProcess for FullPrecision { + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: postprocess::RemoveDeletedIdsAndCopy, + accessor: &mut Self::SearchAccessor<'a>, + query: &[f32], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + diskann::graph::glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } } } -impl DelegateDefaultPostProcessor for FullPrecision { - delegate_default_post_process!(Pipeline); -} - -impl SearchStrategy for Internal { +impl SearchStrategy for Quantized { type QueryComputer = pq::distance::QueryComputer>; type SearchAccessorError = Panics; type SearchAccessor<'a> = QuantAccessor<'a>; @@ -938,28 +946,35 @@ impl SearchStrategy for Internal { } } -impl DelegateDefaultPostProcessor for Internal { +impl DelegateDefaultPostProcessor for Quantized { delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); } -impl SearchStrategy for Quantized { - type QueryComputer = pq::distance::QueryComputer>; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = QuantAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) +impl PostProcess for Quantized { + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: postprocess::RemoveDeletedIdsAndCopy, + accessor: &mut Self::SearchAccessor<'a>, + query: &[f32], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + diskann::graph::glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } } } -impl DelegateDefaultPostProcessor for Quantized { - delegate_default_post_process!(Pipeline); -} - impl PruneStrategy for FullPrecision { type DistanceComputer = ::Distance; type PruneAccessor<'a> = FullAccessor<'a>; @@ -1049,19 +1064,21 @@ impl InplaceDeleteStrategy for FullPrecision { type DeleteElementGuard = Vec; type DeleteElementError = Panics; type PruneStrategy = Self; - type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; - type SearchStrategy = Internal; + type SearchPostProcessor = postprocess::RemoveDeletedIdsAndCopy; + type SearchStrategy = Self; fn prune_strategy(&self) -> Self::PruneStrategy { *self } fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) + *self } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Default::default() + postprocess::RemoveDeletedIdsAndCopy { + filter_start_points: false, + } } fn get_delete_element<'a>( @@ -1080,19 +1097,21 @@ impl InplaceDeleteStrategy for Quantized { type DeleteElementGuard = Vec; type DeleteElementError = Panics; type PruneStrategy = Self; - type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; - type SearchStrategy = Internal; + type SearchPostProcessor = postprocess::RemoveDeletedIdsAndCopy; + type SearchStrategy = Self; fn prune_strategy(&self) -> Self::PruneStrategy { *self } fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) + *self } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Default::default() + postprocess::RemoveDeletedIdsAndCopy { + filter_start_points: false, + } } fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index ed9c451c6..9bd980ed4 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -12,7 +12,7 @@ use diskann::{ graph::{ SearchOutputBuffer, glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, @@ -383,16 +383,31 @@ pub trait GetFullPrecision { /// 1. Filters out deleted ids from being returned. /// 2. Reranks a candidate stream using full-precision distances. /// 3. Copies back the results to the output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; +#[derive(Debug, Clone, Copy)] +pub struct Rerank { + pub filter_start_points: bool, +} + +impl Default for Rerank { + fn default() -> Self { + Self { + filter_start_points: true, + } + } +} impl glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, + A: BuildQueryComputer<[T], Id = u32> + + GetFullPrecision + + AsDeletionCheck + + SearchExt, + ::Checker: Sync, { - type Error = Panics; + type Error = ANNError; + #[allow(clippy::manual_async_fn)] fn post_process( &self, accessor: &mut A, @@ -402,34 +417,48 @@ where output: &mut B, ) -> impl Future> + Send where - I: Iterator>, - B: SearchOutputBuffer + ?Sized, + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, { - let full = accessor.as_full_precision(); - let checker = accessor.as_deletion_check(); - let f = full.distance(); - - // Filter before computing the full precision distances. - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { + async move { + let full = accessor.as_full_precision(); + let f = full.distance(); + let is_not_start_point = if self.filter_start_points { + Some(accessor.is_not_start_point().await?) + } else { + None + }; + let checker = accessor.as_deletion_check(); + + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + return None; + } + + if let Some(filter) = is_not_start_point.as_ref() + && !filter(n.id) + { + return None; + } + Some(( n.id, f.evaluate_similarity(query, unsafe { full.get_vector_sync(n.id.into_usize()) }), )) - } - }) - .collect(); - - // Sort the full precision distances. - reranked - .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - // Store the reranked results. - std::future::ready(Ok(output.extend(reranked))) + }) + .collect(); + + reranked.sort_unstable_by(|a, b| { + (a.1) + .partial_cmp(&b.1) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(output.extend(reranked)) + } } } @@ -468,7 +497,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(glue::Pipeline); + delegate_default_post_process!(RemoveDeletedIdsAndCopy); } impl PostProcess, [T], RemoveDeletedIdsAndCopy> @@ -637,7 +666,9 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - RemoveDeletedIdsAndCopy + RemoveDeletedIdsAndCopy { + filter_start_points: false, + } } async fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index bd6ac6d3b..f4388fcb6 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -8,10 +8,13 @@ use std::{collections::HashMap, future::Future, sync::Arc}; use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, + error::IntoANNResult, + graph::SearchOutputBuffer, graph::glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, - InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, + InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, + neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, HasId, @@ -27,7 +30,7 @@ use crate::model::{ FastMemoryQuantVectorProviderAsync, FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, common::{ - CreateVectorStore, Hybrid, Internal, NoStore, Panics, Quantized, SetElementHelper, + CreateVectorStore, Hybrid, NoStore, Panics, Quantized, SetElementHelper, VectorStore, }, distances, @@ -462,10 +465,9 @@ where /// Perform a search entirely in the quantized space. /// -/// Starting points are not filtered out of the final results but results are reranked using +/// Starting points are filtered out of the final results and results are reranked using /// the full-precision data. -impl SearchStrategy, [T]> - for Internal +impl SearchStrategy, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -485,7 +487,7 @@ where } impl DelegateDefaultPostProcessor, [T]> - for Internal + for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -494,37 +496,35 @@ where delegate_default_post_process!(Rerank); } -/// Perform a search entirely in the quantized space. -/// -/// Starting points are filtered out of the final results and results are reranked using -/// the full-precision data. -impl SearchStrategy, [T]> for Hybrid -where - T: VectorRepr, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type QueryComputer = pq::distance::QueryComputer>; - type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; - type SearchAccessorError = Panics; - - fn search_accessor<'a>( - &'a self, - provider: &'a FullPrecisionProvider, - _context: &'a Ctx, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) - } -} - -impl DelegateDefaultPostProcessor, [T]> +impl PostProcess, [T], Rerank> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(glue::Pipeline); + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: Rerank, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } + } } impl PruneStrategy> for Hybrid @@ -589,10 +589,10 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; - type SearchPostProcessor = diskann::graph::glue::DefaultPostProcess; - type SearchStrategy = Internal; + type SearchPostProcessor = Rerank; + type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) + *self } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -600,7 +600,9 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Default::default() + Rerank { + filter_start_points: false, + } } async fn get_delete_element<'a>( @@ -646,7 +648,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(glue::Pipeline); + delegate_default_post_process!(RemoveDeletedIdsAndCopy); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index e894abdb6..94897dc4b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -10,7 +10,7 @@ use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, + DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -633,7 +633,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - delegate_default_post_process!(glue::Pipeline); + delegate_default_post_process!(Rerank); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -671,7 +671,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - delegate_default_post_process!(glue::Pipeline); + delegate_default_post_process!(RemoveDeletedIdsAndCopy); } impl PruneStrategy, D, Ctx>> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 8efe73c56..28d30c6a3 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -12,7 +12,7 @@ use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, + DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -579,7 +579,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(glue::Pipeline); + delegate_default_post_process!(Rerank); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -612,7 +612,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(glue::Pipeline); + delegate_default_post_process!(RemoveDeletedIdsAndCopy); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index bf6a47bba..ec3b250cc 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -9,6 +9,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, provider::BuildQueryComputer, + ANNError, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -34,16 +35,28 @@ pub(crate) trait DeletionCheck { /// A [`SearchPostProcess`] routine that fuses the removal of deleted elements with the /// copying of IDs into an output buffer. -#[derive(Debug, Clone, Copy, Default)] -pub struct RemoveDeletedIdsAndCopy; +#[derive(Debug, Clone, Copy)] +pub struct RemoveDeletedIdsAndCopy { + pub filter_start_points: bool, +} + +impl Default for RemoveDeletedIdsAndCopy { + fn default() -> Self { + Self { + filter_start_points: true, + } + } +} impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + AsDeletionCheck, + A: BuildQueryComputer + AsDeletionCheck + glue::SearchExt, + ::Checker: Sync, T: ?Sized, { - type Error = std::convert::Infallible; + type Error = ANNError; + #[allow(clippy::manual_async_fn)] fn post_process( &self, accessor: &mut A, @@ -56,14 +69,29 @@ where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { - let checker = accessor.as_deletion_check(); - let count = output.extend(candidates.filter_map(|n| { - if checker.deletion_check(n.id) { + async move { + let is_not_start_point = if self.filter_start_points { + Some(accessor.is_not_start_point().await?) + } else { None + }; + + let checker = accessor.as_deletion_check(); + let filtered = candidates.filter_map(|n| { + if checker.deletion_check(n.id) { + None + } else { + Some((n.id, n.distance)) + } + }); + + let count = if let Some(filter) = is_not_start_point { + output.extend(filtered.filter(|(id, _)| filter(*id))) } else { - Some((n.id, n.distance)) - } - })); - std::future::ready(Ok(count)) + output.extend(filtered) + }; + + Ok(count) + } } } From 214938862403f6e5c816bd55159b348b381e9e69 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 19:20:53 +0530 Subject: [PATCH 13/47] Fix: Align SearchOutputBuffer bound with trait definition (+?Sized) --- diskann/src/graph/search/diverse_search.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 71b217a7e..e1d63e4be 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -113,7 +113,7 @@ where where S: PostProcess, PP: Send + Sync, - OB: SearchOutputBuffer + Send, + OB: SearchOutputBuffer + Send + ?Sized, { async move { let mut accessor = strategy From 6a1872dad16b597f126b271dbc6f80e0e3c20827 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 21:30:36 +0530 Subject: [PATCH 14/47] merge issues fix --- .../src/backend/disk_index/search.rs | 7 +- diskann-garnet/src/provider.rs | 105 ++++++++++++------ diskann-quantization/src/bits/distances.rs | 6 +- .../spherical/quantizer_generated.rs | 5 - .../spherical_quantizer_generated.rs | 5 - .../spherical/supported_metric_generated.rs | 5 - .../transforms/double_hadamard_generated.rs | 5 - .../transforms/null_transform_generated.rs | 5 - .../transforms/padding_hadamard_generated.rs | 5 - .../transforms/random_rotation_generated.rs | 5 - .../transforms/transform_generated.rs | 5 - .../transforms/transform_kind_generated.rs | 5 - 12 files changed, 81 insertions(+), 82 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 70b138cc0..8d908dd74 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -290,9 +290,10 @@ where search_params.is_flat_search, processor, ), - Err(e) => { - Err(format!("Invalid determinant-diversity parameters: {}", e).into()) - } + Err(e) => Err(diskann::ANNError::log_index_error(format!( + "Invalid determinant-diversity parameters: {}", + e + ))), } } else { searcher.search( diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 212b32660..082e04099 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -4,14 +4,17 @@ */ use dashmap::DashMap; +use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNErrorKind, ANNResult, + error::IntoANNResult, graph::{ AdjacencyList, SearchOutputBuffer, config::defaults::MAX_OCCLUSION_SIZE, glue::{ - self, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, - SearchExt, SearchPostProcess, SearchStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, + InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchPostProcess, + SearchStrategy, }, }, neighbor::Neighbor, @@ -24,7 +27,7 @@ use diskann::{ object_pool::{AsPooled, ObjectPool, PooledRef, Undef}, }, }; -use diskann_providers::model::graph::provider::async_::common::{FullPrecision, Internal}; +use diskann_providers::model::graph::provider::async_::common::FullPrecision; use diskann_vector::{PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric}; use std::{ collections::{HashMap, hash_map::Entry}, @@ -718,25 +721,6 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> } } -impl SearchStrategy, [T]> for Internal { - type SearchAccessor<'a> = FullAccessor<'a, T>; - type SearchAccessorError = GarnetProviderError; - type QueryComputer = T::QueryDistance; - type PostProcessor = glue::CopyIds; - - fn search_accessor<'a>( - &'a self, - provider: &'a GarnetProvider, - context: &'a as DataProvider>::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, true)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - /// A [`SearchPostProcess`] base object that copies each `Neighbor` to a `(ExternalId, f32)` pair /// and writes as many as possible to the output buffer. #[derive(Debug, Default, Clone, Copy)] @@ -777,7 +761,6 @@ impl SearchStrategy, [T], GarnetId> for FullPre type SearchAccessor<'a> = FullAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; type QueryComputer = T::QueryDistance; - type PostProcessor = glue::Pipeline; fn search_accessor<'a>( &'a self, @@ -786,16 +769,18 @@ impl SearchStrategy, [T], GarnetId> for FullPre ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider, context, true)) } +} - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } +impl DelegateDefaultPostProcessor, [T], GarnetId> + for FullPrecision +{ + delegate_default_post_process!(glue::Pipeline); } + impl SearchStrategy, [T], u32> for FullPrecision { type SearchAccessor<'a> = FullAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; type QueryComputer = T::QueryDistance; - type PostProcessor = glue::CopyIds; fn search_accessor<'a>( &'a self, @@ -805,8 +790,61 @@ impl SearchStrategy, [T], u32> for FullPrecisio Ok(FullAccessor::new(provider, context, true)) } - fn post_processor(&self) -> Self::PostProcessor { - Default::default() +} + +impl DelegateDefaultPostProcessor, [T], u32> for FullPrecision { + delegate_default_post_process!(glue::CopyIds); +} + +impl PostProcess, [T], glue::CopyIds, u32> for FullPrecision { + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: glue::CopyIds, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + diskann::graph::glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } + } +} + +impl PostProcess, [T], CopyExternalIds, GarnetId> + for FullPrecision +{ + #[allow(clippy::manual_async_fn)] + fn post_process_with<'a, I, B>( + &self, + processor: CopyExternalIds, + accessor: &mut Self::SearchAccessor<'a>, + query: &[T], + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + async move { + diskann::graph::glue::SearchPostProcess::post_process( + &processor, accessor, query, computer, candidates, output, + ) + .await + .into_ann_result() + } } } @@ -858,14 +896,19 @@ impl InplaceDeleteStrategy> for FullPrecision { type DeleteElementError = GarnetProviderError; type PruneStrategy = Self; - type SearchStrategy = Internal; + type SearchPostProcessor = glue::CopyIds; + type SearchStrategy = Self; fn prune_strategy(&self) -> Self::PruneStrategy { Self } fn search_strategy(&self) -> Self::SearchStrategy { - Internal(Self) + Self + } + + fn search_post_processor(&self) -> Self::SearchPostProcessor { + glue::CopyIds } fn get_delete_element<'a>( diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index 9a5d4e72f..e8e454dbf 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -3222,7 +3222,7 @@ mod tests { (dist_8bit.sample(&mut *rng), dist_mbit.sample(&mut *rng)) }) .check_with( - &lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M), + lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M), evaluate_ip, ); } @@ -3250,7 +3250,7 @@ mod tests { let dims = [127, 128, 129, 255, 256, 512, 768, 896, 3072]; for &dim in &dims { let case = HetCase::::new(dim, |_| (255, max_val)); - case.check_with(&lazy_format!("max-value {context} dim={dim}"), evaluate); + case.check_with(lazy_format!("max-value {context} dim={dim}"), evaluate); } } @@ -3320,7 +3320,7 @@ mod tests { // x > 127 sweep (vpmaddubsw unsigned treatment). for x_val in [128i64, 170, 200, 240, 255] { HetCase::::new(block_size, move |_| (x_val, y_half)) - .check_with(&lazy_format!("x > 127 (x_val={x_val})"), evaluate); + .check_with(lazy_format!("x > 127 (x_val={x_val})"), evaluate); } // Dim = block_size - 1 (no full block, all scalar). diff --git a/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs b/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs index 320ca651d..e99fa6767 100644 --- a/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs +++ b/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs b/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs index 9039d45fc..6a0a01c75 100644 --- a/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs +++ b/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs b/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs index d96ac6da6..2db2a9566 100644 --- a/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs +++ b/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs b/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs index d4694e8e4..9ac148ef9 100644 --- a/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs b/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs index 3c4f7788a..a0fc4050b 100644 --- a/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs b/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs index 26931d836..bc59adf2c 100644 --- a/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs b/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs index 64e518e1e..de52a3f4a 100644 --- a/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs b/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs index 7cc3b3afa..efeba2723 100644 --- a/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs b/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs index fbfb9918f..a7c7fd635 100644 --- a/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs @@ -1,8 +1,3 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; From 750d542d198d5c0714bdeff4fd497c2496fe8cac Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 21:35:34 +0530 Subject: [PATCH 15/47] fix merge issues, clippy, fmt --- diskann-garnet/src/provider.rs | 1 - .../graph/provider/async_/bf_tree/provider.rs | 8 ++------ .../graph/provider/async_/inmem/full_precision.rs | 14 ++++---------- .../model/graph/provider/async_/inmem/product.rs | 3 +-- .../model/graph/provider/async_/inmem/scalar.rs | 4 ++-- .../model/graph/provider/async_/inmem/spherical.rs | 4 ++-- .../src/model/graph/provider/async_/postprocess.rs | 2 +- 7 files changed, 12 insertions(+), 24 deletions(-) diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 082e04099..26cdb869c 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -789,7 +789,6 @@ impl SearchStrategy, [T], u32> for FullPrecisio ) -> Result, Self::SearchAccessorError> { Ok(FullAccessor::new(provider, context, true)) } - } impl DelegateDefaultPostProcessor, [T], u32> for FullPrecision { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 3cb34862e..cbabe64cd 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -45,9 +45,7 @@ use crate::model::{ neighbor_provider::NeighborProvider, quant_vector_provider::QuantVectorProvider, vector_provider::VectorProvider, }, - common::{ - CreateDeleteProvider, FullPrecision, Hybrid, NoDeletes, NoStore, Panics, - }, + common::{CreateDeleteProvider, FullPrecision, Hybrid, NoDeletes, NoStore, Panics}, distances, postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, }, @@ -1592,9 +1590,7 @@ where .collect(); reranked.sort_unstable_by(|a, b| { - (a.1) - .partial_cmp(&b.1) - .unwrap_or(std::cmp::Ordering::Equal) + (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) }); Ok(output.extend(reranked)) } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 9bd980ed4..f697ee12f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -12,9 +12,8 @@ use diskann::{ graph::{ SearchOutputBuffer, glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, - InplaceDeleteStrategy, InsertStrategy, PostProcess, PruneStrategy, SearchExt, - SearchStrategy, + self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, + InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -399,10 +398,7 @@ impl Default for Rerank { impl glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> - + GetFullPrecision - + AsDeletionCheck - + SearchExt, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck + SearchExt, ::Checker: Sync, { type Error = ANNError; @@ -452,9 +448,7 @@ where .collect(); reranked.sort_unstable_by(|a, b| { - (a.1) - .partial_cmp(&b.1) - .unwrap_or(std::cmp::Ordering::Equal) + (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) }); Ok(output.extend(reranked)) diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index f4388fcb6..cfe5acc0b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -496,8 +496,7 @@ where delegate_default_post_process!(Rerank); } -impl PostProcess, [T], Rerank> - for Hybrid +impl PostProcess, [T], Rerank> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 94897dc4b..e81c5f4bf 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -10,8 +10,8 @@ use diskann::delegate_default_post_process; use diskann::{ ANNError, ANNResult, graph::glue::{ - DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, + SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 28d30c6a3..64744366b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -12,8 +12,8 @@ use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, + SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index ec3b250cc..e644a3117 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -6,10 +6,10 @@ //! Shared search post-processing. use diskann::{ + ANNError, graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, provider::BuildQueryComputer, - ANNError, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to From 68e4bfd5b4b431bcbb14e87d1b2354bc0bd29a29 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 12 Mar 2026 23:07:22 +0530 Subject: [PATCH 16/47] Fix cached inplace-delete post-processor wiring - Preserve inner search_post_processor for Cached inplace-delete path - Add CachedPostProcess wrapper to avoid PostProcess impl overlap - Keep default post-processing delegation unchanged for normal search --- .../graph/provider/async_/caching/provider.rs | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index a66214170..b493a7598 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -1003,6 +1003,49 @@ where } } +#[derive(Debug, Clone, Copy)] +pub struct CachedPostProcess

(pub P); + +impl glue::PostProcess, T, CachedPostProcess

> + for Cached +where + T: ?Sized, + P: Send + Sync, + DP: DataProvider, + S: glue::PostProcess + + for<'a> SearchStrategy: CacheableAccessor>, + C: for<'a> AsCacheAccessorFor< + 'a, + SearchAccessor<'a, S, DP, T>, + Accessor: NeighborCache, + Error = E, + > + AsyncFriendly, + E: StandardError, +{ + fn post_process_with<'a, I, B>( + &self, + processor: CachedPostProcess

, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + self.strategy.post_process_with( + processor.0, + &mut accessor.inner, + query, + computer, + candidates, + output, + ) + } +} + /// We need `S` to be a [`PruneStrategy`] for the underlying provider. /// /// This strategy has an associated [`PruneElement`] type `E` @@ -1075,8 +1118,11 @@ where DP: DataProvider, S: InplaceDeleteStrategy, Cached: PruneStrategy>, - for<'a> Cached: - glue::DelegateDefaultPostProcessor, S::DeleteElement<'a>>, + for<'a> Cached: glue::PostProcess< + CachingProvider, + S::DeleteElement<'a>, + CachedPostProcess, + >, C: AsyncFriendly, { type DeleteElement<'a> = S::DeleteElement<'a>; @@ -1085,7 +1131,7 @@ where type PruneStrategy = Cached; type SearchStrategy = Cached; - type SearchPostProcessor = glue::DefaultPostProcess; + type SearchPostProcessor = CachedPostProcess; fn prune_strategy(&self) -> Self::PruneStrategy { Cached { @@ -1100,7 +1146,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - glue::DefaultPostProcess + CachedPostProcess(self.strategy.search_post_processor()) } fn get_delete_element<'a>( From 48cd833a41f1f895636242ec582108bf860b833e Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 14:53:27 -0700 Subject: [PATCH 17/47] Revert unrelated quantization changes The post-processing refactor should not touch diskann-quantization. Restores license headers on generated flatbuffer files and reverts a lazy_format! call-site change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- diskann-quantization/src/bits/distances.rs | 6 +++--- .../src/flatbuffers/spherical/quantizer_generated.rs | 5 +++++ .../flatbuffers/spherical/spherical_quantizer_generated.rs | 5 +++++ .../src/flatbuffers/spherical/supported_metric_generated.rs | 5 +++++ .../src/flatbuffers/transforms/double_hadamard_generated.rs | 5 +++++ .../src/flatbuffers/transforms/null_transform_generated.rs | 5 +++++ .../flatbuffers/transforms/padding_hadamard_generated.rs | 5 +++++ .../src/flatbuffers/transforms/random_rotation_generated.rs | 5 +++++ .../src/flatbuffers/transforms/transform_generated.rs | 5 +++++ .../src/flatbuffers/transforms/transform_kind_generated.rs | 5 +++++ 10 files changed, 48 insertions(+), 3 deletions(-) diff --git a/diskann-quantization/src/bits/distances.rs b/diskann-quantization/src/bits/distances.rs index e8e454dbf..9a5d4e72f 100644 --- a/diskann-quantization/src/bits/distances.rs +++ b/diskann-quantization/src/bits/distances.rs @@ -3222,7 +3222,7 @@ mod tests { (dist_8bit.sample(&mut *rng), dist_mbit.sample(&mut *rng)) }) .check_with( - lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M), + &lazy_format!("IP(8,{}) dim={dim}, trial={trial} -- {context}", M), evaluate_ip, ); } @@ -3250,7 +3250,7 @@ mod tests { let dims = [127, 128, 129, 255, 256, 512, 768, 896, 3072]; for &dim in &dims { let case = HetCase::::new(dim, |_| (255, max_val)); - case.check_with(lazy_format!("max-value {context} dim={dim}"), evaluate); + case.check_with(&lazy_format!("max-value {context} dim={dim}"), evaluate); } } @@ -3320,7 +3320,7 @@ mod tests { // x > 127 sweep (vpmaddubsw unsigned treatment). for x_val in [128i64, 170, 200, 240, 255] { HetCase::::new(block_size, move |_| (x_val, y_half)) - .check_with(lazy_format!("x > 127 (x_val={x_val})"), evaluate); + .check_with(&lazy_format!("x > 127 (x_val={x_val})"), evaluate); } // Dim = block_size - 1 (no full block, all scalar). diff --git a/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs b/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs index e99fa6767..320ca651d 100644 --- a/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs +++ b/diskann-quantization/src/flatbuffers/spherical/quantizer_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs b/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs index 6a0a01c75..9039d45fc 100644 --- a/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs +++ b/diskann-quantization/src/flatbuffers/spherical/spherical_quantizer_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs b/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs index 2db2a9566..d96ac6da6 100644 --- a/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs +++ b/diskann-quantization/src/flatbuffers/spherical/supported_metric_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs b/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs index 9ac148ef9..d4694e8e4 100644 --- a/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/double_hadamard_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs b/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs index a0fc4050b..3c4f7788a 100644 --- a/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/null_transform_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs b/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs index bc59adf2c..26931d836 100644 --- a/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/padding_hadamard_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs b/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs index de52a3f4a..64e518e1e 100644 --- a/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/random_rotation_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs b/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs index efeba2723..7cc3b3afa 100644 --- a/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/transform_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; diff --git a/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs b/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs index a7c7fd635..fbfb9918f 100644 --- a/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs +++ b/diskann-quantization/src/flatbuffers/transforms/transform_kind_generated.rs @@ -1,3 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + // automatically generated by the FlatBuffers compiler, do not modify // @generated extern crate alloc; From c0f20537c6f954bc88cbb0d0f70428f6bdd1790e Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 15:19:24 -0700 Subject: [PATCH 18/47] Remove determinant diversity from post-processing refactor Determinant diversity post-processing should be landed separately. This removes: - determinant_diversity_post_process.rs from diskann-providers - determinant_diversity.rs from diskann-benchmark-core - All diversity-related PostProcess impls (full_precision, disk_provider) - Diversity benchmark infrastructure (run_determinant_diversity, DeterminantDiversityKnn trait, search_determinant_diversity) - Diversity input parsing and validation from async_ and disk inputs - Diversity example JSON files Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- copilot.txt | 1 + .../src/search/graph/determinant_diversity.rs | 208 -------- .../src/search/graph/mod.rs | 1 - .../example/wikipedia_compare_detdiv.json | 60 --- .../src/backend/disk_index/search.rs | 32 +- .../src/backend/index/benchmarks.rs | 64 +-- diskann-benchmark/src/backend/index/result.rs | 35 -- .../src/backend/index/search/knn.rs | 122 +---- diskann-benchmark/src/inputs/async_.rs | 36 -- diskann-benchmark/src/inputs/disk.rs | 29 -- .../src/search/provider/disk_provider.rs | 177 ------- .../determinant_diversity_post_process.rs | 484 ------------------ .../provider/async_/inmem/full_precision.rs | 62 --- .../src/model/graph/provider/async_/mod.rs | 5 - example.rs | 40 ++ post_process_design_sketch.rs | 418 +++++++++++++++ 16 files changed, 467 insertions(+), 1307 deletions(-) create mode 100644 copilot.txt delete mode 100644 diskann-benchmark-core/src/search/graph/determinant_diversity.rs delete mode 100644 diskann-benchmark/example/wikipedia_compare_detdiv.json delete mode 100644 diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs create mode 100644 example.rs create mode 100644 post_process_design_sketch.rs diff --git a/copilot.txt b/copilot.txt new file mode 100644 index 000000000..722f174ed --- /dev/null +++ b/copilot.txt @@ -0,0 +1 @@ +copilot --resume=c5407796-a927-4cca-9c30-d450692d150a diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs deleted file mode 100644 index 2f413a117..000000000 --- a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::sync::Arc; - -use diskann::{ - ANNResult, - graph::{self, glue}, - provider, -}; -use diskann_benchmark_runner::utils::{MicroSeconds, percentiles}; -use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; -use diskann_utils::{future::AsyncFriendly, views::Matrix}; - -use crate::{ - recall, - search::{self, Search, graph::Strategy}, - utils, -}; - -#[derive(Debug, Clone, Copy)] -pub struct Parameters { - pub inner: graph::search::Knn, - pub processor: DeterminantDiversitySearchParams, -} - -/// A built-in helper for benchmarking determinant-diversity K-nearest neighbors. -#[derive(Debug)] -pub struct KNN -where - DP: provider::DataProvider, -{ - index: Arc>, - queries: Arc>, - strategy: Strategy, -} - -impl KNN -where - DP: provider::DataProvider, -{ - pub fn new( - index: Arc>, - queries: Arc>, - strategy: Strategy, - ) -> anyhow::Result> { - strategy.length_compatible(queries.nrows())?; - - Ok(Arc::new(Self { - index, - queries, - strategy, - })) - } -} - -impl Search for KNN -where - DP: provider::DataProvider, - S: glue::DefaultSearchStrategy - + glue::PostProcess - + Clone - + AsyncFriendly, - T: AsyncFriendly + Clone, -{ - type Id = DP::ExternalId; - type Parameters = Parameters; - type Output = super::knn::Metrics; - - fn num_queries(&self) -> usize { - self.queries.nrows() - } - - fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(parameters.inner.k_value()) - } - - async fn search( - &self, - parameters: &Self::Parameters, - buffer: &mut O, - index: usize, - ) -> ANNResult - where - O: graph::SearchOutputBuffer + Send, - { - let context = DP::Context::default(); - let stats = self - .index - .search_with( - parameters.inner, - self.strategy.get(index)?, - parameters.processor, - &context, - self.queries.row(index), - buffer, - ) - .await?; - - Ok(super::knn::Metrics { - comparisons: stats.cmps, - hops: stats.hops, - }) - } -} - -/// Summary for determinant-diversity KNN runs. -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct Summary { - pub setup: search::Setup, - pub parameters: Parameters, - pub end_to_end_latencies: Vec, - pub mean_latencies: Vec, - pub p90_latencies: Vec, - pub p99_latencies: Vec, - pub recall: recall::RecallMetrics, - pub mean_cmps: f64, - pub mean_hops: f64, -} - -pub struct Aggregator<'a, I> { - groundtruth: &'a dyn crate::recall::Rows, - recall_k: usize, - recall_n: usize, -} - -impl<'a, I> Aggregator<'a, I> { - pub fn new( - groundtruth: &'a dyn crate::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> Self { - Self { - groundtruth, - recall_k, - recall_n, - } - } -} - -impl search::Aggregate for Aggregator<'_, I> -where - I: crate::recall::RecallCompatible, -{ - type Output = Summary; - - fn aggregate( - &mut self, - run: search::Run, - mut results: Vec>, - ) -> anyhow::Result

{ - let recall = match results.first() { - Some(first) => crate::recall::knn( - self.groundtruth, - None, - first.ids().as_rows(), - self.recall_k, - self.recall_n, - true, - )?, - None => anyhow::bail!("Results must be non-empty"), - }; - - let mut mean_latencies = Vec::with_capacity(results.len()); - let mut p90_latencies = Vec::with_capacity(results.len()); - let mut p99_latencies = Vec::with_capacity(results.len()); - - results.iter_mut().for_each(|r| { - match percentiles::compute_percentiles(r.latencies_mut()) { - Ok(values) => { - let percentiles::Percentiles { mean, p90, p99, .. } = values; - mean_latencies.push(mean); - p90_latencies.push(p90); - p99_latencies.push(p99); - } - Err(_) => { - let zero = MicroSeconds::new(0); - mean_latencies.push(0.0); - p90_latencies.push(zero); - p99_latencies.push(zero); - } - } - }); - - Ok(Summary { - setup: run.setup().clone(), - parameters: *run.parameters(), - end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), - recall, - mean_latencies, - p90_latencies, - p99_latencies, - mean_cmps: utils::average_all( - results - .iter() - .flat_map(|r| r.output().iter().map(|o| o.comparisons)), - ), - mean_hops: utils::average_all( - results - .iter() - .flat_map(|r| r.output().iter().map(|o| o.hops)), - ), - }) - } -} diff --git a/diskann-benchmark-core/src/search/graph/mod.rs b/diskann-benchmark-core/src/search/graph/mod.rs index cfcecb0db..eddb4fbcf 100644 --- a/diskann-benchmark-core/src/search/graph/mod.rs +++ b/diskann-benchmark-core/src/search/graph/mod.rs @@ -3,7 +3,6 @@ * Licensed under the MIT license. */ -pub mod determinant_diversity; pub mod knn; pub mod multihop; pub mod range; diff --git a/diskann-benchmark/example/wikipedia_compare_detdiv.json b/diskann-benchmark/example/wikipedia_compare_detdiv.json deleted file mode 100644 index 3e4ffd150..000000000 --- a/diskann-benchmark/example/wikipedia_compare_detdiv.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "search_directories": [ - "C:/wikipedia_dataset" - ], - "jobs": [ - { - "type": "async-index-build", - "content": { - "source": { - "index-source": "Load", - "data_type": "float32", - "distance": "squared_l2", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" - }, - "search_phase": { - "search-type": "topk", - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "reps": 1, - "num_threads": [8], - "runs": [ - { - "search_n": 10, - "search_l": [20, 30, 40, 50, 100, 200], - "recall_k": 10 - } - ] - } - } - }, - { - "type": "async-index-build", - "content": { - "source": { - "index-source": "Load", - "data_type": "float32", - "distance": "squared_l2", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" - }, - "search_phase": { - "search-type": "topk", - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "reps": 1, - "determinant_diversity_eta": 0.01, - "determinant_diversity_power": 1.0, - "determinant_diversity_results_k": 10, - "num_threads": [8], - "runs": [ - { - "search_n": 10, - "search_l": [20, 30, 40, 50, 100, 200], - "recall_k": 10 - } - ] - } - } - } - ] -} diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 8d908dd74..ce51366e7 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -19,7 +19,6 @@ use diskann_disk::{ storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, }; -use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ storage::{ @@ -270,41 +269,14 @@ where as Box bool + Send + Sync>) }; - let search_result = if let (Some(eta), Some(power)) = ( - search_params.determinant_diversity_eta, - search_params.determinant_diversity_power, - ) { - match DeterminantDiversitySearchParams::new( - search_params - .determinant_diversity_results_k - .unwrap_or(search_params.recall_at as usize), - eta, - power, - ) { - Ok(processor) => searcher.search_determinant_diversity( - q, - search_params.recall_at, - l, - Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, - processor, - ), - Err(e) => Err(diskann::ANNError::log_index_error(format!( - "Invalid determinant-diversity parameters: {}", - e - ))), - } - } else { - searcher.search( + let search_result = searcher.search( q, search_params.recall_at, l, Some(search_params.beam_width), vector_filter, search_params.is_flat_search, - ) - }; + ); match search_result { Ok(search_result) => { diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 07a0b942a..557e7594e 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -22,7 +22,6 @@ use diskann_benchmark_runner::{ utils::datatype, Any, Checkpoint, }; -use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use diskann_providers::{ index::diskann_async, model::{configuration::IndexConfiguration, graph::provider::async_::common}, @@ -497,67 +496,6 @@ where } } -pub(super) fn run_search_outer_full_precision( - input: &SearchPhase, - search_strategy: S, - index: Index, - build_stats: Option, - checkpoint: Checkpoint<'_>, -) -> anyhow::Result -where - DP: DataProvider - + provider::SetElement<[T]>, - T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, - S: glue::SearchStrategy - + glue::DelegateDefaultPostProcessor - + glue::PostProcess - + Clone - + AsyncFriendly, -{ - if let SearchPhase::Topk(search_phase) = input { - if let (Some(eta), Some(power)) = ( - search_phase.determinant_diversity_eta, - search_phase.determinant_diversity_power, - ) { - let mut result = BuildResult::new_topk(build_stats); - checkpoint.checkpoint(&result)?; - - let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile( - &search_phase.queries, - ))?); - - let groundtruth = - datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - - let knn = benchmark_core::search::graph::determinant_diversity::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - - let steps = search::knn::SearchSteps::new( - search_phase.reps, - &search_phase.num_threads, - &search_phase.runs, - ); - - let search_results = search::knn::run_determinant_diversity( - &knn, - &groundtruth, - steps, - eta, - power, - search_phase.determinant_diversity_results_k, - )?; - - result.append(AggregatedSearchResults::Topk(search_results)); - return Ok(result); - } - } - - run_search_outer(input, search_strategy, index, build_stats, checkpoint) -} - macro_rules! impl_build { ($T:ty) => { impl<'a> BuildAndSearch<'a> for FullPrecision<'a, $T> { @@ -609,7 +547,7 @@ macro_rules! impl_build { } }; - let result = run_search_outer_full_precision( + let result = run_search_outer( &self.input.search_phase, common::FullPrecision, index, diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 1f9c2e50a..bcf312832 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -156,41 +156,6 @@ impl SearchResults { } } - pub fn new_determinant_diversity( - summary: benchmark_core::search::graph::determinant_diversity::Summary, - ) -> Self { - let benchmark_core::search::graph::determinant_diversity::Summary { - setup, - parameters, - end_to_end_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall, - mean_cmps, - mean_hops, - .. - } = summary; - - let qps = end_to_end_latencies - .iter() - .map(|latency| recall.num_queries as f64 / latency.as_seconds()) - .collect(); - - Self { - num_tasks: setup.tasks.into(), - search_n: parameters.inner.k_value().get(), - search_l: parameters.inner.l_value().get(), - qps, - search_latencies: end_to_end_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall: (&recall).into(), - mean_cmps: mean_cmps as f32, - mean_hops: mean_hops as f32, - } - } } fn format_search_results_table( diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index a01318481..368e5f788 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -41,32 +41,6 @@ pub(crate) fn run( }) } -pub(crate) fn run_determinant_diversity( - runner: &dyn DeterminantDiversityKnn, - groundtruth: &dyn benchmark_core::recall::Rows, - steps: SearchSteps<'_>, - eta: f64, - power: f64, - results_k: Option, -) -> anyhow::Result> { - run_search_determinant_diversity(runner, groundtruth, steps, |setup, search_l, search_n| { - let base = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); - let processor = - diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams::new( - results_k.unwrap_or(search_n), - eta, - power, - ).map_err(|e| anyhow::anyhow!("Invalid determinant-diversity parameters: {}", e))?; - - let search_params = - diskann_benchmark_core::search::graph::determinant_diversity::Parameters { - inner: base, - processor, - }; - Ok(core_search::Run::new(search_params, setup)) - }) -} - type Run = core_search::Run; pub(crate) trait Knn { fn search_all( @@ -78,14 +52,15 @@ pub(crate) trait Knn { ) -> anyhow::Result>; } -type DeterminantRun = - core_search::Run; -/// Generic search infrastructure that unifies `run()` and `run_determinant_diversity()`. +/////////// +// Impls // +/////////// + +/// Generic search infrastructure. /// /// This helper extracts the common loop logic (iterating over threads and runs, /// and building a setup) leaving parameter construction to a builder closure. -/// This collapses the benchmark helper infrastructure and reduces duplication. fn run_search( runner: &dyn Knn, groundtruth: &dyn benchmark_core::recall::Rows, @@ -118,57 +93,6 @@ where Ok(all) } -/// Generic search infrastructure for determinant-diversity searches. -/// -/// Mirrors the unified logic of `run_search()` but for the DeterminantDiversityKnn trait. -fn run_search_determinant_diversity( - runner: &dyn DeterminantDiversityKnn, - groundtruth: &dyn benchmark_core::recall::Rows, - steps: SearchSteps<'_>, - builder: F, -) -> anyhow::Result> -where - F: Fn( - core_search::Setup, - usize, - usize, - ) -> anyhow::Result< - core_search::Run, - >, -{ - let mut all = Vec::new(); - - for threads in steps.num_tasks.iter() { - for run in steps.runs.iter() { - let setup = core_search::Setup { - threads: *threads, - tasks: *threads, - reps: steps.reps, - }; - - let parameters: Vec<_> = run - .search_l - .iter() - .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) - .collect::>>()?; - - all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); - } - } - - Ok(all) -} - -pub(crate) trait DeterminantDiversityKnn { - fn search_all( - &self, - parameters: Vec, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result>; -} - /////////// // Impls // /////////// @@ -225,40 +149,4 @@ where } } -impl DeterminantDiversityKnn - for Arc> -where - DP: diskann::provider::DataProvider, - core_search::graph::determinant_diversity::KNN: core_search::Search< - Id = DP::InternalId, - Parameters = diskann_benchmark_core::search::graph::determinant_diversity::Parameters, - Output = core_search::graph::knn::Metrics, - >, -{ - fn search_all( - &self, - parameters: Vec< - core_search::Run< - diskann_benchmark_core::search::graph::determinant_diversity::Parameters, - >, - >, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result> { - let results = core_search::search_all( - self.clone(), - parameters.into_iter(), - core_search::graph::determinant_diversity::Aggregator::new( - groundtruth, - recall_k, - recall_n, - ), - )?; - Ok(results - .into_iter() - .map(SearchResults::new_determinant_diversity) - .collect()) - } -} diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 0f8026b58..19230977d 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -123,9 +123,6 @@ pub(crate) struct TopkSearchPhase { pub(crate) queries: InputFile, pub(crate) groundtruth: InputFile, pub(crate) reps: NonZeroUsize, - pub(crate) determinant_diversity_eta: Option, - pub(crate) determinant_diversity_power: Option, - pub(crate) determinant_diversity_results_k: Option, // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, @@ -142,36 +139,6 @@ impl CheckDeserialization for TopkSearchPhase { .with_context(|| format!("search run {}", i))?; } - if self.determinant_diversity_eta.is_some() != self.determinant_diversity_power.is_some() { - return Err(anyhow!( - "determinant_diversity_eta and determinant_diversity_power must either both be set or both be omitted" - )); - } - - if let Some(eta) = self.determinant_diversity_eta { - if eta < 0.0 { - return Err(anyhow!( - "determinant_diversity_eta must be >= 0.0, got {}", - eta - )); - } - } - - if let Some(power) = self.determinant_diversity_power { - if power < 0.0 { - return Err(anyhow!( - "determinant_diversity_power must be >= 0.0, got {}", - power - )); - } - } - - if let Some(k) = self.determinant_diversity_results_k { - if k == 0 { - return Err(anyhow!("determinant_diversity_results_k must be > 0")); - } - } - Ok(()) } } @@ -197,9 +164,6 @@ impl Example for TopkSearchPhase { queries: InputFile::new("path/to/queries"), groundtruth: InputFile::new("path/to/groundtruth"), reps: REPS, - determinant_diversity_eta: None, - determinant_diversity_power: None, - determinant_diversity_results_k: None, num_threads: THREAD_COUNTS.to_vec(), runs, } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 22376a648..0572f99f6 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -85,9 +85,6 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, - pub(crate) determinant_diversity_eta: Option, - pub(crate) determinant_diversity_power: Option, - pub(crate) determinant_diversity_results_k: Option, } ///////// @@ -238,29 +235,6 @@ impl CheckDeserialization for DiskSearchPhase { } } - if self.determinant_diversity_eta.is_some() != self.determinant_diversity_power.is_some() { - anyhow::bail!( - "determinant_diversity_eta and determinant_diversity_power must either both be set or both omitted" - ); - } - - if let Some(eta) = self.determinant_diversity_eta { - if eta < 0.0 { - anyhow::bail!("determinant_diversity_eta must be >= 0.0"); - } - } - - if let Some(power) = self.determinant_diversity_power { - if power < 0.0 { - anyhow::bail!("determinant_diversity_power must be >= 0.0"); - } - } - - if let Some(k) = self.determinant_diversity_results_k { - if k == 0 { - anyhow::bail!("determinant_diversity_results_k must be > 0"); - } - } Ok(()) } } @@ -299,9 +273,6 @@ impl Example for DiskIndexOperation { vector_filters_file: None, num_nodes_to_cache: None, search_io_limit: None, - determinant_diversity_eta: None, - determinant_diversity_power: None, - determinant_diversity_results_k: None, }; Self { diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 6d7a8a928..4fe2a7f0b 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -39,9 +39,6 @@ use diskann::{ }; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ - model::graph::provider::async_::{ - determinant_diversity_post_process, DeterminantDiversitySearchParams, - }, model::{ compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType, pq::quantizer_preprocess, PQData, PQScratch, @@ -394,93 +391,6 @@ where } } -impl<'this, Data, ProviderFactory> - PostProcess< - DiskProvider, - [Data::VectorDataType], - DeterminantDiversitySearchParams, - ( - as DataProvider>::InternalId, - Data::AssociatedDataType, - ), - > for DiskSearchStrategy<'this, Data, ProviderFactory> -where - Data: GraphDataType, - ProviderFactory: VertexProviderFactory, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: DeterminantDiversitySearchParams, - accessor: &mut Self::SearchAccessor<'a>, - query: &[Data::VectorDataType], - _computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator as DataProvider>::InternalId>> + Send, - B: SearchOutputBuffer<( - as DataProvider>::InternalId, - Data::AssociatedDataType, - )> + Send - + ?Sized, - { - async move { - let provider = accessor.provider; - let query_f32 = Data::VectorDataType::as_f32(query) - .into_ann_result()? - .to_vec(); - - let filtered_ids: Vec = candidates - .map(|n| n.id) - .filter(|id| (self.vector_filter)(id)) - .collect(); - - if filtered_ids.is_empty() { - return Ok(0); - } - - ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &filtered_ids)?; - - let mut enriched: Vec<(u32, f32, Vec, Data::AssociatedDataType)> = Vec::new(); - for id in filtered_ids { - let vector = accessor.scratch.vertex_provider.get_vector(&id)?; - let vector_f32 = Data::VectorDataType::as_f32(vector) - .into_ann_result()? - .to_vec(); - let distance = provider - .distance_comparer - .evaluate_similarity(query, vector); - let assoc = *accessor.scratch.vertex_provider.get_associated_data(&id)?; - enriched.push((id, distance, vector_f32, assoc)); - } - - let borrowed: Vec<(u32, f32, &[f32])> = enriched - .iter() - .map(|(id, dist, vector, _)| (*id, *dist, vector.as_slice())) - .collect(); - - let reranked = determinant_diversity_post_process( - borrowed, - &query_f32, - processor.top_k, - processor.determinant_diversity_eta, - processor.determinant_diversity_power, - ); - - let mut pairs = Vec::with_capacity(reranked.len()); - for (id, distance) in reranked { - if let Some((_, _, _, assoc)) = enriched.iter().find(|(eid, _, _, _)| *eid == id) { - pairs.push(((id, *assoc), distance)); - } - } - - Ok(output.extend(pairs)) - } - } -} - /// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates. pub struct DiskQueryComputer { num_pq_chunks: usize, @@ -1069,93 +979,6 @@ where Ok(search_result) } - /// Perform a determinant-diversity search on the disk index. - #[allow(clippy::too_many_arguments)] - pub fn search_determinant_diversity( - &self, - query: &[Data::VectorDataType], - return_list_size: u32, - search_list_size: u32, - beam_width: Option, - vector_filter: Option>, - is_flat_search: bool, - processor: DeterminantDiversitySearchParams, - ) -> ANNResult> { - let mut query_stats = QueryStatistics::default(); - let mut indices = vec![0u32; return_list_size as usize]; - let mut distances = vec![0f32; return_list_size as usize]; - let mut associated_data = - vec![Data::AssociatedDataType::default(); return_list_size as usize]; - - let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( - &mut indices, - &mut distances, - &mut associated_data, - ); - - let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); - let strategy = self.search_strategy(query, &*vector_filter); - let timer = Instant::now(); - let k = return_list_size as usize; - let l = search_list_size as usize; - - let stats = if is_flat_search { - self.runtime.block_on(self.index.flat_search( - &strategy, - &DefaultContext, - strategy.query, - strategy.vector_filter, - &Knn::new(k, l, beam_width)?, - &mut result_output_buffer, - ))? - } else { - let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.index.search_with( - knn_search, - &strategy, - processor, - &DefaultContext, - strategy.query, - &mut result_output_buffer, - ))? - }; - - query_stats.total_comparisons = stats.cmps; - query_stats.search_hops = stats.hops; - query_stats.total_execution_time_us = timer.elapsed().as_micros(); - query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128; - query_stats.total_io_operations = strategy.io_tracker.io_count() as u32; - query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32; - query_stats.query_pq_preprocess_time_us = - IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128; - query_stats.cpu_time_us = query_stats.total_execution_time_us - - query_stats.io_time_us - - query_stats.query_pq_preprocess_time_us; - - let mut search_result = SearchResult { - results: Vec::with_capacity(return_list_size as usize), - stats: SearchResultStats { - cmps: query_stats.total_comparisons, - result_count: stats.result_count, - query_statistics: query_stats.clone(), - }, - }; - - for ((vertex_id, distance), associated_data) in indices - .into_iter() - .zip(distances.into_iter()) - .zip(associated_data.into_iter()) - { - search_result.results.push(SearchResultItem { - vertex_id, - distance, - data: associated_data, - }); - } - - Ok(search_result) - } - /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. #[allow(clippy::too_many_arguments)] diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs deleted file mode 100644 index 56b663d13..000000000 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ /dev/null @@ -1,484 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Determinant-diversity search post-processing. -//! -//! This module provides post-processing functionality for determinant-diversity search, -//! which reranks search results to maximize diversity using a greedy -//! orthogonalization algorithm. - -use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; - -/// Error type for determinant-diversity parameter validation. -#[derive(Debug)] -pub enum DeterminantDiversityError { - InvalidTopK { top_k: usize }, - InvalidEta { eta: f64 }, - InvalidPower { power: f64 }, -} - -impl std::fmt::Display for DeterminantDiversityError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::InvalidTopK { top_k } => { - write!(f, "top_k must be > 0, got {}", top_k) - } - Self::InvalidEta { eta } => { - write!(f, "eta must be >= 0.0, got {}", eta) - } - Self::InvalidPower { power } => { - write!(f, "power must be > 0.0, got {}", power) - } - } - } -} - -impl std::error::Error for DeterminantDiversityError {} - -/// Parameters for determinant-diversity reranking. -/// -/// # Invariants -/// -/// - `top_k > 0`: Must request at least one result -/// - `determinant_diversity_eta >= 0.0`: Ridge regularization parameter (0 = no ridge) -/// - `determinant_diversity_power > 0.0`: Exponent for diversity scaling (typically 1.0-2.0) -#[derive(Debug, Clone, Copy)] -pub struct DeterminantDiversitySearchParams { - pub top_k: usize, - pub determinant_diversity_eta: f64, - pub determinant_diversity_power: f64, -} - -impl DeterminantDiversitySearchParams { - /// Construct parameters with validation. - /// - /// # Arguments - /// - /// * `top_k` - Number of results to return (must be > 0) - /// * `determinant_diversity_eta` - Ridge regularization parameter (must be >= 0.0) - /// * `determinant_diversity_power` - Diversity exponent (must be > 0.0) - /// - /// # Errors - /// - /// Returns [`DeterminantDiversityError`] if any parameter is invalid. - pub fn new( - top_k: usize, - determinant_diversity_eta: f64, - determinant_diversity_power: f64, - ) -> Result { - if top_k == 0 { - return Err(DeterminantDiversityError::InvalidTopK { top_k }); - } - - if determinant_diversity_eta < 0.0 || !determinant_diversity_eta.is_finite() { - return Err(DeterminantDiversityError::InvalidEta { - eta: determinant_diversity_eta, - }); - } - - if determinant_diversity_power <= 0.0 || !determinant_diversity_power.is_finite() { - return Err(DeterminantDiversityError::InvalidPower { - power: determinant_diversity_power, - }); - } - - Ok(Self { - top_k, - determinant_diversity_eta, - determinant_diversity_power, - }) - } -} - -/// Post-process search results using determinant-diversity reranking. -/// -/// If `determinant_diversity_eta > 0.0`, uses a ridge-aware variant. -/// Otherwise, uses greedy orthogonalization. -pub fn determinant_diversity_post_process( - candidates: Vec<(Id, f32, &[f32])>, - query: &[f32], - k: usize, - determinant_diversity_eta: f64, - determinant_diversity_power: f64, -) -> Vec<(Id, f32)> { - if candidates.is_empty() || query.is_empty() { - return Vec::new(); - } - - let k = k.min(candidates.len()); - if k == 0 { - return Vec::new(); - } - - // Convert vectors to owned format only once - let candidates_f32: Vec<(Id, f32, Vec)> = candidates - .into_iter() - .map(|(id, dist, v)| (id, dist, v.to_vec())) - .collect(); - - if candidates_f32[0].2.is_empty() { - return Vec::new(); - } - - let results = if determinant_diversity_eta > 0.0 { - post_process_with_eta_f32( - candidates_f32, - query, - k, - determinant_diversity_eta, - determinant_diversity_power, - ) - } else { - post_process_greedy_orthogonalization_f32( - candidates_f32, - query, - k, - determinant_diversity_power, - ) - }; - - debug_assert_eq!( - results.len(), - k, - "determinant-diversity post-process should return exactly k={} results, got {}", - k, - results.len() - ); - - results -} - -fn post_process_with_eta_f32( - candidates: Vec<(Id, f32, Vec)>, - query: &[f32], - k: usize, - determinant_diversity_eta: f64, - determinant_diversity_power: f64, -) -> Vec<(Id, f32)> { - let eta = determinant_diversity_eta as f32; - let power = determinant_diversity_power; - - if candidates.is_empty() || query.is_empty() { - return Vec::new(); - } - - let n = candidates.len(); - let k = k.min(n); - - if k == 0 { - return Vec::new(); - } - - let d = candidates[0].2.len(); - if d == 0 { - return Vec::new(); - } - - let inv_sqrt_eta = 1.0 / eta.sqrt(); - - let mut residuals: Vec> = Vec::with_capacity(n); - let mut norms_sq: Vec = Vec::with_capacity(n); - - // Initialize residuals and norms (only one allocation per candidate) - for (_, _, v) in &candidates { - let similarity = dot_product(v, query); - let scale = similarity.max(0.0).powf(power as f32) * inv_sqrt_eta; - let r: Vec = v.iter().map(|&x| x * scale).collect(); - let s = dot_product(&r, &r); - residuals.push(r); - norms_sq.push(s); - } - - let mut available: Vec = vec![true; n]; - let mut selected: Vec = Vec::with_capacity(k); - - for _ in 0..k { - let best_idx = available - .iter() - .enumerate() - .filter(|&(_, &avail)| avail) - .max_by(|(i, _), (j, _)| { - norms_sq[*i] - .partial_cmp(&norms_sq[*j]) - .unwrap_or(std::cmp::Ordering::Equal) - }) - .map(|(i, _)| i); - - let Some(j) = best_idx else { - break; - }; - - selected.push(j); - available[j] = false; - - if selected.len() == k { - break; - } - - let norm_factor = 1.0 / (1.0 + norms_sq[j]).sqrt(); - - // Compute all projections first to avoid needing to clone residuals[j] - let mut projections: Vec = Vec::with_capacity(n); - for i in 0..n { - if !available[i] { - projections.push(0.0); - } else { - let alpha = dot_product(&residuals[j], &residuals[i]) * norm_factor * norm_factor; - projections.push(alpha); - } - } - - // Now apply all updates using the precomputed projections - let q_scaled: Vec = residuals[j].iter().map(|&x| x * norm_factor).collect(); - for i in 0..n { - if !available[i] { - continue; - } - - let alpha = projections[i]; - for (r_val, &q_val) in residuals[i].iter_mut().zip(q_scaled.iter()) { - *r_val -= alpha * q_val; - } - - norms_sq[i] = (norms_sq[i] - alpha * alpha).max(0.0); - } - } - - selected - .iter() - .map(|&idx| { - let (id, dist, _) = candidates[idx]; - (id, dist) - }) - .collect() -} - -fn post_process_greedy_orthogonalization_f32( - candidates: Vec<(Id, f32, Vec)>, - query: &[f32], - k: usize, - determinant_diversity_power: f64, -) -> Vec<(Id, f32)> { - let power = determinant_diversity_power; - - if candidates.is_empty() || query.is_empty() { - return Vec::new(); - } - - let n = candidates.len(); - let k = k.min(n); - - if k == 0 { - return Vec::new(); - } - - let mut residuals: Vec> = Vec::with_capacity(n); - let mut norms_sq: Vec = Vec::with_capacity(n); - - // Initialize residuals and norms (only one allocation per candidate) - for (_, _, v) in &candidates { - let similarity = dot_product(v, query); - let scale = similarity.max(0.0).powf(power as f32); - let r: Vec = v.iter().map(|&x| x * scale).collect(); - let s = dot_product(&r, &r); - residuals.push(r); - norms_sq.push(s); - } - - let mut available: Vec = vec![true; n]; - let mut selected: Vec = Vec::with_capacity(k); - - for _ in 0..k { - let best = available - .iter() - .enumerate() - .filter(|&(_, &avail)| avail) - .max_by(|(i, _), (j, _)| { - norms_sq[*i] - .partial_cmp(&norms_sq[*j]) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let Some((i_star, _)) = best else { - break; - }; - - let best_norm_sq = norms_sq[i_star]; - selected.push(i_star); - available[i_star] = false; - - if selected.len() == k { - break; - } - - if best_norm_sq <= 0.0 { - continue; - } - - let inv_norm_sq_star = 1.0 / best_norm_sq; - - // Compute all projections and make a copy of r_star to avoid borrow conflicts - let r_star_copy = residuals[i_star].clone(); - let mut projections: Vec = Vec::with_capacity(n); - for j in 0..n { - if !available[j] { - projections.push(0.0); - } else { - let proj = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; - projections.push(proj); - } - } - - // Now apply all updates using the precomputed projections - for j in 0..n { - if !available[j] { - continue; - } - - let proj_coeff = projections[j]; - for (r_val, &rs_val) in residuals[j].iter_mut().zip(r_star_copy.iter()) { - *r_val -= proj_coeff * rs_val; - } - - norms_sq[j] = (norms_sq[j] - proj_coeff * proj_coeff * best_norm_sq).max(0.0); - } - } - - selected - .iter() - .map(|&idx| { - let (id, dist, _) = candidates[idx]; - (id, dist) - }) - .collect() -} - -#[inline] -fn dot_product(a: &[f32], b: &[f32]) -> f32 { - >>::evaluate(a, b) - .into_inner() -} - -#[cfg(test)] -mod tests { - use super::*; - - // ===== Validation Tests ===== - - #[test] - fn test_validation_valid_params() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, 2.0); - assert!(result.is_ok()); - } - - #[test] - fn test_validation_zero_top_k() { - let result = DeterminantDiversitySearchParams::new(0, 0.01, 2.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidTopK { top_k: 0 }) - )); - } - - #[test] - fn test_validation_negative_eta() { - let result = DeterminantDiversitySearchParams::new(10, -0.01, 2.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidEta { .. }) - )); - } - - #[test] - fn test_validation_zero_power() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, 0.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidPower { .. }) - )); - } - - #[test] - fn test_validation_negative_power() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, -1.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidPower { .. }) - )); - } - - #[test] - fn test_validation_nan_eta() { - let result = DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidEta { .. }) - )); - } - - #[test] - fn test_validation_infinity_power() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, f64::INFINITY); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidPower { .. }) - )); - } - - // ===== Algorithm Tests ===== - - #[test] - fn test_determinant_diversity_post_process_with_eta() { - let v1 = vec![1.0f32, 0.0, 0.0]; - let v2 = vec![0.0f32, 1.0, 0.0]; - let v3 = vec![0.0f32, 0.0, 1.0]; - let candidates = vec![ - (1u32, 0.5f32, v1.as_slice()), - (2u32, 0.3f32, v2.as_slice()), - (3u32, 0.7f32, v3.as_slice()), - ]; - let query = vec![1.0, 1.0, 1.0]; - - let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); - assert_eq!(result.len(), 3); - } - - #[test] - fn test_determinant_diversity_post_process_enabled_greedy() { - let v1 = vec![1.0f32, 0.0, 0.0]; - let v2 = vec![0.99f32, 0.1, 0.0]; - let v3 = vec![0.0f32, 1.0, 0.0]; - let candidates = vec![ - (1u32, 0.5f32, v1.as_slice()), - (2u32, 0.3f32, v2.as_slice()), - (3u32, 0.4f32, v3.as_slice()), - ]; - let query = vec![1.0, 1.0, 0.0]; - - let result = determinant_diversity_post_process(candidates, &query, 2, 0.0, 1.0); - assert_eq!(result.len(), 2); - } - - #[test] - fn test_determinant_diversity_post_process_empty() { - let candidates: Vec<(u32, f32, &[f32])> = vec![]; - let query = vec![1.0, 1.0, 1.0]; - - let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); - assert!(result.is_empty()); - } - - #[test] - fn test_determinant_diversity_post_process_k_larger_than_candidates() { - let v1 = vec![1.0f32, 0.0, 0.0]; - let v2 = vec![0.0f32, 1.0, 0.0]; - let candidates = vec![(1u32, 0.5f32, v1.as_slice()), (2u32, 0.3f32, v2.as_slice())]; - let query = vec![1.0, 1.0, 1.0]; - - let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); - // Should return min(k, len(candidates)) = 2 - assert_eq!(result.len(), 2); - } -} diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index f697ee12f..baf2a6df2 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -24,9 +24,6 @@ use diskann::{ utils::{IntoUsize, VectorRepr}, }; -use super::super::determinant_diversity_post_process::{ - DeterminantDiversitySearchParams, determinant_diversity_post_process, -}; use diskann_utils::future::AsyncFriendly; use diskann_vector::{DistanceFunction, distance::Metric}; @@ -526,65 +523,6 @@ where } } -impl - PostProcess, [T], DeterminantDiversitySearchParams> - for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: DeterminantDiversitySearchParams, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - _computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - let query_f32 = T::as_f32(query).into_ann_result()?.to_vec(); - let mut candidates_with_vectors = Vec::new(); - - for candidate in candidates { - if accessor.provider.deleted.deletion_check(candidate.id) { - continue; - } - - let vector = accessor.get_element(candidate.id).await.into_ann_result()?; - let vector_f32 = T::as_f32(vector).into_ann_result()?; - candidates_with_vectors.push(( - candidate.id, - candidate.distance, - vector_f32.to_vec(), - )); - } - - let borrowed: Vec<(u32, f32, &[f32])> = candidates_with_vectors - .iter() - .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) - .collect(); - - let reranked = determinant_diversity_post_process( - borrowed, - &query_f32, - processor.top_k, - processor.determinant_diversity_eta, - processor.determinant_diversity_power, - ); - - Ok(output.extend(reranked)) - } - } -} - // Pruning impl PruneStrategy> for FullPrecision where diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index bdf620d38..4e703cf33 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -44,8 +44,3 @@ pub mod caching; #[cfg(test)] pub mod debug_provider; -// Determinant-diversity post-processing. -pub mod determinant_diversity_post_process; -pub use determinant_diversity_post_process::{ - DeterminantDiversitySearchParams, determinant_diversity_post_process, -}; diff --git a/example.rs b/example.rs new file mode 100644 index 000000000..dccc80441 --- /dev/null +++ b/example.rs @@ -0,0 +1,40 @@ +// Default post-process +#[derive(Debug, Clone, Copy)] +pub struct DefaultPostProcess; + +pub trait DelegatePostProcess { + type Delegate: DoesThings; +} + +impl SearchPostProcess for DefaultPostProcess +where + T: DelegatePostProcess +{ + fn post_process(args...) { + T::Delegate::post_process(args...) + } +} + +// Apply the default post-process via the normal search API. +fn search( + dispatch: T, + other_args... +) +where + DefaultPostProcess: SearchPostProcess +{ + search_with(dispatch, other_args..., DefaultPostProcess) +} + +// Second API that allows for overriding the post-processor explicitly. +fn search_with( + dispatch: T, + other_args... + post_process: P +) +where + P: SearchPostProcess +{ + // Do the thing. The `Search` trait will always take a post-processor. +} + diff --git a/post_process_design_sketch.rs b/post_process_design_sketch.rs new file mode 100644 index 000000000..cc1bceff0 --- /dev/null +++ b/post_process_design_sketch.rs @@ -0,0 +1,418 @@ +// ============================================================================= +// Post-Processing Redesign: Sketch & Rationale +// ============================================================================= +// +// Context +// ------- +// Two competing PRs attempted to refactor how SearchStrategy interacts with +// post-processing. Both had structural problems: +// +// Exhibit-A kept `type PostProcessor` on SearchStrategy and layered a new +// `PostProcess` trait on top. This created two +// parallel "what's the post-processor?" answers on the same type that could +// silently diverge. The GAT associated type became dead weight that every +// implementor still had to fill in. +// +// Exhibit-B removed `PostProcessor` from SearchStrategy (good), but replaced +// it with a `DelegatePostProcess` marker whose blanket impl covered *all* +// processor types `P` at once: +// +// impl PostProcess for S +// where S: SearchStrategy<…> + DelegatePostProcess, +// P: for<'a> SearchPostProcess, T, O> + … +// +// This makes it impossible to override `PostProcess` for a specific `P` +// without opting out of the blanket entirely (removing DelegatePostProcess), +// which then forces manual impls for every processor type — an all-or-nothing +// cliff. It also provided no `KnnWith`-style mechanism for callers to supply +// a custom processor at the search call-site. +// +// Proposed Design +// --------------- +// Flip the blanket. Instead of "strategy S gets PostProcess for all P", +// make it "the DefaultPostProcess ZST gets support for all strategies S +// that opt in via HasDefaultProcessor". +// +// The blanket is narrow (covers exactly one P = DefaultPostProcess), so custom +// PostProcess<…, RagSearchParams, …> impls are coherence-safe. Strategies +// that don't need a default can skip HasDefaultProcessor and still be used via +// KnnWith with an explicit processor. +// +// ============================================================================= +// +// How to read this file +// --------------------- +// This is pseudocode — it won't compile. Signatures use real Rust syntax where +// possible but elide lifetimes, bounds, and async machinery for clarity. +// Comments marked "NOTE" call out places where the real implementation will +// need careful attention to HRTB / GAT interactions. +// +// ============================================================================= + +// --------------------------------------------------------------------------- +// 1. SearchStrategy — clean, no post-processing knowledge +// --------------------------------------------------------------------------- +// +// This is the same as today minus `type PostProcessor` and `fn post_processor`. + +pub trait SearchStrategy::InternalId>: + Send + Sync +where + Provider: DataProvider, + T: ?Sized, + O: Send, +{ + type QueryComputer: /* PreprocessedDistanceFunction bounds */ Send + Sync + 'static; + type SearchAccessorError: StandardError; + + // NOTE: This GAT is the source of most HRTB complexity downstream. + type SearchAccessor<'a>: ExpandBeam + + SearchExt; + + fn search_accessor<'a>( + &'a self, + provider: &'a Provider, + context: &'a Provider::Context, + ) -> Result, Self::SearchAccessorError>; +} + +// --------------------------------------------------------------------------- +// 2. SearchPostProcess — unchanged from today +// --------------------------------------------------------------------------- +// +// Low-level trait, parameterized by the *accessor* (not the strategy). +// CopyIds, Rerank, Pipeline, RemoveDeletedIdsAndCopy, etc. all +// implement this directly. No changes needed here. + +pub trait SearchPostProcess::Id> +where + A: BuildQueryComputer, + T: ?Sized, +{ + type Error: StandardError; + + fn post_process( + &self, + accessor: &mut A, + query: &T, + computer: &>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized; +} + +// Pipeline, CopyIds, FilterStartPoints, SearchPostProcessStep — all unchanged. + +// --------------------------------------------------------------------------- +// 3. PostProcess — strategy-level bridge, parameterized by processor P +// --------------------------------------------------------------------------- +// +// This trait connects a strategy to a specific processor type. It is the +// surface that the search infrastructure (Knn, KnnWith, RecordedKnn, etc.) +// bounds on. + +pub trait PostProcess::InternalId>: + SearchStrategy +where + Provider: DataProvider, + T: ?Sized, + O: Send, + P: Send + Sync, +{ + fn post_process_with<'a, I, B>( + &self, + processor: &P, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized; +} + +// --------------------------------------------------------------------------- +// 4. HasDefaultProcessor — opt-in "I have a default post-processor" +// --------------------------------------------------------------------------- +// +// Strategies that want to work with Knn (no explicit processor) implement this. +// It replaces the old `type PostProcessor` on SearchStrategy. +// +// NOTE: The `for<'a> SearchPostProcess, T, O>` HRTB +// bound is the same one that lived on SearchStrategy::PostProcessor today. +// It's not new complexity — it just moved here. + +pub trait HasDefaultProcessor::InternalId>: + SearchStrategy +where + Provider: DataProvider, + T: ?Sized, + O: Send, +{ + type Processor: for<'a> SearchPostProcess, T, O> + + Send + + Sync; + + fn create_processor(&self) -> Self::Processor; +} + +// Convenience macro (same idea as exhibit-B's delegate_default_post_process!). +macro_rules! delegate_default_post_process { + ($Processor:ty) => { + type Processor = $Processor; + fn create_processor(&self) -> Self::Processor { + Default::default() + } + }; +} + +// --------------------------------------------------------------------------- +// 5. DefaultPostProcess ZST + THE blanket impl +// --------------------------------------------------------------------------- +// +// KEY DESIGN POINT: The blanket covers exactly P = DefaultPostProcess. +// Custom processor types (RagSearchParams, etc.) are free to have their own +// `impl PostProcess<…, RagSearchParams, …> for MyStrategy` without any +// coherence conflict. + +#[derive(Debug, Default, Clone, Copy)] +pub struct DefaultPostProcess; + +impl PostProcess for S +where + S: HasDefaultProcessor, + Provider: DataProvider, + T: ?Sized + Sync, + O: Send, +{ + async fn post_process_with<'a, I, B>( + &self, + _processor: &DefaultPostProcess, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> ANNResult + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + self.create_processor() + .post_process(accessor, query, computer, candidates, output) + .await + .into_ann_result() + } +} + +// --------------------------------------------------------------------------- +// 6. Search API split: Knn vs KnnWith +// --------------------------------------------------------------------------- +// +// Knn uses the default processor. KnnWith allows an explicit override. +// Both delegate to a shared `search_core` that is parameterized over PP. + +impl Knn { + /// Shared core — the only axis of variation is the processor. + async fn search_core( + &self, + index: &DiskANNIndex, + strategy: &S, + /* … */ + post_processor: &PP, + ) -> ANNResult + where + S: PostProcess, + PP: Send + Sync, + /* … */ + { + let mut accessor = strategy.search_accessor(/* … */)?; + let computer = accessor.build_query_computer(query)?; + /* … search_internal … */ + let count = strategy + .post_process_with(post_processor, &mut accessor, query, &computer, candidates, output) + .await?; + Ok(stats.finish(count as u32)) + } +} + +// Knn: uses DefaultPostProcess +impl Search for Knn +where + S: PostProcess, + // equivalently: S: HasDefaultProcessor +{ + fn search(self, /* … */) -> impl SendFuture> { + async move { + self.search_core(/* … */, &DefaultPostProcess).await + } + } +} + +// KnnWith: uses caller-supplied processor +pub struct KnnWith { + inner: Knn, + post_processor: PP, +} + +impl Search for KnnWith +where + S: PostProcess, + PP: Send + Sync, +{ + fn search(self, /* … */) -> impl SendFuture> { + async move { + self.inner + .search_core(/* … */, &self.post_processor) + .await + } + } +} + +// --------------------------------------------------------------------------- +// 7. Example: implementing a strategy +// --------------------------------------------------------------------------- + +struct MyStrategy { /* … */ } + +impl SearchStrategy for MyStrategy { + type QueryComputer = MyComputer; + type SearchAccessorError = ANNError; + type SearchAccessor<'a> = MyAccessor<'a>; + + fn search_accessor<'a>(/* … */) -> Result, ANNError> { /* … */ } + // No PostProcessor, no post_processor() — clean. +} + +// Opt in to the default: "my default post-processor is CopyIds" +impl HasDefaultProcessor for MyStrategy { + delegate_default_post_process!(CopyIds); +} +// That's it — Knn now works with MyStrategy. + +// Opt in to RAG reranking too (no coherence conflict!): +impl PostProcess for MyStrategy { + async fn post_process_with( + &self, + processor: &RagSearchParams, + accessor: &mut MyAccessor<'_>, + /* … */ + ) -> ANNResult { + // Custom RAG logic here + } +} +// Now `KnnWith::new(knn, rag_params)` also works with MyStrategy. + +// --------------------------------------------------------------------------- +// 8. Decorator strategies (BetaFilter) +// --------------------------------------------------------------------------- +// +// BetaFilter wraps an inner strategy and delegates. The PostProcess<…, P, …> +// impl is generic over P, which is coherence-safe because it's on a concrete +// wrapper type (not a blanket over Self). + +impl PostProcess + for BetaFilter +where + Strategy: PostProcess, + P: Send + Sync, + /* … other bounds … */ +{ + async fn post_process_with( + &self, + processor: &P, + accessor: &mut Self::SearchAccessor<'_>, + /* … */ + ) -> ANNResult { + // Unwrap the layered accessor, delegate to inner strategy + self.strategy + .post_process_with(processor, &mut accessor.inner, /* … */) + .await + } +} + +impl HasDefaultProcessor + for BetaFilter +where + Strategy: HasDefaultProcessor, + /* … */ +{ + type Processor = Strategy::Processor; + fn create_processor(&self) -> Self::Processor { + self.strategy.create_processor() + } +} + +// --------------------------------------------------------------------------- +// 9. InplaceDeleteStrategy +// --------------------------------------------------------------------------- +// +// The delete-search phase needs exactly one processor type. The associated +// type pins it, and the SearchStrategy bound requires PostProcess for that +// specific type. +// +// NOTE: The double `for<'a>` bound is verbose but unavoidable given the GAT. + +pub trait InplaceDeleteStrategy: Send + Sync + 'static +where + Provider: DataProvider, +{ + type DeleteElement<'a>: Send + Sync + ?Sized; + type DeleteElementGuard: /* … AsyncLower … */ + 'static; + type DeleteElementError: StandardError; + type PruneStrategy: PruneStrategy; + + /// The processor used during the delete-search phase. + type SearchPostProcessor: Send + Sync; + + /// The search strategy, which must support PostProcess with the above processor. + type SearchStrategy: for<'a> SearchStrategy> + + for<'a> PostProcess< + Provider, + Self::DeleteElement<'a>, + Self::SearchPostProcessor, + >; + + fn prune_strategy(&self) -> Self::PruneStrategy; + fn search_strategy(&self) -> Self::SearchStrategy; + fn search_post_processor(&self) -> Self::SearchPostProcessor; + + fn get_delete_element<'a>(/* … */) -> impl Future> + Send; +} + +// --------------------------------------------------------------------------- +// 10. Known pain points for the real implementation +// --------------------------------------------------------------------------- +// +// A. HRTB on HasDefaultProcessor::Processor +// The bound `for<'a> SearchPostProcess, T, O>` +// is the same one that lived on SearchStrategy::PostProcessor before. +// It's not new — it just moved. The delegate_default_post_process! macro +// should absorb this. +// +// B. BetaFilter's generic P delegation +// `impl

PostProcess<…, P, …> for BetaFilter where S: PostProcess<…, P, …>` +// is coherence-safe (concrete wrapper, not a blanket over Self), but verify +// that rustc is happy with the HRTB interaction when SearchAccessor<'a> is +// a layered type (BetaAccessor wrapping the inner accessor). +// +// C. Disk provider (DiskSearchStrategy) +// Today it has PostProcessor = RerankAndFilter. Under the new design: +// - impl HasDefaultProcessor → Processor = RerankAndFilter +// - impl PostProcess<…, RagSearchParams, …> → custom RAG reranking +// These are independent impls with no coherence conflict. +// +// D. Caching provider (CachingAccessor) +// Uses Pipeline today. Same pattern: HasDefaultProcessor +// with Processor = Pipeline. The Pipeline type is just +// another SearchPostProcess impl. +// +// E. The .send() / IntoANNResult bridge +// The blanket impl calls `create_processor().post_process(…).await`. +// The SearchPostProcess::Error needs to be convertible to ANNError. Today +// this is handled via IntoANNResult / .send(). Same pattern applies. From 054398bb6c985b341fb4d8cf3b0bffcc8f6476c9 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 15:32:24 -0700 Subject: [PATCH 19/47] =?UTF-8?q?Eliminate=20PostProcess=20trait=20?= =?UTF-8?q?=E2=80=94=20Option=20A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the PostProcess trait indirection with HRTB bounds directly on search methods. The key insight is that SearchAccessor<'a> has no 'where Self: 'a' clause, making the HRTB 'for<'a> PP: SearchPostProcess< S::SearchAccessor<'a>, T, O>' safe even with generic S. Changes in diskann/: - Remove PostProcess trait, DelegateDefaultPostProcessor trait, DefaultPostProcess ZST, blanket impl, delegate_default_post_process! macro - Add HasDefaultProcessor trait and has_default_processor! macro - Update DefaultSearchStrategy = SearchStrategy + HasDefaultProcessor - Update InplaceDeleteStrategy: SearchPostProcessor now carries the HRTB SearchPostProcess bound directly - Search trait now requires S: SearchStrategy (needed for GAT projection) - All Search impls (Knn, RecordedKnn, Range, Diverse, Multihop) call processor.post_process() directly instead of strategy.post_process_with() - DiskANNIndex::search() uses HasDefaultProcessor::create_processor() - DiskANNIndex::search_with() takes PP with HRTB bound - Update test provider to use HasDefaultProcessor + CopyIds Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- diskann/src/graph/glue.rs | 103 ++++---------------- diskann/src/graph/index.rs | 29 +++--- diskann/src/graph/search/diverse_search.rs | 14 +-- diskann/src/graph/search/knn_search.rs | 30 +++--- diskann/src/graph/search/mod.rs | 9 +- diskann/src/graph/search/multihop_search.rs | 15 +-- diskann/src/graph/search/range_search.rs | 14 +-- diskann/src/graph/test/provider.rs | 8 +- 8 files changed, 76 insertions(+), 146 deletions(-) diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index baccb9cf5..7e42f14a7 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -333,44 +333,12 @@ where ) -> Result, Self::SearchAccessorError>; } -/// Strategy-level bridge connecting a [`SearchStrategy`] to a specific processor type `P`. -/// -/// This trait is the surface that the search infrastructure (for example, -/// [`super::search::Knn`]) bounds on. -/// -/// The blanket impl covers `P = DefaultPostProcess` for any strategy implementing -/// [`DelegateDefaultPostProcessor`]. Custom processor types (e.g. `DeterminantDiversitySearchParams`) can have -/// their own `PostProcess` impls without coherence conflicts. -pub trait PostProcess::InternalId>: - SearchStrategy -where - Provider: DataProvider, - T: ?Sized, - O: Send, - P: Send + Sync, -{ - /// Run post-processing with the given `processor` on `candidates`, writing - /// results into `output`. - fn post_process_with<'a, I, B>( - &self, - processor: P, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized; -} - /// Opt-in trait for strategies that have a default post-processor. /// /// Strategies implementing this trait work with [`super::search::Knn`] (no explicit -/// processor). The old `SearchStrategy::PostProcessor` associated type is replaced by -/// `DelegateDefaultPostProcessor::Processor`. -pub trait DelegateDefaultPostProcessor::InternalId>: +/// processor). The search infrastructure will call `create_processor()` to obtain the +/// processor and invoke its [`SearchPostProcess::post_process`] method. +pub trait HasDefaultProcessor::InternalId>: SearchStrategy where Provider: DataProvider, @@ -386,7 +354,7 @@ where /// Aggregate trait for strategies that support both search access and a default post-processor. pub trait DefaultSearchStrategy::InternalId>: - SearchStrategy + DelegateDefaultPostProcessor + SearchStrategy + HasDefaultProcessor where Provider: DataProvider, T: ?Sized, @@ -396,25 +364,25 @@ where impl DefaultSearchStrategy for S where - S: SearchStrategy + DelegateDefaultPostProcessor, + S: SearchStrategy + HasDefaultProcessor, Provider: DataProvider, T: ?Sized, O: Send, { } -/// Convenience macro for implementing [`DelegateDefaultPostProcessor`] when the processor +/// Convenience macro for implementing [`HasDefaultProcessor`] when the processor /// is a [`Default`]-constructible type. /// /// # Example /// /// ```ignore -/// impl DelegateDefaultPostProcessor for MyStrategy { -/// delegate_default_post_process!(CopyIds); +/// impl HasDefaultProcessor for MyStrategy { +/// has_default_processor!(CopyIds); /// } /// ``` #[macro_export] -macro_rules! delegate_default_post_process { +macro_rules! has_default_processor { ($Processor:ty) => { type Processor = $Processor; fn create_processor(&self) -> Self::Processor { @@ -423,45 +391,6 @@ macro_rules! delegate_default_post_process { }; } -/// A zero-sized marker representing "use the default post-processor". -/// -/// The blanket `PostProcess` impl covers exactly `P = DefaultPostProcess`. -/// Custom processor types are free to have their own `PostProcess` impls -/// without coherence conflicts. -#[derive(Debug, Default, Clone, Copy)] -pub struct DefaultPostProcess; - -impl PostProcess for S -where - S: DelegateDefaultPostProcessor, - Provider: DataProvider, - T: ?Sized + Sync, - O: Send, -{ - fn post_process_with<'a, I, B>( - &self, - _processor: DefaultPostProcess, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - use crate::error::IntoANNResult; - async move { - self.create_processor() - .post_process(accessor, query, computer, candidates, output) - .send() - .await - .into_ann_result() - } - } -} - /// Perform post-processing on the results of search, storing the results in an output buffer. /// /// Simple implementations include [`CopyIds`], which simply forwards the search results @@ -864,12 +793,14 @@ where type PruneStrategy: PruneStrategy; /// The processor used during the delete-search phase. - type SearchPostProcessor: Send + Sync; + type SearchPostProcessor: for<'a> SearchPostProcess< + >>::SearchAccessor<'a>, + Self::DeleteElement<'a>, + > + Send + + Sync; /// The type of the search strategy to use for graph traversal. - /// It must support [`PostProcess`] with [`Self::SearchPostProcessor`]. - type SearchStrategy: for<'a> SearchStrategy> - + for<'a> PostProcess, Self::SearchPostProcessor>; + type SearchStrategy: for<'a> SearchStrategy>; /// Construct the prune strategy object. fn prune_strategy(&self) -> Self::PruneStrategy; @@ -1158,8 +1089,8 @@ mod tests { } } - impl DelegateDefaultPostProcessor for Strategy { - delegate_default_post_process!(CopyIds); + impl HasDefaultProcessor for Strategy { + has_default_processor!(CopyIds); } // Use the provided implementation. diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index c239e8671..b79b01128 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -27,7 +27,7 @@ use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, glue::{ self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, InsertStrategy, - PostProcess, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, + PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -1297,16 +1297,16 @@ where // placed into the output. let proxy = v.async_lower(); let post_processor = strategy.search_post_processor(); - let num_results = search_strategy - .post_process_with( - post_processor, + let num_results = post_processor + .post_process( &mut search_accessor, &*proxy, &computer, scratch.best.iter(), &mut neighbor::BackInserter::new(output.as_mut_slice()), ) - .await?; + .await + .into_ann_result()?; let mut undeleted_ids: Vec<_> = output .iter() @@ -2153,20 +2153,13 @@ where ) -> impl SendFuture> where P: super::search::Search, - S: glue::PostProcess, - glue::DefaultPostProcess: Send + Sync, + S: glue::HasDefaultProcessor, O: Send, OB: super::search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, { - self.search_with( - search_params, - strategy, - glue::DefaultPostProcess, - context, - query, - output, - ) + let processor = strategy.create_processor(); + self.search_with(search_params, strategy, processor, context, query, output) } /// Execute a search with an explicit post-processor parameter. @@ -2181,8 +2174,8 @@ where ) -> impl SendFuture> where P: super::search::Search, - S: glue::PostProcess, - PP: Send + Sync, + S: glue::SearchStrategy, + PP: for<'a> glue::SearchPostProcess, T, O> + Send + Sync, O: Send, OB: super::search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, @@ -2229,7 +2222,7 @@ where where T: ?Sized, S: SearchStrategy: IdIterator> - + glue::DelegateDefaultPostProcessor, + + glue::HasDefaultProcessor, I: Iterator::InternalId>, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index e1d63e4be..f7e57f52e 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -14,7 +14,7 @@ use crate::{ error::IntoANNResult, graph::{ DiverseSearchParams, - glue::{PostProcess, SearchExt}, + glue::{SearchExt, SearchPostProcess}, index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, @@ -95,6 +95,7 @@ where impl Search for Diverse

where DP: DataProvider, + S: crate::graph::glue::SearchStrategy, T: Sync + ?Sized, O: Send, P: AttributeValueProvider, @@ -111,8 +112,7 @@ where output: &mut OB, ) -> impl SendFuture> where - S: PostProcess, - PP: Send + Sync, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { async move { @@ -139,16 +139,16 @@ where // Post-process diverse results diverse_scratch.best.post_process(); - let result_count = strategy - .post_process_with( - processor, + let result_count = processor + .post_process( &mut accessor, query, &computer, diverse_scratch.best.iter().take(self.inner.l_value().get()), output, ) - .await?; + .await + .into_ann_result()?; Ok(stats.finish(result_count as u32)) } diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index f70dc6aed..090d8d44b 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -15,7 +15,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{PostProcess, SearchExt}, + glue::{SearchExt, SearchPostProcess}, index::{DiskANNIndex, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, @@ -157,10 +157,10 @@ impl Knn { where DP: DataProvider, T: Sync + ?Sized, - S: PostProcess, + S: crate::graph::glue::SearchStrategy, O: Send, OB: SearchOutputBuffer + Send + ?Sized, - PP: Send + Sync, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, { let mut accessor = strategy .search_accessor(&index.data_provider, context) @@ -182,16 +182,16 @@ impl Knn { ) .await?; - let result_count = strategy - .post_process_with( - post_processor, + let result_count = post_processor + .post_process( &mut accessor, query, &computer, scratch.best.iter().take(self.l_value.get().into_usize()), output, ) - .await?; + .await + .into_ann_result()?; Ok(stats.finish(result_count as u32)) } @@ -200,6 +200,7 @@ impl Knn { impl Search for Knn where DP: DataProvider, + S: crate::graph::glue::SearchStrategy, T: Sync + ?Sized, O: Send, { @@ -216,8 +217,7 @@ where output: &mut OB, ) -> impl SendFuture> where - S: PostProcess, - PP: Send + Sync, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { async move { @@ -252,6 +252,7 @@ impl<'r, SR: ?Sized> RecordedKnn<'r, SR> { impl<'r, DP, S, T, O, SR> Search for RecordedKnn<'r, SR> where DP: DataProvider, + S: crate::graph::glue::SearchStrategy, T: Sync + ?Sized, O: Send, SR: super::record::SearchRecord + ?Sized, @@ -268,8 +269,7 @@ where output: &mut OB, ) -> impl SendFuture> where - S: PostProcess, - PP: Send + Sync, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { async move { @@ -293,9 +293,8 @@ where ) .await?; - let result_count = strategy - .post_process_with( - processor, + let result_count = processor + .post_process( &mut accessor, query, &computer, @@ -305,7 +304,8 @@ where .take(self.inner.l_value.get().into_usize()), output, ) - .await?; + .await + .into_ann_result()?; Ok(stats.finish(result_count as u32)) } diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index ab32581dc..d43bc77b0 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -51,6 +51,7 @@ pub(crate) mod scratch; pub trait Search where DP: DataProvider, + S: crate::graph::glue::SearchStrategy, O: Send, { /// The result type returned by this search. @@ -88,8 +89,12 @@ where output: &mut OB, ) -> impl SendFuture> where - S: crate::graph::glue::PostProcess, - PP: Send + Sync, + PP: for<'a> crate::graph::glue::SearchPostProcess< + >::SearchAccessor<'a>, + T, + O, + > + Send + + Sync, OB: crate::graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; } diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 5771cd4de..09ec0a001 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -16,7 +16,8 @@ use crate::{ error::{ErrorExt, IntoANNResult}, graph::{ glue::{ - self, ExpandBeam, HybridPredicate, PostProcess, Predicate, PredicateMut, SearchExt, + self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, + SearchPostProcess, }, index::{ DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, @@ -55,6 +56,7 @@ impl<'q, InternalId> MultihopSearch<'q, InternalId> { impl<'q, DP, S, T, O> Search for MultihopSearch<'q, DP::InternalId> where DP: DataProvider, + S: glue::SearchStrategy, T: Sync + ?Sized, O: Send, { @@ -70,8 +72,7 @@ where output: &mut OB, ) -> impl SendFuture> where - S: PostProcess, - PP: Send + Sync, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { async move { @@ -95,16 +96,16 @@ where ) .await?; - let result_count = strategy - .post_process_with( - processor, + let result_count = processor + .post_process( &mut accessor, query, &computer, scratch.best.iter().take(self.inner.l_value().get()), output, ) - .await?; + .await + .into_ann_result()?; Ok(stats.finish(result_count as u32)) } diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 495f80de1..d746d3dca 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -13,7 +13,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, ExpandBeam, PostProcess, SearchExt}, + glue::{self, ExpandBeam, SearchExt, SearchPostProcess}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::{self, SearchOutputBuffer}, @@ -170,6 +170,7 @@ impl Range { impl Search for Range where DP: DataProvider, + S: glue::SearchStrategy, T: Sync + ?Sized, O: Send + Default + Clone, { @@ -185,8 +186,7 @@ where output: &mut OB, ) -> impl SendFuture> where - S: PostProcess, - PP: Send + Sync, + PP: for<'a> glue::SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { async move { @@ -255,16 +255,16 @@ where result_dists.as_mut_slice(), ); - let _ = strategy - .post_process_with( - processor, + let _ = processor + .post_process( &mut accessor, query, &computer, scratch.in_range.iter().copied(), &mut output_buffer, ) - .await?; + .await + .into_ann_result()?; // Filter by inner/outer radius let inner_cutoff = if let Some(inner_radius) = self.inner_radius() { diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 786bf35ac..14b5960e5 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -16,7 +16,7 @@ use diskann_vector::distance::Metric; use thiserror::Error; use crate::{ - ANNError, ANNResult, delegate_default_post_process, + ANNError, ANNResult, has_default_processor, error::{Infallible, message}, graph::{AdjacencyList, glue, test::synthetic}, internal::counter::{Counter, LocalCounter}, @@ -964,8 +964,8 @@ impl glue::SearchStrategy for Strategy { } } -impl glue::DelegateDefaultPostProcessor for Strategy { - delegate_default_post_process!(glue::CopyIds); +impl glue::HasDefaultProcessor for Strategy { + has_default_processor!(glue::CopyIds); } impl glue::PruneStrategy for Strategy { @@ -1015,7 +1015,7 @@ impl glue::InplaceDeleteStrategy for Strategy { type DeleteElementError = AccessedInvalidId; type PruneStrategy = Self; type SearchStrategy = Self; - type SearchPostProcessor = glue::DefaultPostProcess; + type SearchPostProcessor = glue::CopyIds; fn prune_strategy(&self) -> Self::PruneStrategy { *self From bb3d60adb79751363474e9a10e521e9d7523456c Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 15:40:06 -0700 Subject: [PATCH 20/47] Remove PostProcess impls across providers Delete all forwarding PostProcess impl blocks (10 impls), rename DelegateDefaultPostProcessor to HasDefaultProcessor across all provider delete CachedPostProcess

newtype (replaced by Pipeline), and replace longhand SearchStrategy + HasDefaultProcessor with DefaultSearchStrategy where appropriate. Net: -383 lines of pure forwarding boilerplate eliminated. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/backend/disk_index/search.rs | 14 ++-- .../src/backend/index/benchmarks.rs | 5 +- diskann-benchmark/src/backend/index/result.rs | 1 - .../src/backend/index/search/knn.rs | 3 - .../src/search/provider/disk_provider.rs | 7 +- diskann-garnet/src/provider.rs | 70 ++---------------- .../inline_beta_search/inline_beta_filter.rs | 6 +- diskann-providers/src/index/diskann_async.rs | 16 ++-- diskann-providers/src/index/wrapped_async.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 73 ++----------------- .../graph/provider/async_/caching/provider.rs | 58 ++------------- .../graph/provider/async_/debug_provider.rs | 64 ++-------------- .../provider/async_/inmem/full_precision.rs | 44 ++--------- .../graph/provider/async_/inmem/product.rs | 48 ++---------- .../graph/provider/async_/inmem/scalar.rs | 16 ++-- .../graph/provider/async_/inmem/spherical.rs | 14 ++-- .../model/graph/provider/async_/inmem/test.rs | 8 +- .../src/model/graph/provider/async_/mod.rs | 1 - .../model/graph/provider/layers/betafilter.rs | 10 +-- diskann/src/graph/search/range_search.rs | 2 +- diskann/src/graph/test/provider.rs | 3 +- post_process_design_sketch.rs | 8 +- 22 files changed, 90 insertions(+), 383 deletions(-) diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index ce51366e7..f3ad744a9 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -270,13 +270,13 @@ where }; let search_result = searcher.search( - q, - search_params.recall_at, - l, - Some(search_params.beam_width), - vector_filter, - search_params.is_flat_search, - ); + q, + search_params.recall_at, + l, + Some(search_params.beam_width), + vector_filter, + search_params.is_flat_search, + ); match search_result { Ok(search_result) => { diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 557e7594e..fa4a77078 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -349,10 +349,7 @@ where DP: DataProvider + provider::SetElement<[T]>, T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, - S: glue::SearchStrategy - + glue::DelegateDefaultPostProcessor - + Clone - + AsyncFriendly, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, { match &input { SearchPhase::Topk(search_phase) => { diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index bcf312832..1d6102f9b 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -155,7 +155,6 @@ impl SearchResults { mean_hops: mean_hops as f32, } } - } fn format_search_results_table( diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 368e5f788..5695f4e95 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -52,7 +52,6 @@ pub(crate) trait Knn { ) -> anyhow::Result>; } - /////////// // Impls // /////////// @@ -148,5 +147,3 @@ where Ok(results.into_iter().map(SearchResults::new).collect()) } } - - diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 4fe2a7f0b..870541f0c 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -16,12 +16,11 @@ use std::{ }; use diskann::{ - error::IntoANNResult, graph::{ self, glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, IdIterator, PostProcess, SearchExt, - SearchPostProcess, SearchStrategy, + self, ExpandBeam, HasDefaultProcessor, IdIterator, SearchExt, SearchPostProcess, + SearchStrategy, }, search::Knn, search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, @@ -372,7 +371,7 @@ where } impl<'this, Data, ProviderFactory> - DelegateDefaultPostProcessor< + HasDefaultProcessor< DiskProvider, [Data::VectorDataType], ( diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 26cdb869c..00cfb16f6 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -4,17 +4,15 @@ */ use dashmap::DashMap; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNErrorKind, ANNResult, - error::IntoANNResult, graph::{ AdjacencyList, SearchOutputBuffer, config::defaults::MAX_OCCLUSION_SIZE, glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, - InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchPostProcess, - SearchStrategy, + self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, }, }, neighbor::Neighbor, @@ -771,10 +769,8 @@ impl SearchStrategy, [T], GarnetId> for FullPre } } -impl DelegateDefaultPostProcessor, [T], GarnetId> - for FullPrecision -{ - delegate_default_post_process!(glue::Pipeline); +impl HasDefaultProcessor, [T], GarnetId> for FullPrecision { + has_default_processor!(glue::Pipeline); } impl SearchStrategy, [T], u32> for FullPrecision { @@ -791,60 +787,8 @@ impl SearchStrategy, [T], u32> for FullPrecisio } } -impl DelegateDefaultPostProcessor, [T], u32> for FullPrecision { - delegate_default_post_process!(glue::CopyIds); -} - -impl PostProcess, [T], glue::CopyIds, u32> for FullPrecision { - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: glue::CopyIds, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - diskann::graph::glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } -} - -impl PostProcess, [T], CopyExternalIds, GarnetId> - for FullPrecision -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: CopyExternalIds, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - diskann::graph::glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } +impl HasDefaultProcessor, [T], u32> for FullPrecision { + has_default_processor!(glue::CopyIds); } impl PruneStrategy> for FullPrecision { diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 0b6a15a39..ea509b3ff 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -61,16 +61,16 @@ where } } -/// [`DelegateDefaultPostProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps +/// [`HasDefaultProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps /// the inner strategy's default processor with [`FilterResults`]. impl - diskann::graph::glue::DelegateDefaultPostProcessor< + diskann::graph::glue::HasDefaultProcessor< DocumentProvider>, FilteredQuery, > for InlineBetaStrategy where DP: DataProvider, - Strategy: diskann::graph::glue::DelegateDefaultPostProcessor, + Strategy: diskann::graph::glue::HasDefaultProcessor, Q: AsyncFriendly + Clone, { type Processor = FilterResults; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 281abdec0..69942563b 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -176,7 +176,7 @@ pub(crate) mod tests { self, AdjacencyList, ConsolidateKind, InplaceDeleteMethod, StartPointStrategy, config::IntraBatchCandidates, glue::{ - AsElement, DelegateDefaultPostProcessor, InplaceDeleteStrategy, InsertStrategy, + AsElement, DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, SearchStrategy, aliases, }, index::{PartitionedNeighbors, QueryLabelProvider, QueryVisitDecision}, @@ -350,7 +350,7 @@ pub(crate) mod tests { mut checker: Checker, ) where DP: DataProvider, - S: SearchStrategy + DelegateDefaultPostProcessor, + S: DefaultSearchStrategy, Q: std::fmt::Debug + Sync + ?Sized, Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, { @@ -398,7 +398,7 @@ pub(crate) mod tests { filter: &dyn QueryLabelProvider, ) where DP: DataProvider, - S: SearchStrategy + DelegateDefaultPostProcessor, + S: DefaultSearchStrategy, Q: std::fmt::Debug + Sync + ?Sized, Checker: FnMut(usize, (u32, f32)) -> Result<(), Box>, { @@ -504,8 +504,8 @@ pub(crate) mod tests { quant_strategy: QS, ) where DP: DataProvider, - FS: SearchStrategy + DelegateDefaultPostProcessor + Clone + 'static, - QS: SearchStrategy + DelegateDefaultPostProcessor + Clone + 'static, + FS: DefaultSearchStrategy + Clone + 'static, + QS: DefaultSearchStrategy + Clone + 'static, T: Default + Clone + Send + Sync + std::fmt::Debug, { // Assume all vectors have the same length. @@ -927,8 +927,7 @@ pub(crate) mod tests { ) where T: VectorRepr + GenerateSphericalData + Into, S: InsertStrategy, [T]> - + SearchStrategy, [T]> - + DelegateDefaultPostProcessor, [T]> + + DefaultSearchStrategy, [T]> + Clone + 'static, rand::distr::StandardUniform: Distribution, @@ -1055,8 +1054,7 @@ pub(crate) mod tests { ) where T: VectorRepr + GenerateSphericalData + Into, S: InsertStrategy, [T]> - + SearchStrategy, [T]> - + DelegateDefaultPostProcessor, [T]> + + DefaultSearchStrategy, [T]> + Clone + 'static, rand::distr::StandardUniform: Distribution, diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 60ba2f632..3e8a68377 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -232,7 +232,7 @@ where ) -> ANNResult where T: Sync + ?Sized, - S: SearchStrategy + glue::DelegateDefaultPostProcessor, + S: glue::DefaultSearchStrategy, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index cbabe64cd..ab102b6ba 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -16,15 +16,15 @@ use std::{ use serde::{Deserialize, Serialize}; use bf_tree::{BfTree, Config}; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, error::IntoANNResult, graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, - InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, + self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -1485,43 +1485,13 @@ where } } -impl DelegateDefaultPostProcessor, [T]> for FullPrecision +impl HasDefaultProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, { - delegate_default_post_process!(RemoveDeletedIdsAndCopy); -} - -impl PostProcess, [T], RemoveDeletedIdsAndCopy> for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: RemoveDeletedIdsAndCopy, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } + has_default_processor!(RemoveDeletedIdsAndCopy); } /// An [`glue::SearchPostProcess`] implementation that reranks PQ vectors. @@ -1619,41 +1589,12 @@ where } } -impl DelegateDefaultPostProcessor, [T]> for Hybrid +impl HasDefaultProcessor, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - delegate_default_post_process!(Rerank); -} - -impl PostProcess, [T], Rerank> for Hybrid -where - T: VectorRepr, - D: AsyncFriendly + DeletionCheck, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: Rerank, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } + has_default_processor!(Rerank); } // Pruning diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index b493a7598..007307f15 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -980,13 +980,13 @@ where } } -/// [`DelegateDefaultPostProcessor`] delegation for [`Cached`]. The processor is composed by +/// [`HasDefaultProcessor`] delegation for [`Cached`]. The processor is composed by /// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. -impl glue::DelegateDefaultPostProcessor, T> for Cached +impl glue::HasDefaultProcessor, T> for Cached where T: ?Sized, DP: DataProvider, - S: glue::DelegateDefaultPostProcessor + S: glue::HasDefaultProcessor + for<'a> SearchStrategy: CacheableAccessor>, C: for<'a> AsCacheAccessorFor< 'a, @@ -1003,49 +1003,6 @@ where } } -#[derive(Debug, Clone, Copy)] -pub struct CachedPostProcess

(pub P); - -impl glue::PostProcess, T, CachedPostProcess

> - for Cached -where - T: ?Sized, - P: Send + Sync, - DP: DataProvider, - S: glue::PostProcess - + for<'a> SearchStrategy: CacheableAccessor>, - C: for<'a> AsCacheAccessorFor< - 'a, - SearchAccessor<'a, S, DP, T>, - Accessor: NeighborCache, - Error = E, - > + AsyncFriendly, - E: StandardError, -{ - fn post_process_with<'a, I, B>( - &self, - processor: CachedPostProcess

, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - self.strategy.post_process_with( - processor.0, - &mut accessor.inner, - query, - computer, - candidates, - output, - ) - } -} - /// We need `S` to be a [`PruneStrategy`] for the underlying provider. /// /// This strategy has an associated [`PruneElement`] type `E` @@ -1118,11 +1075,6 @@ where DP: DataProvider, S: InplaceDeleteStrategy, Cached: PruneStrategy>, - for<'a> Cached: glue::PostProcess< - CachingProvider, - S::DeleteElement<'a>, - CachedPostProcess, - >, C: AsyncFriendly, { type DeleteElement<'a> = S::DeleteElement<'a>; @@ -1131,7 +1083,7 @@ where type PruneStrategy = Cached; type SearchStrategy = Cached; - type SearchPostProcessor = CachedPostProcess; + type SearchPostProcessor = Pipeline; fn prune_strategy(&self) -> Self::PruneStrategy { Cached { @@ -1146,7 +1098,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - CachedPostProcess(self.strategy.search_post_processor()) + Pipeline::new(Unwrap, self.strategy.search_post_processor()) } fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index 8590c7bec..763864fd1 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -11,15 +11,15 @@ use std::{ }, }; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ AdjacencyList, SearchOutputBuffer, glue::{ - AsElement, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, - InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, + AsElement, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, + InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -903,33 +903,8 @@ impl SearchStrategy for FullPrecision { } } -impl DelegateDefaultPostProcessor for FullPrecision { - delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); -} - -impl PostProcess for FullPrecision { - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: postprocess::RemoveDeletedIdsAndCopy, - accessor: &mut Self::SearchAccessor<'a>, - query: &[f32], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - diskann::graph::glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } +impl HasDefaultProcessor for FullPrecision { + has_default_processor!(postprocess::RemoveDeletedIdsAndCopy); } impl SearchStrategy for Quantized { @@ -946,33 +921,8 @@ impl SearchStrategy for Quantized { } } -impl DelegateDefaultPostProcessor for Quantized { - delegate_default_post_process!(postprocess::RemoveDeletedIdsAndCopy); -} - -impl PostProcess for Quantized { - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: postprocess::RemoveDeletedIdsAndCopy, - accessor: &mut Self::SearchAccessor<'a>, - query: &[f32], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - diskann::graph::glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } +impl HasDefaultProcessor for Quantized { + has_default_processor!(postprocess::RemoveDeletedIdsAndCopy); } impl PruneStrategy for FullPrecision { diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index baf2a6df2..25e0797de 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -5,15 +5,14 @@ use std::{collections::HashMap, fmt::Debug, future::Future}; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, - error::IntoANNResult, graph::{ SearchOutputBuffer, glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, - InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, + self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, }, neighbor::Neighbor, @@ -480,47 +479,14 @@ where } } -impl DelegateDefaultPostProcessor, [T]> - for FullPrecision +impl HasDefaultProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(RemoveDeletedIdsAndCopy); -} - -impl PostProcess, [T], RemoveDeletedIdsAndCopy> - for FullPrecision -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: RemoveDeletedIdsAndCopy, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } + has_default_processor!(RemoveDeletedIdsAndCopy); } // Pruning diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index cfe5acc0b..d5aed4c7a 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -5,16 +5,13 @@ use std::{collections::HashMap, future::Future, sync::Arc}; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, - error::IntoANNResult, - graph::SearchOutputBuffer, graph::glue::{ - self, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, - InsertStrategy, PostProcess, PruneStrategy, SearchExt, SearchStrategy, + self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchExt, SearchStrategy, }, - neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, HasId, @@ -486,44 +483,13 @@ where } } -impl DelegateDefaultPostProcessor, [T]> - for Hybrid +impl HasDefaultProcessor, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(Rerank); -} - -impl PostProcess, [T], Rerank> for Hybrid -where - T: VectorRepr, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - #[allow(clippy::manual_async_fn)] - fn post_process_with<'a, I, B>( - &self, - processor: Rerank, - accessor: &mut Self::SearchAccessor<'a>, - query: &[T], - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - async move { - glue::SearchPostProcess::post_process( - &processor, accessor, query, computer, candidates, output, - ) - .await - .into_ann_result() - } - } + has_default_processor!(Rerank); } impl PruneStrategy> for Hybrid @@ -640,14 +606,14 @@ where } } -impl DelegateDefaultPostProcessor, [T]> +impl HasDefaultProcessor, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(RemoveDeletedIdsAndCopy); + has_default_processor!(RemoveDeletedIdsAndCopy); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index e81c5f4bf..5e6f2aa18 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -6,12 +6,12 @@ use std::{future::Future, sync::Mutex}; use crate::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, graph::glue::{ - DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, - SearchExt, SearchStrategy, + ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, PruneStrategy, SearchExt, + SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -624,8 +624,7 @@ where } impl - DelegateDefaultPostProcessor, D, Ctx>, [T]> - for Quantized + HasDefaultProcessor, D, Ctx>, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -633,7 +632,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - delegate_default_post_process!(Rerank); + has_default_processor!(Rerank); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -662,8 +661,7 @@ where } impl - DelegateDefaultPostProcessor, D, Ctx>, [T]> - for Quantized + HasDefaultProcessor, D, Ctx>, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -671,7 +669,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - delegate_default_post_process!(RemoveDeletedIdsAndCopy); + has_default_processor!(RemoveDeletedIdsAndCopy); } impl PruneStrategy, D, Ctx>> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 64744366b..bba8ffa2d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -7,13 +7,13 @@ use std::{future::Future, sync::Mutex}; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, - SearchExt, SearchStrategy, + ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, PruneStrategy, SearchExt, + SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -572,14 +572,14 @@ where } } -impl DelegateDefaultPostProcessor, [T]> +impl HasDefaultProcessor, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(Rerank); + has_default_processor!(Rerank); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -605,14 +605,14 @@ where } } -impl DelegateDefaultPostProcessor, [T]> +impl HasDefaultProcessor, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - delegate_default_post_process!(RemoveDeletedIdsAndCopy); + has_default_processor!(RemoveDeletedIdsAndCopy); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 8ffc6acf8..d3f1a1c50 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -5,12 +5,12 @@ use std::{future::Future, sync::Mutex}; -use diskann::delegate_default_post_process; +use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, error::{RankedError, ToRanked, TransientError}, graph::glue::{ - AsElement, CopyIds, DelegateDefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, + AsElement, CopyIds, ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, neighbor::Neighbor, @@ -251,8 +251,8 @@ impl SearchStrategy for Flaky { } } -impl DelegateDefaultPostProcessor for Flaky { - delegate_default_post_process!(CopyIds); +impl HasDefaultProcessor for Flaky { + has_default_processor!(CopyIds); } impl FillSet for FlakyAccessor<'_> {} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 4e703cf33..3d89359e2 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -43,4 +43,3 @@ pub mod caching; // Debug provider for testing. #[cfg(test)] pub mod debug_provider; - diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 5c09cbf8e..0e625f9f2 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -144,16 +144,16 @@ where } } -/// [`DelegateDefaultPostProcessor`] delegation for [`BetaFilter`]. The processor is composed by +/// [`HasDefaultProcessor`] delegation for [`BetaFilter`]. The processor is composed by /// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. -impl glue::DelegateDefaultPostProcessor +impl glue::HasDefaultProcessor for BetaFilter where T: ?Sized, I: VectorId, O: Send, Provider: DataProvider, - Strategy: glue::DelegateDefaultPostProcessor, + Strategy: glue::HasDefaultProcessor, { type Processor = glue::Pipeline; @@ -559,8 +559,8 @@ mod tests { } } - impl glue::DelegateDefaultPostProcessor for SimpleStrategy { - diskann::delegate_default_post_process!(CopyIds); + impl glue::HasDefaultProcessor for SimpleStrategy { + diskann::has_default_processor!(CopyIds); } /// A simple `QueryLabelProvider` that matches multiples of 3. diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index d746d3dca..f4bc6bb01 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -13,7 +13,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, ExpandBeam, SearchExt, SearchPostProcess}, + glue::{self, ExpandBeam, SearchExt}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::{self, SearchOutputBuffer}, diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 14b5960e5..b2543b013 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -16,9 +16,10 @@ use diskann_vector::distance::Metric; use thiserror::Error; use crate::{ - ANNError, ANNResult, has_default_processor, + ANNError, ANNResult, error::{Infallible, message}, graph::{AdjacencyList, glue, test::synthetic}, + has_default_processor, internal::counter::{Counter, LocalCounter}, provider, utils::VectorRepr, diff --git a/post_process_design_sketch.rs b/post_process_design_sketch.rs index cc1bceff0..994bff558 100644 --- a/post_process_design_sketch.rs +++ b/post_process_design_sketch.rs @@ -161,8 +161,8 @@ where fn create_processor(&self) -> Self::Processor; } -// Convenience macro (same idea as exhibit-B's delegate_default_post_process!). -macro_rules! delegate_default_post_process { +// Convenience macro (same idea as exhibit-B's has_default_processor!). +macro_rules! has_default_processor { ($Processor:ty) => { type Processor = $Processor; fn create_processor(&self) -> Self::Processor { @@ -291,7 +291,7 @@ impl SearchStrategy for MyStrategy { // Opt in to the default: "my default post-processor is CopyIds" impl HasDefaultProcessor for MyStrategy { - delegate_default_post_process!(CopyIds); + has_default_processor!(CopyIds); } // That's it — Knn now works with MyStrategy. @@ -392,7 +392,7 @@ where // A. HRTB on HasDefaultProcessor::Processor // The bound `for<'a> SearchPostProcess, T, O>` // is the same one that lived on SearchStrategy::PostProcessor before. -// It's not new — it just moved. The delegate_default_post_process! macro +// It's not new — it just moved. The has_default_processor! macro // should absorb this. // // B. BetaFilter's generic P delegation From 6b8b1f5dcf18d5aafdaa252293558eb9f2aaf768 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 15:46:30 -0700 Subject: [PATCH 21/47] Replace manual_async_fn with async fn in SearchPostProcess impls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert 3 SearchPostProcess implementations from manual async desugaring (fn -> impl Future + Send with async move block) to native async fn. The recursive test_spawning in provider.rs is kept manual because it needs an explicit 'static bound on the returned future. Added T: Sync bound to RemoveDeletedIdsAndCopy impl because async fn captures all parameters (including unused &T) in the future, requiring &T: Send → T: Sync. This is always satisfied at call sites. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../graph/provider/async_/bf_tree/provider.rs | 69 ++++++++--------- .../provider/async_/inmem/full_precision.rs | 77 +++++++++---------- .../graph/provider/async_/postprocess.rs | 47 ++++++----- 3 files changed, 92 insertions(+), 101 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index ab102b6ba..6cd02a680 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -1515,55 +1515,52 @@ where { type Error = ANNError; - #[allow(clippy::manual_async_fn)] - fn post_process( + async fn post_process( &self, accessor: &mut QuantAccessor<'a, T, D>, query: &[T], _computer: &pq::distance::QueryComputer>, candidates: I, output: &mut B, - ) -> impl Future> + Send + ) -> Result where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { - async move { - let provider = &accessor.provider; - let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); - let is_not_start_point = if self.filter_start_points { - Some(accessor.is_not_start_point().await?) - } else { - None - }; - let checker = accessor.as_deletion_check(); - - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - return None; - } + let provider = &accessor.provider; + let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); + let is_not_start_point = if self.filter_start_points { + Some(accessor.is_not_start_point().await?) + } else { + None + }; + let checker = accessor.as_deletion_check(); - if let Some(filter) = is_not_start_point.as_ref() - && !filter(n.id) - { - return None; - } + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + return None; + } - #[allow(clippy::expect_used)] - let vec = provider - .full_vectors - .get_vector_sync(n.id.into_usize()) - .expect("Full vector provider failed to retrieve element"); - Some((n.id, f.evaluate_similarity(query, &vec))) - }) - .collect(); + if let Some(filter) = is_not_start_point.as_ref() + && !filter(n.id) + { + return None; + } - reranked.sort_unstable_by(|a, b| { - (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); - Ok(output.extend(reranked)) - } + #[allow(clippy::expect_used)] + let vec = provider + .full_vectors + .get_vector_sync(n.id.into_usize()) + .expect("Full vector provider failed to retrieve element"); + Some((n.id, f.evaluate_similarity(query, &vec))) + }) + .collect(); + + reranked.sort_unstable_by(|a, b| { + (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + Ok(output.extend(reranked)) } } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 25e0797de..5170a20a1 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -399,56 +399,53 @@ where { type Error = ANNError; - #[allow(clippy::manual_async_fn)] - fn post_process( + async fn post_process( &self, accessor: &mut A, query: &[T], _computer: &A::QueryComputer, candidates: I, output: &mut B, - ) -> impl Future> + Send + ) -> Result where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { - async move { - let full = accessor.as_full_precision(); - let f = full.distance(); - let is_not_start_point = if self.filter_start_points { - Some(accessor.is_not_start_point().await?) - } else { - None - }; - let checker = accessor.as_deletion_check(); - - let mut reranked: Vec<(u32, f32)> = candidates - .filter_map(|n| { - if checker.deletion_check(n.id) { - return None; - } - - if let Some(filter) = is_not_start_point.as_ref() - && !filter(n.id) - { - return None; - } - - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) - }) - .collect(); - - reranked.sort_unstable_by(|a, b| { - (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); - - Ok(output.extend(reranked)) - } + let full = accessor.as_full_precision(); + let f = full.distance(); + let is_not_start_point = if self.filter_start_points { + Some(accessor.is_not_start_point().await?) + } else { + None + }; + let checker = accessor.as_deletion_check(); + + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if checker.deletion_check(n.id) { + return None; + } + + if let Some(filter) = is_not_start_point.as_ref() + && !filter(n.id) + { + return None; + } + + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) + }) + .collect(); + + reranked.sort_unstable_by(|a, b| { + (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) + }); + + Ok(output.extend(reranked)) } } diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index e644a3117..14fac7e63 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -52,46 +52,43 @@ impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where A: BuildQueryComputer + AsDeletionCheck + glue::SearchExt, ::Checker: Sync, - T: ?Sized, + T: ?Sized + Sync, { type Error = ANNError; - #[allow(clippy::manual_async_fn)] - fn post_process( + async fn post_process( &self, accessor: &mut A, _query: &T, _computer: &>::QueryComputer, candidates: I, output: &mut B, - ) -> impl std::future::Future> + Send + ) -> Result where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { - async move { - let is_not_start_point = if self.filter_start_points { - Some(accessor.is_not_start_point().await?) - } else { - None - }; - - let checker = accessor.as_deletion_check(); - let filtered = candidates.filter_map(|n| { - if checker.deletion_check(n.id) { - None - } else { - Some((n.id, n.distance)) - } - }); + let is_not_start_point = if self.filter_start_points { + Some(accessor.is_not_start_point().await?) + } else { + None + }; - let count = if let Some(filter) = is_not_start_point { - output.extend(filtered.filter(|(id, _)| filter(*id))) + let checker = accessor.as_deletion_check(); + let filtered = candidates.filter_map(|n| { + if checker.deletion_check(n.id) { + None } else { - output.extend(filtered) - }; + Some((n.id, n.distance)) + } + }); - Ok(count) - } + let count = if let Some(filter) = is_not_start_point { + output.extend(filtered.filter(|(id, _)| filter(*id))) + } else { + output.extend(filtered) + }; + + Ok(count) } } From 5e2db52c304c3e0fce7391b0bf1ad0fae5e81ba2 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 15:57:42 -0700 Subject: [PATCH 22/47] Fix unused imports from PostProcess removal Remove leftover SearchOutputBuffer, IntoANNResult, and Neighbor imports in debug_provider.rs that were only used by the deleted PostProcess impls. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/model/graph/provider/async_/bf_tree/provider.rs | 5 ++--- .../src/model/graph/provider/async_/debug_provider.rs | 4 +--- .../src/model/graph/provider/async_/inmem/full_precision.rs | 5 ++--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 6cd02a680..fcc9fb561 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -1557,9 +1557,8 @@ where }) .collect(); - reranked.sort_unstable_by(|a, b| { - (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(output.extend(reranked)) } } diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index 763864fd1..d2ee732a2 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -14,15 +14,13 @@ use std::{ use diskann::has_default_processor; use diskann::{ ANNError, ANNErrorKind, ANNResult, - error::IntoANNResult, graph::{ - AdjacencyList, SearchOutputBuffer, + AdjacencyList, glue::{ AsElement, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, }, - neighbor::Neighbor, provider::{ self, Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultAccessor, DefaultContext, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 5170a20a1..a803fbd1b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -441,9 +441,8 @@ where }) .collect(); - reranked.sort_unstable_by(|a, b| { - (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); Ok(output.extend(reranked)) } From 349d0b677b69f70bc98c2247cc6cb1bc2c7078d0 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 16:11:11 -0700 Subject: [PATCH 23/47] Remove some more JSONs. --- Cargo.lock | 1 - copilot.txt | 1 - diskann-benchmark-core/Cargo.toml | 1 - ...kipedia_disk_build_and_compare_detdiv.json | 82 ------------------- .../wikipedia_disk_build_baseline.json | 37 --------- .../wikipedia_disk_compare_detdiv.json | 52 ------------ .../wikipedia_disk_load_compare_detdiv.json | 52 ------------ 7 files changed, 226 deletions(-) delete mode 100644 copilot.txt delete mode 100644 diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json delete mode 100644 diskann-benchmark/example/wikipedia_disk_build_baseline.json delete mode 100644 diskann-benchmark/example/wikipedia_disk_compare_detdiv.json delete mode 100644 diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json diff --git a/Cargo.lock b/Cargo.lock index 8d0c541f9..51b6a0b0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,7 +687,6 @@ dependencies = [ "anyhow", "diskann", "diskann-benchmark-runner", - "diskann-providers", "diskann-utils", "diskann-vector", "futures-util", diff --git a/copilot.txt b/copilot.txt deleted file mode 100644 index 722f174ed..000000000 --- a/copilot.txt +++ /dev/null @@ -1 +0,0 @@ -copilot --resume=c5407796-a927-4cca-9c30-d450692d150a diff --git a/diskann-benchmark-core/Cargo.toml b/diskann-benchmark-core/Cargo.toml index b3eabed45..90e64b9e3 100644 --- a/diskann-benchmark-core/Cargo.toml +++ b/diskann-benchmark-core/Cargo.toml @@ -11,7 +11,6 @@ edition = "2024" anyhow.workspace = true diskann.workspace = true diskann-benchmark-runner = { workspace = true } -diskann-providers = { workspace = true } diskann-utils.default-features = false diskann-utils.workspace = true futures-util = { workspace = true, default-features = false } diff --git a/diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json b/diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json deleted file mode 100644 index 166f1edc4..000000000 --- a/diskann-benchmark/example/wikipedia_disk_build_and_compare_detdiv.json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "search_directories": [ - "C:/wikipedia_dataset" - ], - "jobs": [ - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Build", - "data_type": "float32", - "data": "C:/wikipedia_dataset/data.bin", - "distance": "squared_l2", - "dim": 1024, - "max_degree": 64, - "l_build": 100, - "num_threads": 8, - "build_ram_limit_gb": 32.0, - "num_pq_chunks": 128, - "quantization_type": "FP", - "save_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null - } - } - }, - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Load", - "data_type": "float32", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null - } - } - }, - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Load", - "data_type": "float32", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null, - "determinant_diversity_eta": 0.01, - "determinant_diversity_power": 1.0, - "determinant_diversity_results_k": 10 - } - } - } - ] -} diff --git a/diskann-benchmark/example/wikipedia_disk_build_baseline.json b/diskann-benchmark/example/wikipedia_disk_build_baseline.json deleted file mode 100644 index a3abf3eb7..000000000 --- a/diskann-benchmark/example/wikipedia_disk_build_baseline.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "search_directories": [ - "C:/wikipedia_dataset" - ], - "jobs": [ - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Build", - "data_type": "float32", - "data": "C:/wikipedia_dataset/data.bin", - "distance": "squared_l2", - "dim": 1024, - "max_degree": 64, - "l_build": 100, - "num_threads": 8, - "build_ram_limit_gb": 32.0, - "num_pq_chunks": 128, - "quantization_type": "FP", - "save_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null - } - } - } - ] -} diff --git a/diskann-benchmark/example/wikipedia_disk_compare_detdiv.json b/diskann-benchmark/example/wikipedia_disk_compare_detdiv.json deleted file mode 100644 index d3c430a28..000000000 --- a/diskann-benchmark/example/wikipedia_disk_compare_detdiv.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "search_directories": [ - "C:/wikipedia_dataset" - ], - "jobs": [ - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Load", - "data_type": "float32", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null - } - } - }, - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Load", - "data_type": "float32", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null, - "determinant_diversity_eta": 0.01, - "determinant_diversity_power": 1.0, - "determinant_diversity_results_k": 10 - } - } - } - ] -} diff --git a/diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json b/diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json deleted file mode 100644 index a7e08c15b..000000000 --- a/diskann-benchmark/example/wikipedia_disk_load_compare_detdiv.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "search_directories": [ - "C:/wikipedia_dataset" - ], - "jobs": [ - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Load", - "data_type": "float32", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null - } - } - }, - { - "type": "disk-index", - "content": { - "source": { - "disk-index-source": "Load", - "data_type": "float32", - "load_path": "C:/wikipedia_dataset/wikipedia_saved_disk_index" - }, - "search_phase": { - "queries": "C:/wikipedia_dataset/query.bin", - "groundtruth": "C:/wikipedia_dataset/groundtruth_k100.bin", - "search_list": [20, 30, 40, 50, 100, 200], - "beam_width": 8, - "recall_at": 10, - "num_threads": 8, - "is_flat_search": false, - "distance": "squared_l2", - "vector_filters_file": null, - "determinant_diversity_eta": 0.01, - "determinant_diversity_power": 1.0, - "determinant_diversity_results_k": 10 - } - } - } - ] -} From 749d59bcfbd3ecfea15e2929ecd28b1f876edfcc Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 16:42:47 -0700 Subject: [PATCH 24/47] Replace runtime filter_start_points with type-level Pipeline composition Restore main's structural approach: post-processors (RemoveDeletedIdsAndCopy, Rerank) are clean ZSTs that only handle deletion filtering and reranking. Start-point filtering is composed via Pipeline at the type level: - HasDefaultProcessor returns Pipeline (filters start points during regular search) - InplaceDeleteStrategy returns Base directly (no start-point filtering during delete operations) This eliminates the runtime 'filter_start_points: bool' flag, makes the post-processors synchronous again (no .await needed), and restores their Error types to Infallible/Panics instead of ANNError. Also reverts diskann-benchmark/src/backend/index/search/knn.rs to main. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- copilot.txt | 1 + .../src/backend/disk_index/search.rs | 6 +- .../src/backend/index/search/knn.rs | 52 +++++---------- diskann-providers/src/index/wrapped_async.rs | 5 +- .../graph/provider/async_/bf_tree/provider.rs | 65 ++++++------------- .../graph/provider/async_/debug_provider.rs | 17 ++--- .../provider/async_/inmem/full_precision.rs | 60 +++++------------ .../graph/provider/async_/inmem/product.rs | 8 +-- .../graph/provider/async_/inmem/scalar.rs | 8 +-- .../graph/provider/async_/inmem/spherical.rs | 8 +-- .../graph/provider/async_/postprocess.rs | 45 +++---------- diskann/src/graph/search/mod.rs | 7 +- example.rs | 40 ------------ 13 files changed, 91 insertions(+), 231 deletions(-) create mode 100644 copilot.txt delete mode 100644 example.rs diff --git a/copilot.txt b/copilot.txt new file mode 100644 index 000000000..722f174ed --- /dev/null +++ b/copilot.txt @@ -0,0 +1 @@ +copilot --resume=c5407796-a927-4cca-9c30-d450692d150a diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index f3ad744a9..65e5804a7 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -269,16 +269,14 @@ where as Box bool + Send + Sync>) }; - let search_result = searcher.search( + match searcher.search( q, search_params.recall_at, l, Some(search_params.beam_width), vector_filter, search_params.is_flat_search, - ); - - match search_result { + ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; *rc = search_result.results.len() as u32; diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 5695f4e95..915b8eca6 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -35,40 +35,6 @@ pub(crate) fn run( groundtruth: &dyn benchmark_core::recall::Rows, steps: SearchSteps<'_>, ) -> anyhow::Result> { - run_search(runner, groundtruth, steps, |setup, search_l, search_n| { - let search_params = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); - core_search::Run::new(search_params, setup) - }) -} - -type Run = core_search::Run; -pub(crate) trait Knn { - fn search_all( - &self, - parameters: Vec, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result>; -} - -/////////// -// Impls // -/////////// - -/// Generic search infrastructure. -/// -/// This helper extracts the common loop logic (iterating over threads and runs, -/// and building a setup) leaving parameter construction to a builder closure. -fn run_search( - runner: &dyn Knn, - groundtruth: &dyn benchmark_core::recall::Rows, - steps: SearchSteps<'_>, - builder: F, -) -> anyhow::Result> -where - F: Fn(core_search::Setup, usize, usize) -> core_search::Run, -{ let mut all = Vec::new(); for threads in steps.num_tasks.iter() { @@ -82,7 +48,12 @@ where let parameters: Vec<_> = run .search_l .iter() - .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) + .map(|search_l| { + let search_params = + diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + + core_search::Run::new(search_params, setup.clone()) + }) .collect(); all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); @@ -92,6 +63,17 @@ where Ok(all) } +type Run = core_search::Run; +pub(crate) trait Knn { + fn search_all( + &self, + parameters: Vec, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result>; +} + /////////// // Impls // /////////// diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 3e8a68377..c972df2de 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -10,7 +10,8 @@ use diskann::{ graph::{ self, ConsolidateKind, InplaceDeleteMethod, glue::{ - self, AsElement, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchStrategy, + self, AsElement, DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchStrategy, }, index::{DegreeStats, PartitionedNeighbors, SearchState, SearchStats}, search::Knn, @@ -232,7 +233,7 @@ where ) -> ANNResult where T: Sync + ?Sized, - S: glue::DefaultSearchStrategy, + S: DefaultSearchStrategy, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index fcc9fb561..4a8ca44fc 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -16,7 +16,6 @@ use std::{ use serde::{Deserialize, Serialize}; use bf_tree::{BfTree, Config}; -use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, error::IntoANNResult, @@ -27,6 +26,7 @@ use diskann::{ PruneStrategy, SearchExt, SearchStrategy, }, }, + has_default_processor, neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, @@ -1491,75 +1491,54 @@ where Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, { - has_default_processor!(RemoveDeletedIdsAndCopy); + has_default_processor!(glue::Pipeline); } /// An [`glue::SearchPostProcess`] implementation that reranks PQ vectors. -#[derive(Debug, Clone, Copy)] -pub struct Rerank { - pub filter_start_points: bool, -} - -impl Default for Rerank { - fn default() -> Self { - Self { - filter_start_points: true, - } - } -} +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; impl<'a, T, D> glue::SearchPostProcess, [T]> for Rerank where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - type Error = ANNError; + type Error = Panics; - async fn post_process( + fn post_process( &self, accessor: &mut QuantAccessor<'a, T, D>, query: &[T], _computer: &pq::distance::QueryComputer>, candidates: I, output: &mut B, - ) -> Result + ) -> impl std::future::Future> + Send where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { let provider = &accessor.provider; - let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); - let is_not_start_point = if self.filter_start_points { - Some(accessor.is_not_start_point().await?) - } else { - None - }; let checker = accessor.as_deletion_check(); + let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); let mut reranked: Vec<(u32, f32)> = candidates .filter_map(|n| { if checker.deletion_check(n.id) { - return None; - } - - if let Some(filter) = is_not_start_point.as_ref() - && !filter(n.id) - { - return None; + None + } else { + #[allow(clippy::expect_used)] + let vec = provider + .full_vectors + .get_vector_sync(n.id.into_usize()) + .expect("Full vector provider failed to retrieve element"); + Some((n.id, f.evaluate_similarity(query, &vec))) } - - #[allow(clippy::expect_used)] - let vec = provider - .full_vectors - .get_vector_sync(n.id.into_usize()) - .expect("Full vector provider failed to retrieve element"); - Some((n.id, f.evaluate_similarity(query, &vec))) }) .collect(); reranked .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - Ok(output.extend(reranked)) + std::future::ready(Ok(output.extend(reranked))) } } @@ -1590,7 +1569,7 @@ where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - has_default_processor!(Rerank); + has_default_processor!(glue::Pipeline); } // Pruning @@ -1712,9 +1691,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - RemoveDeletedIdsAndCopy { - filter_start_points: false, - } + RemoveDeletedIdsAndCopy } async fn get_delete_element<'a>( @@ -1753,9 +1730,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Rerank { - filter_start_points: false, - } + Rerank } async fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index d2ee732a2..d75b2afdb 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -17,8 +17,9 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - AsElement, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, - InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + AsElement, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, + InplaceDeleteStrategy, InsertStrategy, Pipeline, PruneStrategy, SearchExt, + SearchStrategy, }, }, provider::{ @@ -902,7 +903,7 @@ impl SearchStrategy for FullPrecision { } impl HasDefaultProcessor for FullPrecision { - has_default_processor!(postprocess::RemoveDeletedIdsAndCopy); + has_default_processor!(Pipeline); } impl SearchStrategy for Quantized { @@ -920,7 +921,7 @@ impl SearchStrategy for Quantized { } impl HasDefaultProcessor for Quantized { - has_default_processor!(postprocess::RemoveDeletedIdsAndCopy); + has_default_processor!(Pipeline); } impl PruneStrategy for FullPrecision { @@ -1024,9 +1025,7 @@ impl InplaceDeleteStrategy for FullPrecision { } fn search_post_processor(&self) -> Self::SearchPostProcessor { - postprocess::RemoveDeletedIdsAndCopy { - filter_start_points: false, - } + postprocess::RemoveDeletedIdsAndCopy } fn get_delete_element<'a>( @@ -1057,9 +1056,7 @@ impl InplaceDeleteStrategy for Quantized { } fn search_post_processor(&self) -> Self::SearchPostProcessor { - postprocess::RemoveDeletedIdsAndCopy { - filter_start_points: false, - } + postprocess::RemoveDeletedIdsAndCopy } fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index a803fbd1b..87679154d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -378,73 +378,51 @@ pub trait GetFullPrecision { /// 1. Filters out deleted ids from being returned. /// 2. Reranks a candidate stream using full-precision distances. /// 3. Copies back the results to the output buffer. -#[derive(Debug, Clone, Copy)] -pub struct Rerank { - pub filter_start_points: bool, -} - -impl Default for Rerank { - fn default() -> Self { - Self { - filter_start_points: true, - } - } -} +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; impl glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck + SearchExt, - ::Checker: Sync, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, { - type Error = ANNError; + type Error = Panics; - async fn post_process( + fn post_process( &self, accessor: &mut A, query: &[T], _computer: &A::QueryComputer, candidates: I, output: &mut B, - ) -> Result + ) -> impl std::future::Future> + Send where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { let full = accessor.as_full_precision(); - let f = full.distance(); - let is_not_start_point = if self.filter_start_points { - Some(accessor.is_not_start_point().await?) - } else { - None - }; let checker = accessor.as_deletion_check(); + let f = full.distance(); let mut reranked: Vec<(u32, f32)> = candidates .filter_map(|n| { if checker.deletion_check(n.id) { - return None; - } - - if let Some(filter) = is_not_start_point.as_ref() - && !filter(n.id) - { - return None; + None + } else { + Some(( + n.id, + f.evaluate_similarity(query, unsafe { + full.get_vector_sync(n.id.into_usize()) + }), + )) } - - Some(( - n.id, - f.evaluate_similarity(query, unsafe { - full.get_vector_sync(n.id.into_usize()) - }), - )) }) .collect(); reranked .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); - Ok(output.extend(reranked)) + std::future::ready(Ok(output.extend(reranked))) } } @@ -482,7 +460,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(RemoveDeletedIdsAndCopy); + has_default_processor!(glue::Pipeline); } // Pruning @@ -560,9 +538,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - RemoveDeletedIdsAndCopy { - filter_start_points: false, - } + RemoveDeletedIdsAndCopy } async fn get_delete_element<'a>( diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index d5aed4c7a..1c8bf7ebd 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -489,7 +489,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(Rerank); + has_default_processor!(glue::Pipeline); } impl PruneStrategy> for Hybrid @@ -565,9 +565,7 @@ where } fn search_post_processor(&self) -> Self::SearchPostProcessor { - Rerank { - filter_start_points: false, - } + Rerank } async fn get_delete_element<'a>( @@ -613,7 +611,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(RemoveDeletedIdsAndCopy); + has_default_processor!(glue::Pipeline); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 5e6f2aa18..e82f70cb7 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -10,8 +10,8 @@ use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, graph::glue::{ - ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, PruneStrategy, SearchExt, - SearchStrategy, + ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, Pipeline, + PruneStrategy, SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -632,7 +632,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - has_default_processor!(Rerank); + has_default_processor!(Pipeline); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -669,7 +669,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - has_default_processor!(RemoveDeletedIdsAndCopy); + has_default_processor!(Pipeline); } impl PruneStrategy, D, Ctx>> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index bba8ffa2d..9bc43d7d4 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -12,8 +12,8 @@ use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, PruneStrategy, SearchExt, - SearchStrategy, + ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, Pipeline, + PruneStrategy, SearchExt, SearchStrategy, }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, @@ -579,7 +579,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(Rerank); + has_default_processor!(Pipeline); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -612,7 +612,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(RemoveDeletedIdsAndCopy); + has_default_processor!(Pipeline); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index 14fac7e63..bf6a47bba 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -6,7 +6,6 @@ //! Shared search post-processing. use diskann::{ - ANNError, graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, provider::BuildQueryComputer, @@ -35,60 +34,36 @@ pub(crate) trait DeletionCheck { /// A [`SearchPostProcess`] routine that fuses the removal of deleted elements with the /// copying of IDs into an output buffer. -#[derive(Debug, Clone, Copy)] -pub struct RemoveDeletedIdsAndCopy { - pub filter_start_points: bool, -} - -impl Default for RemoveDeletedIdsAndCopy { - fn default() -> Self { - Self { - filter_start_points: true, - } - } -} +#[derive(Debug, Clone, Copy, Default)] +pub struct RemoveDeletedIdsAndCopy; impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + AsDeletionCheck + glue::SearchExt, - ::Checker: Sync, - T: ?Sized + Sync, + A: BuildQueryComputer + AsDeletionCheck, + T: ?Sized, { - type Error = ANNError; + type Error = std::convert::Infallible; - async fn post_process( + fn post_process( &self, accessor: &mut A, _query: &T, _computer: &>::QueryComputer, candidates: I, output: &mut B, - ) -> Result + ) -> impl std::future::Future> + Send where I: Iterator> + Send, B: SearchOutputBuffer + Send + ?Sized, { - let is_not_start_point = if self.filter_start_points { - Some(accessor.is_not_start_point().await?) - } else { - None - }; - let checker = accessor.as_deletion_check(); - let filtered = candidates.filter_map(|n| { + let count = output.extend(candidates.filter_map(|n| { if checker.deletion_check(n.id) { None } else { Some((n.id, n.distance)) } - }); - - let count = if let Some(filter) = is_not_start_point { - output.extend(filtered.filter(|(id, _)| filter(*id))) - } else { - output.extend(filtered) - }; - - Ok(count) + })); + std::future::ready(Ok(count)) } } diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index d43bc77b0..a50104ea5 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -89,11 +89,8 @@ where output: &mut OB, ) -> impl SendFuture> where - PP: for<'a> crate::graph::glue::SearchPostProcess< - >::SearchAccessor<'a>, - T, - O, - > + Send + PP: for<'a> crate::graph::glue::SearchPostProcess, T, O> + + Send + Sync, OB: crate::graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; } diff --git a/example.rs b/example.rs deleted file mode 100644 index dccc80441..000000000 --- a/example.rs +++ /dev/null @@ -1,40 +0,0 @@ -// Default post-process -#[derive(Debug, Clone, Copy)] -pub struct DefaultPostProcess; - -pub trait DelegatePostProcess { - type Delegate: DoesThings; -} - -impl SearchPostProcess for DefaultPostProcess -where - T: DelegatePostProcess -{ - fn post_process(args...) { - T::Delegate::post_process(args...) - } -} - -// Apply the default post-process via the normal search API. -fn search( - dispatch: T, - other_args... -) -where - DefaultPostProcess: SearchPostProcess -{ - search_with(dispatch, other_args..., DefaultPostProcess) -} - -// Second API that allows for overriding the post-processor explicitly. -fn search_with( - dispatch: T, - other_args... - post_process: P -) -where - P: SearchPostProcess -{ - // Do the thing. The `Search` trait will always take a post-processor. -} - From 008c09b5d8f36d3f2f21655ea0898ad74ed45276 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 16:44:44 -0700 Subject: [PATCH 25/47] Remove misc files. --- copilot.txt | 1 - post_process_design_sketch.rs | 418 ---------------------------------- 2 files changed, 419 deletions(-) delete mode 100644 copilot.txt delete mode 100644 post_process_design_sketch.rs diff --git a/copilot.txt b/copilot.txt deleted file mode 100644 index 722f174ed..000000000 --- a/copilot.txt +++ /dev/null @@ -1 +0,0 @@ -copilot --resume=c5407796-a927-4cca-9c30-d450692d150a diff --git a/post_process_design_sketch.rs b/post_process_design_sketch.rs deleted file mode 100644 index 994bff558..000000000 --- a/post_process_design_sketch.rs +++ /dev/null @@ -1,418 +0,0 @@ -// ============================================================================= -// Post-Processing Redesign: Sketch & Rationale -// ============================================================================= -// -// Context -// ------- -// Two competing PRs attempted to refactor how SearchStrategy interacts with -// post-processing. Both had structural problems: -// -// Exhibit-A kept `type PostProcessor` on SearchStrategy and layered a new -// `PostProcess` trait on top. This created two -// parallel "what's the post-processor?" answers on the same type that could -// silently diverge. The GAT associated type became dead weight that every -// implementor still had to fill in. -// -// Exhibit-B removed `PostProcessor` from SearchStrategy (good), but replaced -// it with a `DelegatePostProcess` marker whose blanket impl covered *all* -// processor types `P` at once: -// -// impl PostProcess for S -// where S: SearchStrategy<…> + DelegatePostProcess, -// P: for<'a> SearchPostProcess, T, O> + … -// -// This makes it impossible to override `PostProcess` for a specific `P` -// without opting out of the blanket entirely (removing DelegatePostProcess), -// which then forces manual impls for every processor type — an all-or-nothing -// cliff. It also provided no `KnnWith`-style mechanism for callers to supply -// a custom processor at the search call-site. -// -// Proposed Design -// --------------- -// Flip the blanket. Instead of "strategy S gets PostProcess for all P", -// make it "the DefaultPostProcess ZST gets support for all strategies S -// that opt in via HasDefaultProcessor". -// -// The blanket is narrow (covers exactly one P = DefaultPostProcess), so custom -// PostProcess<…, RagSearchParams, …> impls are coherence-safe. Strategies -// that don't need a default can skip HasDefaultProcessor and still be used via -// KnnWith with an explicit processor. -// -// ============================================================================= -// -// How to read this file -// --------------------- -// This is pseudocode — it won't compile. Signatures use real Rust syntax where -// possible but elide lifetimes, bounds, and async machinery for clarity. -// Comments marked "NOTE" call out places where the real implementation will -// need careful attention to HRTB / GAT interactions. -// -// ============================================================================= - -// --------------------------------------------------------------------------- -// 1. SearchStrategy — clean, no post-processing knowledge -// --------------------------------------------------------------------------- -// -// This is the same as today minus `type PostProcessor` and `fn post_processor`. - -pub trait SearchStrategy::InternalId>: - Send + Sync -where - Provider: DataProvider, - T: ?Sized, - O: Send, -{ - type QueryComputer: /* PreprocessedDistanceFunction bounds */ Send + Sync + 'static; - type SearchAccessorError: StandardError; - - // NOTE: This GAT is the source of most HRTB complexity downstream. - type SearchAccessor<'a>: ExpandBeam - + SearchExt; - - fn search_accessor<'a>( - &'a self, - provider: &'a Provider, - context: &'a Provider::Context, - ) -> Result, Self::SearchAccessorError>; -} - -// --------------------------------------------------------------------------- -// 2. SearchPostProcess — unchanged from today -// --------------------------------------------------------------------------- -// -// Low-level trait, parameterized by the *accessor* (not the strategy). -// CopyIds, Rerank, Pipeline, RemoveDeletedIdsAndCopy, etc. all -// implement this directly. No changes needed here. - -pub trait SearchPostProcess::Id> -where - A: BuildQueryComputer, - T: ?Sized, -{ - type Error: StandardError; - - fn post_process( - &self, - accessor: &mut A, - query: &T, - computer: &>::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized; -} - -// Pipeline, CopyIds, FilterStartPoints, SearchPostProcessStep — all unchanged. - -// --------------------------------------------------------------------------- -// 3. PostProcess — strategy-level bridge, parameterized by processor P -// --------------------------------------------------------------------------- -// -// This trait connects a strategy to a specific processor type. It is the -// surface that the search infrastructure (Knn, KnnWith, RecordedKnn, etc.) -// bounds on. - -pub trait PostProcess::InternalId>: - SearchStrategy -where - Provider: DataProvider, - T: ?Sized, - O: Send, - P: Send + Sync, -{ - fn post_process_with<'a, I, B>( - &self, - processor: &P, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized; -} - -// --------------------------------------------------------------------------- -// 4. HasDefaultProcessor — opt-in "I have a default post-processor" -// --------------------------------------------------------------------------- -// -// Strategies that want to work with Knn (no explicit processor) implement this. -// It replaces the old `type PostProcessor` on SearchStrategy. -// -// NOTE: The `for<'a> SearchPostProcess, T, O>` HRTB -// bound is the same one that lived on SearchStrategy::PostProcessor today. -// It's not new complexity — it just moved here. - -pub trait HasDefaultProcessor::InternalId>: - SearchStrategy -where - Provider: DataProvider, - T: ?Sized, - O: Send, -{ - type Processor: for<'a> SearchPostProcess, T, O> - + Send - + Sync; - - fn create_processor(&self) -> Self::Processor; -} - -// Convenience macro (same idea as exhibit-B's has_default_processor!). -macro_rules! has_default_processor { - ($Processor:ty) => { - type Processor = $Processor; - fn create_processor(&self) -> Self::Processor { - Default::default() - } - }; -} - -// --------------------------------------------------------------------------- -// 5. DefaultPostProcess ZST + THE blanket impl -// --------------------------------------------------------------------------- -// -// KEY DESIGN POINT: The blanket covers exactly P = DefaultPostProcess. -// Custom processor types (RagSearchParams, etc.) are free to have their own -// `impl PostProcess<…, RagSearchParams, …> for MyStrategy` without any -// coherence conflict. - -#[derive(Debug, Default, Clone, Copy)] -pub struct DefaultPostProcess; - -impl PostProcess for S -where - S: HasDefaultProcessor, - Provider: DataProvider, - T: ?Sized + Sync, - O: Send, -{ - async fn post_process_with<'a, I, B>( - &self, - _processor: &DefaultPostProcess, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> ANNResult - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - self.create_processor() - .post_process(accessor, query, computer, candidates, output) - .await - .into_ann_result() - } -} - -// --------------------------------------------------------------------------- -// 6. Search API split: Knn vs KnnWith -// --------------------------------------------------------------------------- -// -// Knn uses the default processor. KnnWith allows an explicit override. -// Both delegate to a shared `search_core` that is parameterized over PP. - -impl Knn { - /// Shared core — the only axis of variation is the processor. - async fn search_core( - &self, - index: &DiskANNIndex, - strategy: &S, - /* … */ - post_processor: &PP, - ) -> ANNResult - where - S: PostProcess, - PP: Send + Sync, - /* … */ - { - let mut accessor = strategy.search_accessor(/* … */)?; - let computer = accessor.build_query_computer(query)?; - /* … search_internal … */ - let count = strategy - .post_process_with(post_processor, &mut accessor, query, &computer, candidates, output) - .await?; - Ok(stats.finish(count as u32)) - } -} - -// Knn: uses DefaultPostProcess -impl Search for Knn -where - S: PostProcess, - // equivalently: S: HasDefaultProcessor -{ - fn search(self, /* … */) -> impl SendFuture> { - async move { - self.search_core(/* … */, &DefaultPostProcess).await - } - } -} - -// KnnWith: uses caller-supplied processor -pub struct KnnWith { - inner: Knn, - post_processor: PP, -} - -impl Search for KnnWith -where - S: PostProcess, - PP: Send + Sync, -{ - fn search(self, /* … */) -> impl SendFuture> { - async move { - self.inner - .search_core(/* … */, &self.post_processor) - .await - } - } -} - -// --------------------------------------------------------------------------- -// 7. Example: implementing a strategy -// --------------------------------------------------------------------------- - -struct MyStrategy { /* … */ } - -impl SearchStrategy for MyStrategy { - type QueryComputer = MyComputer; - type SearchAccessorError = ANNError; - type SearchAccessor<'a> = MyAccessor<'a>; - - fn search_accessor<'a>(/* … */) -> Result, ANNError> { /* … */ } - // No PostProcessor, no post_processor() — clean. -} - -// Opt in to the default: "my default post-processor is CopyIds" -impl HasDefaultProcessor for MyStrategy { - has_default_processor!(CopyIds); -} -// That's it — Knn now works with MyStrategy. - -// Opt in to RAG reranking too (no coherence conflict!): -impl PostProcess for MyStrategy { - async fn post_process_with( - &self, - processor: &RagSearchParams, - accessor: &mut MyAccessor<'_>, - /* … */ - ) -> ANNResult { - // Custom RAG logic here - } -} -// Now `KnnWith::new(knn, rag_params)` also works with MyStrategy. - -// --------------------------------------------------------------------------- -// 8. Decorator strategies (BetaFilter) -// --------------------------------------------------------------------------- -// -// BetaFilter wraps an inner strategy and delegates. The PostProcess<…, P, …> -// impl is generic over P, which is coherence-safe because it's on a concrete -// wrapper type (not a blanket over Self). - -impl PostProcess - for BetaFilter -where - Strategy: PostProcess, - P: Send + Sync, - /* … other bounds … */ -{ - async fn post_process_with( - &self, - processor: &P, - accessor: &mut Self::SearchAccessor<'_>, - /* … */ - ) -> ANNResult { - // Unwrap the layered accessor, delegate to inner strategy - self.strategy - .post_process_with(processor, &mut accessor.inner, /* … */) - .await - } -} - -impl HasDefaultProcessor - for BetaFilter -where - Strategy: HasDefaultProcessor, - /* … */ -{ - type Processor = Strategy::Processor; - fn create_processor(&self) -> Self::Processor { - self.strategy.create_processor() - } -} - -// --------------------------------------------------------------------------- -// 9. InplaceDeleteStrategy -// --------------------------------------------------------------------------- -// -// The delete-search phase needs exactly one processor type. The associated -// type pins it, and the SearchStrategy bound requires PostProcess for that -// specific type. -// -// NOTE: The double `for<'a>` bound is verbose but unavoidable given the GAT. - -pub trait InplaceDeleteStrategy: Send + Sync + 'static -where - Provider: DataProvider, -{ - type DeleteElement<'a>: Send + Sync + ?Sized; - type DeleteElementGuard: /* … AsyncLower … */ + 'static; - type DeleteElementError: StandardError; - type PruneStrategy: PruneStrategy; - - /// The processor used during the delete-search phase. - type SearchPostProcessor: Send + Sync; - - /// The search strategy, which must support PostProcess with the above processor. - type SearchStrategy: for<'a> SearchStrategy> - + for<'a> PostProcess< - Provider, - Self::DeleteElement<'a>, - Self::SearchPostProcessor, - >; - - fn prune_strategy(&self) -> Self::PruneStrategy; - fn search_strategy(&self) -> Self::SearchStrategy; - fn search_post_processor(&self) -> Self::SearchPostProcessor; - - fn get_delete_element<'a>(/* … */) -> impl Future> + Send; -} - -// --------------------------------------------------------------------------- -// 10. Known pain points for the real implementation -// --------------------------------------------------------------------------- -// -// A. HRTB on HasDefaultProcessor::Processor -// The bound `for<'a> SearchPostProcess, T, O>` -// is the same one that lived on SearchStrategy::PostProcessor before. -// It's not new — it just moved. The has_default_processor! macro -// should absorb this. -// -// B. BetaFilter's generic P delegation -// `impl

PostProcess<…, P, …> for BetaFilter where S: PostProcess<…, P, …>` -// is coherence-safe (concrete wrapper, not a blanket over Self), but verify -// that rustc is happy with the HRTB interaction when SearchAccessor<'a> is -// a layered type (BetaAccessor wrapping the inner accessor). -// -// C. Disk provider (DiskSearchStrategy) -// Today it has PostProcessor = RerankAndFilter. Under the new design: -// - impl HasDefaultProcessor → Processor = RerankAndFilter -// - impl PostProcess<…, RagSearchParams, …> → custom RAG reranking -// These are independent impls with no coherence conflict. -// -// D. Caching provider (CachingAccessor) -// Uses Pipeline today. Same pattern: HasDefaultProcessor -// with Processor = Pipeline. The Pipeline type is just -// another SearchPostProcess impl. -// -// E. The .send() / IntoANNResult bridge -// The blanket impl calls `create_processor().post_process(…).await`. -// The SearchPostProcess::Error needs to be convertible to ANNError. Today -// this is handled via IntoANNResult / .send(). Same pattern applies. From 669dec1aff702a3514adf9d625b486547d9fdc19 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 17:00:43 -0700 Subject: [PATCH 26/47] Restore knn_search.rs to main's structure Remove the unnecessary search_core helper that the PR introduced. Inline the search body back into Knn::search, matching main's structure. The only semantic difference from main is the Option A change: processor is now a method parameter instead of coming from strategy.post_processor(). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- copilot.txt | 1 + diskann/src/graph/search/knn_search.rs | 112 +++---- post_process_design_sketch.rs | 418 +++++++++++++++++++++++++ 3 files changed, 475 insertions(+), 56 deletions(-) create mode 100644 copilot.txt create mode 100644 post_process_design_sketch.rs diff --git a/copilot.txt b/copilot.txt new file mode 100644 index 000000000..722f174ed --- /dev/null +++ b/copilot.txt @@ -0,0 +1 @@ +copilot --resume=c5407796-a927-4cca-9c30-d450692d150a diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 090d8d44b..3af34eae8 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -143,60 +143,6 @@ impl Knn { } } -impl Knn { - /// Shared search core parameterised over the post-processor type. - async fn search_core( - &self, - index: &DiskANNIndex, - strategy: &S, - context: &DP::Context, - query: &T, - output: &mut OB, - post_processor: PP, - ) -> ANNResult - where - DP: DataProvider, - T: Sync + ?Sized, - S: crate::graph::glue::SearchStrategy, - O: Send, - OB: SearchOutputBuffer + Send + ?Sized, - PP: for<'a> SearchPostProcess, T, O> + Send + Sync, - { - let mut accessor = strategy - .search_accessor(&index.data_provider, context) - .into_ann_result()?; - - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; - - let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); - - let stats = index - .search_internal( - Some(self.beam_width.get()), - &start_ids, - &mut accessor, - &computer, - &mut scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - let result_count = post_processor - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(self.l_value.get().into_usize()), - output, - ) - .await - .into_ann_result()?; - - Ok(stats.finish(result_count as u32)) - } -} - impl Search for Knn where DP: DataProvider, @@ -206,7 +152,31 @@ where { type Output = SearchStats; - /// Execute the k-NN search on the given index using the default post-processor. + /// Execute the k-NN search on the given index. + /// + /// This method executes a search using the provided `strategy` to access and process elements. + /// It computes the similarity between the query vector and the elements in the index, traversing + /// the graph towards the nearest neighbors according to the search parameters. + /// + /// # Arguments + /// + /// * `index` - The DiskANN index to search. + /// * `strategy` - The search strategy to use for accessing and processing elements. + /// * `processor` - The post-processor to apply to the search results. + /// * `context` - The context to pass through to providers. + /// * `query` - The query vector for which nearest neighbors are sought. + /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + /// + /// # Returns + /// + /// Returns [`SearchStats`] containing: + /// - The number of distance computations performed. + /// - The number of hops (graph traversal steps). + /// - Timing information for the search operation. + /// + /// # Errors + /// + /// Returns an error if there is a failure accessing elements or computing distances. fn search( self, index: &DiskANNIndex, @@ -221,8 +191,38 @@ where OB: SearchOutputBuffer + Send + ?Sized, { async move { - self.search_core(index, strategy, context, query, output, processor) + let mut accessor = strategy + .search_accessor(&index.data_provider, context) + .into_ann_result()?; + + let computer = accessor.build_query_computer(query).into_ann_result()?; + let start_ids = accessor.starting_points().await?; + + let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); + + let stats = index + .search_internal( + Some(self.beam_width.get()), + &start_ids, + &mut accessor, + &computer, + &mut scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + let result_count = processor + .post_process( + &mut accessor, + query, + &computer, + scratch.best.iter().take(self.l_value.get().into_usize()), + output, + ) .await + .into_ann_result()?; + + Ok(stats.finish(result_count as u32)) } } } diff --git a/post_process_design_sketch.rs b/post_process_design_sketch.rs new file mode 100644 index 000000000..994bff558 --- /dev/null +++ b/post_process_design_sketch.rs @@ -0,0 +1,418 @@ +// ============================================================================= +// Post-Processing Redesign: Sketch & Rationale +// ============================================================================= +// +// Context +// ------- +// Two competing PRs attempted to refactor how SearchStrategy interacts with +// post-processing. Both had structural problems: +// +// Exhibit-A kept `type PostProcessor` on SearchStrategy and layered a new +// `PostProcess` trait on top. This created two +// parallel "what's the post-processor?" answers on the same type that could +// silently diverge. The GAT associated type became dead weight that every +// implementor still had to fill in. +// +// Exhibit-B removed `PostProcessor` from SearchStrategy (good), but replaced +// it with a `DelegatePostProcess` marker whose blanket impl covered *all* +// processor types `P` at once: +// +// impl PostProcess for S +// where S: SearchStrategy<…> + DelegatePostProcess, +// P: for<'a> SearchPostProcess, T, O> + … +// +// This makes it impossible to override `PostProcess` for a specific `P` +// without opting out of the blanket entirely (removing DelegatePostProcess), +// which then forces manual impls for every processor type — an all-or-nothing +// cliff. It also provided no `KnnWith`-style mechanism for callers to supply +// a custom processor at the search call-site. +// +// Proposed Design +// --------------- +// Flip the blanket. Instead of "strategy S gets PostProcess for all P", +// make it "the DefaultPostProcess ZST gets support for all strategies S +// that opt in via HasDefaultProcessor". +// +// The blanket is narrow (covers exactly one P = DefaultPostProcess), so custom +// PostProcess<…, RagSearchParams, …> impls are coherence-safe. Strategies +// that don't need a default can skip HasDefaultProcessor and still be used via +// KnnWith with an explicit processor. +// +// ============================================================================= +// +// How to read this file +// --------------------- +// This is pseudocode — it won't compile. Signatures use real Rust syntax where +// possible but elide lifetimes, bounds, and async machinery for clarity. +// Comments marked "NOTE" call out places where the real implementation will +// need careful attention to HRTB / GAT interactions. +// +// ============================================================================= + +// --------------------------------------------------------------------------- +// 1. SearchStrategy — clean, no post-processing knowledge +// --------------------------------------------------------------------------- +// +// This is the same as today minus `type PostProcessor` and `fn post_processor`. + +pub trait SearchStrategy::InternalId>: + Send + Sync +where + Provider: DataProvider, + T: ?Sized, + O: Send, +{ + type QueryComputer: /* PreprocessedDistanceFunction bounds */ Send + Sync + 'static; + type SearchAccessorError: StandardError; + + // NOTE: This GAT is the source of most HRTB complexity downstream. + type SearchAccessor<'a>: ExpandBeam + + SearchExt; + + fn search_accessor<'a>( + &'a self, + provider: &'a Provider, + context: &'a Provider::Context, + ) -> Result, Self::SearchAccessorError>; +} + +// --------------------------------------------------------------------------- +// 2. SearchPostProcess — unchanged from today +// --------------------------------------------------------------------------- +// +// Low-level trait, parameterized by the *accessor* (not the strategy). +// CopyIds, Rerank, Pipeline, RemoveDeletedIdsAndCopy, etc. all +// implement this directly. No changes needed here. + +pub trait SearchPostProcess::Id> +where + A: BuildQueryComputer, + T: ?Sized, +{ + type Error: StandardError; + + fn post_process( + &self, + accessor: &mut A, + query: &T, + computer: &>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized; +} + +// Pipeline, CopyIds, FilterStartPoints, SearchPostProcessStep — all unchanged. + +// --------------------------------------------------------------------------- +// 3. PostProcess — strategy-level bridge, parameterized by processor P +// --------------------------------------------------------------------------- +// +// This trait connects a strategy to a specific processor type. It is the +// surface that the search infrastructure (Knn, KnnWith, RecordedKnn, etc.) +// bounds on. + +pub trait PostProcess::InternalId>: + SearchStrategy +where + Provider: DataProvider, + T: ?Sized, + O: Send, + P: Send + Sync, +{ + fn post_process_with<'a, I, B>( + &self, + processor: &P, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized; +} + +// --------------------------------------------------------------------------- +// 4. HasDefaultProcessor — opt-in "I have a default post-processor" +// --------------------------------------------------------------------------- +// +// Strategies that want to work with Knn (no explicit processor) implement this. +// It replaces the old `type PostProcessor` on SearchStrategy. +// +// NOTE: The `for<'a> SearchPostProcess, T, O>` HRTB +// bound is the same one that lived on SearchStrategy::PostProcessor today. +// It's not new complexity — it just moved here. + +pub trait HasDefaultProcessor::InternalId>: + SearchStrategy +where + Provider: DataProvider, + T: ?Sized, + O: Send, +{ + type Processor: for<'a> SearchPostProcess, T, O> + + Send + + Sync; + + fn create_processor(&self) -> Self::Processor; +} + +// Convenience macro (same idea as exhibit-B's has_default_processor!). +macro_rules! has_default_processor { + ($Processor:ty) => { + type Processor = $Processor; + fn create_processor(&self) -> Self::Processor { + Default::default() + } + }; +} + +// --------------------------------------------------------------------------- +// 5. DefaultPostProcess ZST + THE blanket impl +// --------------------------------------------------------------------------- +// +// KEY DESIGN POINT: The blanket covers exactly P = DefaultPostProcess. +// Custom processor types (RagSearchParams, etc.) are free to have their own +// `impl PostProcess<…, RagSearchParams, …> for MyStrategy` without any +// coherence conflict. + +#[derive(Debug, Default, Clone, Copy)] +pub struct DefaultPostProcess; + +impl PostProcess for S +where + S: HasDefaultProcessor, + Provider: DataProvider, + T: ?Sized + Sync, + O: Send, +{ + async fn post_process_with<'a, I, B>( + &self, + _processor: &DefaultPostProcess, + accessor: &mut Self::SearchAccessor<'a>, + query: &T, + computer: &Self::QueryComputer, + candidates: I, + output: &mut B, + ) -> ANNResult + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + self.create_processor() + .post_process(accessor, query, computer, candidates, output) + .await + .into_ann_result() + } +} + +// --------------------------------------------------------------------------- +// 6. Search API split: Knn vs KnnWith +// --------------------------------------------------------------------------- +// +// Knn uses the default processor. KnnWith allows an explicit override. +// Both delegate to a shared `search_core` that is parameterized over PP. + +impl Knn { + /// Shared core — the only axis of variation is the processor. + async fn search_core( + &self, + index: &DiskANNIndex, + strategy: &S, + /* … */ + post_processor: &PP, + ) -> ANNResult + where + S: PostProcess, + PP: Send + Sync, + /* … */ + { + let mut accessor = strategy.search_accessor(/* … */)?; + let computer = accessor.build_query_computer(query)?; + /* … search_internal … */ + let count = strategy + .post_process_with(post_processor, &mut accessor, query, &computer, candidates, output) + .await?; + Ok(stats.finish(count as u32)) + } +} + +// Knn: uses DefaultPostProcess +impl Search for Knn +where + S: PostProcess, + // equivalently: S: HasDefaultProcessor +{ + fn search(self, /* … */) -> impl SendFuture> { + async move { + self.search_core(/* … */, &DefaultPostProcess).await + } + } +} + +// KnnWith: uses caller-supplied processor +pub struct KnnWith { + inner: Knn, + post_processor: PP, +} + +impl Search for KnnWith +where + S: PostProcess, + PP: Send + Sync, +{ + fn search(self, /* … */) -> impl SendFuture> { + async move { + self.inner + .search_core(/* … */, &self.post_processor) + .await + } + } +} + +// --------------------------------------------------------------------------- +// 7. Example: implementing a strategy +// --------------------------------------------------------------------------- + +struct MyStrategy { /* … */ } + +impl SearchStrategy for MyStrategy { + type QueryComputer = MyComputer; + type SearchAccessorError = ANNError; + type SearchAccessor<'a> = MyAccessor<'a>; + + fn search_accessor<'a>(/* … */) -> Result, ANNError> { /* … */ } + // No PostProcessor, no post_processor() — clean. +} + +// Opt in to the default: "my default post-processor is CopyIds" +impl HasDefaultProcessor for MyStrategy { + has_default_processor!(CopyIds); +} +// That's it — Knn now works with MyStrategy. + +// Opt in to RAG reranking too (no coherence conflict!): +impl PostProcess for MyStrategy { + async fn post_process_with( + &self, + processor: &RagSearchParams, + accessor: &mut MyAccessor<'_>, + /* … */ + ) -> ANNResult { + // Custom RAG logic here + } +} +// Now `KnnWith::new(knn, rag_params)` also works with MyStrategy. + +// --------------------------------------------------------------------------- +// 8. Decorator strategies (BetaFilter) +// --------------------------------------------------------------------------- +// +// BetaFilter wraps an inner strategy and delegates. The PostProcess<…, P, …> +// impl is generic over P, which is coherence-safe because it's on a concrete +// wrapper type (not a blanket over Self). + +impl PostProcess + for BetaFilter +where + Strategy: PostProcess, + P: Send + Sync, + /* … other bounds … */ +{ + async fn post_process_with( + &self, + processor: &P, + accessor: &mut Self::SearchAccessor<'_>, + /* … */ + ) -> ANNResult { + // Unwrap the layered accessor, delegate to inner strategy + self.strategy + .post_process_with(processor, &mut accessor.inner, /* … */) + .await + } +} + +impl HasDefaultProcessor + for BetaFilter +where + Strategy: HasDefaultProcessor, + /* … */ +{ + type Processor = Strategy::Processor; + fn create_processor(&self) -> Self::Processor { + self.strategy.create_processor() + } +} + +// --------------------------------------------------------------------------- +// 9. InplaceDeleteStrategy +// --------------------------------------------------------------------------- +// +// The delete-search phase needs exactly one processor type. The associated +// type pins it, and the SearchStrategy bound requires PostProcess for that +// specific type. +// +// NOTE: The double `for<'a>` bound is verbose but unavoidable given the GAT. + +pub trait InplaceDeleteStrategy: Send + Sync + 'static +where + Provider: DataProvider, +{ + type DeleteElement<'a>: Send + Sync + ?Sized; + type DeleteElementGuard: /* … AsyncLower … */ + 'static; + type DeleteElementError: StandardError; + type PruneStrategy: PruneStrategy; + + /// The processor used during the delete-search phase. + type SearchPostProcessor: Send + Sync; + + /// The search strategy, which must support PostProcess with the above processor. + type SearchStrategy: for<'a> SearchStrategy> + + for<'a> PostProcess< + Provider, + Self::DeleteElement<'a>, + Self::SearchPostProcessor, + >; + + fn prune_strategy(&self) -> Self::PruneStrategy; + fn search_strategy(&self) -> Self::SearchStrategy; + fn search_post_processor(&self) -> Self::SearchPostProcessor; + + fn get_delete_element<'a>(/* … */) -> impl Future> + Send; +} + +// --------------------------------------------------------------------------- +// 10. Known pain points for the real implementation +// --------------------------------------------------------------------------- +// +// A. HRTB on HasDefaultProcessor::Processor +// The bound `for<'a> SearchPostProcess, T, O>` +// is the same one that lived on SearchStrategy::PostProcessor before. +// It's not new — it just moved. The has_default_processor! macro +// should absorb this. +// +// B. BetaFilter's generic P delegation +// `impl

PostProcess<…, P, …> for BetaFilter where S: PostProcess<…, P, …>` +// is coherence-safe (concrete wrapper, not a blanket over Self), but verify +// that rustc is happy with the HRTB interaction when SearchAccessor<'a> is +// a layered type (BetaAccessor wrapping the inner accessor). +// +// C. Disk provider (DiskSearchStrategy) +// Today it has PostProcessor = RerankAndFilter. Under the new design: +// - impl HasDefaultProcessor → Processor = RerankAndFilter +// - impl PostProcess<…, RagSearchParams, …> → custom RAG reranking +// These are independent impls with no coherence conflict. +// +// D. Caching provider (CachingAccessor) +// Uses Pipeline today. Same pattern: HasDefaultProcessor +// with Processor = Pipeline. The Pipeline type is just +// another SearchPostProcess impl. +// +// E. The .send() / IntoANNResult bridge +// The blanket impl calls `create_processor().post_process(…).await`. +// The SearchPostProcess::Error needs to be convertible to ANNError. Today +// this is handled via IntoANNResult / .send(). Same pattern applies. From 46f05b3d8cb384c5ec47617fa28edae99166fcdc Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 18:25:44 -0700 Subject: [PATCH 27/47] Almost there. --- copilot.txt | 1 - .../src/search/graph/range.rs | 8 +- .../src/search/provider/disk_provider.rs | 2 +- diskann-garnet/src/provider.rs | 6 +- diskann-providers/src/index/diskann_async.rs | 33 ++-- .../graph/provider/async_/caching/provider.rs | 1 + diskann/src/graph/index.rs | 5 +- diskann/src/graph/mod.rs | 2 +- diskann/src/graph/search/mod.rs | 7 +- diskann/src/graph/search/range_search.rs | 185 +++++++++++++----- diskann/src/graph/search_output_buffer.rs | 21 -- diskann/src/neighbor/mod.rs | 52 +++++ 12 files changed, 217 insertions(+), 106 deletions(-) delete mode 100644 copilot.txt diff --git a/copilot.txt b/copilot.txt deleted file mode 100644 index 722f174ed..000000000 --- a/copilot.txt +++ /dev/null @@ -1 +0,0 @@ -copilot --resume=c5407796-a927-4cca-9c30-d450692d150a diff --git a/diskann-benchmark-core/src/search/graph/range.rs b/diskann-benchmark-core/src/search/graph/range.rs index 3f9eb3ed2..2f056a4d3 100644 --- a/diskann-benchmark-core/src/search/graph/range.rs +++ b/diskann-benchmark-core/src/search/graph/range.rs @@ -105,20 +105,16 @@ where { let context = DP::Context::default(); let range_search = *parameters; - let result = self + let _ = self .index .search( range_search, self.strategy.get(index)?, &context, self.queries.row(index), - &mut (), + buffer, ) .await?; - buffer.extend(std::iter::zip( - result.ids.into_iter(), - result.distances.into_iter(), - )); Ok(Metrics {}) } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 870541f0c..70438b00d 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -470,7 +470,7 @@ where let load_ids: Box<[_]> = ids.take(io_limit).collect(); self.ensure_loaded(&load_ids)?; - let mut ids = Vec::new(); + let mut ids: Vec = Vec::new(); for i in load_ids { ids.clear(); ids.extend( diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 00cfb16f6..b7f4d7db6 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -4,7 +4,6 @@ */ use dashmap::DashMap; -use diskann::has_default_processor; use diskann::{ ANNError, ANNErrorKind, ANNResult, graph::{ @@ -15,6 +14,7 @@ use diskann::{ PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, }, }, + has_default_processor, neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, @@ -787,10 +787,6 @@ impl SearchStrategy, [T], u32> for FullPrecisio } } -impl HasDefaultProcessor, [T], u32> for FullPrecision { - has_default_processor!(glue::CopyIds); -} - impl PruneStrategy> for FullPrecision { type PruneAccessor<'a> = FullAccessor<'a, T>; type PruneAccessorError = GarnetProviderError; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 69942563b..58a499d43 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -2283,23 +2283,27 @@ pub(crate) mod tests { { // Full Precision Search. let range_search = Range::new(starting_l_value, radius).unwrap(); - let result = index - .search(range_search, &FullPrecision, ctx, query, &mut ()) + let mut results: Vec> = Vec::new(); + let _ = index + .search(range_search, &FullPrecision, ctx, query, &mut results) .await .unwrap(); - assert_range_results_exactly_match(q, >, &result.ids, radius, None); + let ids: Vec = results.iter().map(|n| n.id).collect(); + assert_range_results_exactly_match(q, >, &ids, radius, None); } { // Quantized Search let range_search = Range::new(starting_l_value, radius).unwrap(); - let result = index - .search(range_search, &Hybrid::new(None), ctx, query, &mut ()) + let mut results: Vec> = Vec::new(); + let _ = index + .search(range_search, &Hybrid::new(None), ctx, query, &mut results) .await .unwrap(); - assert_range_results_exactly_match(q, >, &result.ids, radius, None); + let ids: Vec = results.iter().map(|n| n.id).collect(); + assert_range_results_exactly_match(q, >, &ids, radius, None); } { @@ -2316,27 +2320,30 @@ pub(crate) mod tests { 1.0, ) .unwrap(); - let result = index - .search(range_search, &FullPrecision, ctx, query, &mut ()) + let mut results: Vec> = Vec::new(); + let _ = index + .search(range_search, &FullPrecision, ctx, query, &mut results) .await .unwrap(); - assert_range_results_exactly_match(q, >, &result.ids, radius, Some(inner_radius)); + let ids: Vec = results.iter().map(|n| n.id).collect(); + assert_range_results_exactly_match(q, >, &ids, radius, Some(inner_radius)); } { // Test with a lower initial beam to trigger more two-round searches // We don't expect results to exactly match here let range_search = Range::new(lower_l_value, radius).unwrap(); - let result = index - .search(range_search, &FullPrecision, ctx, query, &mut ()) + let mut results: Vec> = Vec::new(); + let _ = index + .search(range_search, &FullPrecision, ctx, query, &mut results) .await .unwrap(); // check that ids don't have duplicates let mut ids_set = std::collections::HashSet::new(); - for id in &result.ids { - assert!(ids_set.insert(*id)); + for n in &results { + assert!(ids_set.insert(n.id)); } } } diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index 007307f15..d107cb684 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -1075,6 +1075,7 @@ where DP: DataProvider, S: InplaceDeleteStrategy, Cached: PruneStrategy>, + for<'a> Cached: SearchStrategy, S::DeleteElement<'a>>, C: AsyncFriendly, { type DeleteElement<'a> = S::DeleteElement<'a>; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index b79b01128..0eaaecd2c 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2138,10 +2138,9 @@ where /// let params = Knn::new(10, 100, None)?; /// let stats = index.search(params, &strategy, &context, &query, &mut output).await?; /// - /// // Range search (note: uses () as output buffer, results in Output type) + /// // Range search (results written to output buffer) /// let params = Range::new(100, 0.5)?; - /// let result = index.search(params, &strategy, &context, &query, &mut ()).await?; - /// // result.ids and result.distances contain the matches + /// let stats = index.search(params, &strategy, &context, &query, &mut output).await?; /// ``` pub fn search( &self, diff --git a/diskann/src/graph/mod.rs b/diskann/src/graph/mod.rs index e88ad844c..04b4668a6 100644 --- a/diskann/src/graph/mod.rs +++ b/diskann/src/graph/mod.rs @@ -31,7 +31,7 @@ pub mod search; // Re-export the Search trait and error/output types only. // Search parameter types (Knn, Range, Diverse, etc.) should be accessed via `graph::search::`. -pub use search::{KnnSearchError, RangeSearchError, RangeSearchOutput, Search}; +pub use search::{KnnSearchError, RangeSearchError, Search}; mod internal; diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index a50104ea5..a7235d7ff 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -20,8 +20,7 @@ //! //! // Range search //! let params = Range::new(100, 0.5)?; -//! let result = index.search(params, &strategy, &context, &query, &mut ()).await?; -//! println!("Found {} points within radius", result.ids.len()); +//! let stats = index.search(params, &strategy, &context, &query, &mut output).await?; //! ``` use diskann_utils::future::SendFuture; @@ -74,7 +73,7 @@ where /// # Returns /// /// Returns `Self::Output` which varies by search type (e.g., [`SearchStats`](super::index::SearchStats) - /// for k-NN, [`RangeSearchOutput`] for range search). + /// for k-NN and range search). /// /// # Errors /// @@ -97,7 +96,7 @@ where pub use knn_search::{Knn, KnnSearchError, RecordedKnn}; pub use multihop_search::MultihopSearch; -pub use range_search::{Range, RangeSearchError, RangeSearchOutput}; +pub use range_search::{Range, RangeSearchError}; // Feature-gated diverse search. #[cfg(feature = "experimental_diversity_search")] diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index f4bc6bb01..3453f9f87 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -23,16 +23,6 @@ use crate::{ utils::IntoUsize, }; -/// Result from a range search operation. -pub struct RangeSearchOutput { - /// Search statistics. - pub stats: SearchStats, - /// IDs of points within the radius. - pub ids: Vec, - /// Distances corresponding to each ID. - pub distances: Vec, -} - /// Error type for [`Range`] parameter validation. #[derive(Debug, Error)] pub enum RangeSearchError { @@ -174,7 +164,7 @@ where T: Sync + ?Sized, O: Send + Default + Clone, { - type Output = RangeSearchOutput; + type Output = SearchStats; fn search( self, @@ -246,62 +236,79 @@ where initial_stats }; - // Post-process results - let mut result_ids: Vec = vec![O::default(); scratch.in_range.len()]; - let mut result_dists: Vec = vec![f32::MAX; scratch.in_range.len()]; - - let mut output_buffer = search_output_buffer::IdDistance::new( - result_ids.as_mut_slice(), - result_dists.as_mut_slice(), - ); + // Post-process results directly into the output buffer, filtering by radius. + let radius = self.radius(); + let inner_radius = self.inner_radius(); + let mut filtered = DistanceFiltered::new(output, |dist| { + if let Some(ir) = inner_radius + && dist <= ir + { + return false; + } + dist < radius + }); - let _ = processor + let result_count = processor .post_process( &mut accessor, query, &computer, scratch.in_range.iter().copied(), - &mut output_buffer, + &mut filtered, ) .await .into_ann_result()?; - // Filter by inner/outer radius - let inner_cutoff = if let Some(inner_radius) = self.inner_radius() { - result_dists - .iter() - .position(|dist| *dist > inner_radius) - .unwrap_or(result_dists.len()) - } else { - 0 - }; + Ok(SearchStats { + cmps: stats.cmps, + hops: stats.hops, + result_count: result_count as u32, + range_search_second_round: stats.range_search_second_round, + }) + } + } +} - let outer_cutoff = result_dists - .iter() - .position(|dist| *dist > self.radius()) - .unwrap_or(result_dists.len()); +/// A [`SearchOutputBuffer`] wrapper that filters results by distance before +/// forwarding them to an inner buffer. +struct DistanceFiltered<'a, F, B: ?Sized> { + predicate: F, + inner: &'a mut B, +} - result_ids.truncate(outer_cutoff); - result_ids.drain(0..inner_cutoff); +impl<'a, F, B: ?Sized> DistanceFiltered<'a, F, B> { + fn new(inner: &'a mut B, predicate: F) -> Self { + Self { predicate, inner } + } +} - result_dists.truncate(outer_cutoff); - result_dists.drain(0..inner_cutoff); +impl SearchOutputBuffer for DistanceFiltered<'_, F, B> +where + F: FnMut(f32) -> bool, + B: SearchOutputBuffer + ?Sized, +{ + fn size_hint(&self) -> Option { + self.inner.size_hint() + } - let result_count = result_ids.len(); + fn push(&mut self, id: I, distance: f32) -> search_output_buffer::BufferState { + if (self.predicate)(distance) { + self.inner.push(id, distance) + } else { + search_output_buffer::BufferState::Available + } + } - let _ = output.extend(result_ids.iter().cloned().zip(result_dists.iter().copied())); + fn current_len(&self) -> usize { + self.inner.current_len() + } - Ok(RangeSearchOutput { - stats: SearchStats { - cmps: stats.cmps, - hops: stats.hops, - result_count: result_count as u32, - range_search_second_round: stats.range_search_second_round, - }, - ids: result_ids, - distances: result_dists, - }) - } + fn extend(&mut self, itr: Itr) -> usize + where + Itr: IntoIterator, + { + self.inner + .extend(itr.into_iter().filter(|(_, dist)| (self.predicate)(*dist))) } } @@ -384,6 +391,8 @@ where #[cfg(test)] mod tests { use super::*; + use crate::graph::search_output_buffer::BufferState; + use crate::neighbor::Neighbor; #[test] fn test_range_search_validation() { @@ -400,4 +409,78 @@ mod tests { // Invalid inner radius > radius assert!(Range::with_options(None, 100, None, 0.5, Some(1.0), 1.0, 1.0).is_err()); } + + #[test] + fn distance_filtered_push_accepts_passing_items() { + let mut inner: Vec> = Vec::new(); + let mut filtered = DistanceFiltered::new(&mut inner, |d| d < 1.0); + + assert_eq!(filtered.push(1, 0.5), BufferState::Available); + assert_eq!(filtered.current_len(), 1); + assert_eq!(inner[0].id, 1); + assert_eq!(inner[0].distance, 0.5); + } + + #[test] + fn distance_filtered_push_rejects_failing_items() { + let mut inner: Vec> = Vec::new(); + let mut filtered = DistanceFiltered::new(&mut inner, |d| d < 1.0); + + assert_eq!(filtered.push(1, 1.5), BufferState::Available); + assert_eq!(filtered.current_len(), 0); + } + + #[test] + fn distance_filtered_extend_filters_correctly() { + let mut inner: Vec> = Vec::new(); + let mut filtered = DistanceFiltered::new(&mut inner, |d| d < 1.0); + assert!(filtered.size_hint().is_none()); + + let items = vec![(1u32, 0.3), (2, 1.5), (3, 0.7), (4, 2.0), (5, 0.9)]; + let count = filtered.extend(items); + + assert_eq!(count, 3); + assert_eq!(inner.len(), 3); + assert_eq!(inner[0].id, 1); + assert_eq!(inner[1].id, 3); + assert_eq!(inner[2].id, 5); + } + + #[test] + fn distance_filtered_respects_inner_capacity() { + let mut ids = [0u32; 2]; + let mut dists = [0.0f32; 2]; + let mut inner = search_output_buffer::IdDistance::new(&mut ids, &mut dists); + let mut filtered = DistanceFiltered::new(&mut inner, |d| d < 1.0); + assert_eq!(filtered.size_hint(), Some(2)); + + let items = vec![(1u32, 0.1), (2, 0.2), (3, 0.3)]; + let count = filtered.extend(items); + + assert_eq!(count, 2); + assert_eq!(ids, [1, 2]); + } + + #[test] + fn distance_filtered_inner_radius_pattern() { + let mut inner: Vec> = Vec::new(); + let radius = 1.0f32; + let inner_radius = Some(0.3f32); + let mut filtered = DistanceFiltered::new(&mut inner, |dist| { + if let Some(ir) = inner_radius + && dist <= ir + { + return false; + } + dist < radius + }); + + let items = vec![(1u32, 0.1), (2, 0.5), (3, 0.3), (4, 1.0), (5, 0.8)]; + let count = filtered.extend(items); + + // 0.1 and 0.3 are <= inner_radius, 1.0 is not < radius + assert_eq!(count, 2); + assert_eq!(inner[0].id, 2); + assert_eq!(inner[1].id, 5); + } } diff --git a/diskann/src/graph/search_output_buffer.rs b/diskann/src/graph/search_output_buffer.rs index 8d3469ea8..fddd7c73b 100644 --- a/diskann/src/graph/search_output_buffer.rs +++ b/diskann/src/graph/search_output_buffer.rs @@ -36,27 +36,6 @@ pub trait SearchOutputBuffer { Itr: IntoIterator; } -impl SearchOutputBuffer for () { - fn size_hint(&self) -> Option { - None - } - - fn push(&mut self, _id: I, _distance: D) -> BufferState { - BufferState::Available - } - - fn current_len(&self) -> usize { - 0 - } - - fn extend(&mut self, _itr: Itr) -> usize - where - Itr: IntoIterator, - { - 0 - } -} - /// Indicate whether future calls to [`SearchOutputBuffer::push`] will succeed or not. #[derive(Debug, Clone, Copy, PartialEq)] #[must_use = "This type indicates whether the output buffer is full or not."] diff --git a/diskann/src/neighbor/mod.rs b/diskann/src/neighbor/mod.rs index 29ee87981..f9f769fc4 100644 --- a/diskann/src/neighbor/mod.rs +++ b/diskann/src/neighbor/mod.rs @@ -195,6 +195,36 @@ where } } +impl SearchOutputBuffer for Vec> +where + I: Default + Eq, +{ + fn size_hint(&self) -> Option { + None + } + + fn push(&mut self, id: I, distance: f32) -> search_output_buffer::BufferState { + self.push(Neighbor::new(id, distance)); + search_output_buffer::BufferState::Available + } + + fn current_len(&self) -> usize { + self.len() + } + + fn extend(&mut self, itr: Itr) -> usize + where + Itr: IntoIterator, + { + let before = self.len(); + Extend::extend( + self, + itr.into_iter().map(|(id, dist)| Neighbor::new(id, dist)), + ); + self.len() - before + } +} + #[cfg(test)] mod neighbor_test { use super::*; @@ -346,4 +376,26 @@ mod neighbor_test { assert_eq!(&buffer, &[f(1), f(2), f(3), f(4), f(5)]); } } + + #[test] + fn test_vec_neighbor_search_output_buffer() { + use crate::graph::search_output_buffer::SearchOutputBuffer; + + let mut buf: Vec> = Vec::new(); + assert_eq!(SearchOutputBuffer::::size_hint(&buf), None); + assert_eq!(SearchOutputBuffer::::current_len(&buf), 0); + + // push grows unboundedly + assert!(SearchOutputBuffer::push(&mut buf, 1, 0.5).is_available()); + assert!(SearchOutputBuffer::push(&mut buf, 2, 1.0).is_available()); + assert_eq!(SearchOutputBuffer::::current_len(&buf), 2); + assert_eq!(buf[0], Neighbor::new(1, 0.5)); + assert_eq!(buf[1], Neighbor::new(2, 1.0)); + + // extend appends and returns count + let count = SearchOutputBuffer::extend(&mut buf, vec![(3u32, 1.5), (4, 2.0), (5, 2.5)]); + assert_eq!(count, 3); + assert_eq!(SearchOutputBuffer::::current_len(&buf), 5); + assert_eq!(buf[4], Neighbor::new(5, 2.5)); + } } From 6dec4bc0c80c6ec8df05c985751e63f5b4ff442c Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 13 Mar 2026 18:44:25 -0700 Subject: [PATCH 28/47] Last cleanups (outside of caching). --- .../graph/provider/async_/bf_tree/provider.rs | 14 +- .../provider/async_/inmem/full_precision.rs | 6 +- .../model/graph/provider/async_/inmem/test.rs | 2 +- diskann/src/graph/index.rs | 13 +- diskann/src/graph/search/knn_search.rs | 4 +- diskann/src/graph/search/mod.rs | 12 +- post_process_design_sketch.rs | 418 ------------------ 7 files changed, 27 insertions(+), 442 deletions(-) delete mode 100644 post_process_design_sketch.rs diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 4a8ca44fc..900aedd90 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -1512,15 +1512,16 @@ where _computer: &pq::distance::QueryComputer>, candidates: I, output: &mut B, - ) -> impl std::future::Future> + Send + ) -> impl Future> + Send where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, + I: Iterator>, + B: SearchOutputBuffer + ?Sized, { let provider = &accessor.provider; let checker = accessor.as_deletion_check(); let f = T::distance(provider.metric, Some(provider.full_vectors.dim())); + // Filter before computing the full precision distances. let mut reranked: Vec<(u32, f32)> = candidates .filter_map(|n| { if checker.deletion_check(n.id) { @@ -1536,16 +1537,15 @@ where }) .collect(); + // Sort the full precision distances. reranked .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. std::future::ready(Ok(output.extend(reranked))) } } /// Perform a search entirely in the quantized space. -/// -/// Starting points are are filtered out of the final results and results are reranked using -/// the full-precision data. impl SearchStrategy, [T]> for Hybrid where T: VectorRepr, @@ -1564,6 +1564,8 @@ where } } +/// Starting points are filtered out of the final results and results are reranked using +/// the full-precision data. impl HasDefaultProcessor, [T]> for Hybrid where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 87679154d..dec10cac4 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -395,10 +395,10 @@ where _computer: &A::QueryComputer, candidates: I, output: &mut B, - ) -> impl std::future::Future> + Send + ) -> impl Future> + Send where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, + I: Iterator>, + B: SearchOutputBuffer + ?Sized, { let full = accessor.as_full_precision(); let checker = accessor.as_deletion_check(); diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index d3f1a1c50..45f8ee61d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -5,7 +5,6 @@ use std::{future::Future, sync::Mutex}; -use diskann::has_default_processor; use diskann::{ ANNError, ANNResult, error::{RankedError, ToRanked, TransientError}, @@ -13,6 +12,7 @@ use diskann::{ AsElement, CopyIds, ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, + has_default_processor, neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 0eaaecd2c..dc716ec9f 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -24,7 +24,7 @@ use thiserror::Error; use tokio::task::JoinSet; use super::{ - AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, + AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ self, AsElement, ExpandBeam, FillSet, IdIterator, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, aliases, @@ -2151,10 +2151,10 @@ where output: &mut OB, ) -> impl SendFuture> where - P: super::search::Search, + P: Search, S: glue::HasDefaultProcessor, O: Send, - OB: super::search_output_buffer::SearchOutputBuffer + Send + ?Sized, + OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, { let processor = strategy.create_processor(); @@ -2172,11 +2172,11 @@ where output: &mut OB, ) -> impl SendFuture> where - P: super::search::Search, + P: Search, S: glue::SearchStrategy, PP: for<'a> glue::SearchPostProcess, T, O> + Send + Sync, O: Send, - OB: super::search_output_buffer::SearchOutputBuffer + Send + ?Sized, + OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, { search_params.search(self, strategy, processor, context, query, output) @@ -2220,8 +2220,7 @@ where ) -> ANNResult where T: ?Sized, - S: SearchStrategy: IdIterator> - + glue::HasDefaultProcessor, + S: glue::DefaultSearchStrategy: IdIterator>, I: Iterator::InternalId>, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 3af34eae8..1f5cba95d 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -15,7 +15,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{SearchExt, SearchPostProcess}, + glue::{self, SearchExt, SearchPostProcess}, index::{DiskANNIndex, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, @@ -252,7 +252,7 @@ impl<'r, SR: ?Sized> RecordedKnn<'r, SR> { impl<'r, DP, S, T, O, SR> Search for RecordedKnn<'r, SR> where DP: DataProvider, - S: crate::graph::glue::SearchStrategy, + S: glue::SearchStrategy, T: Sync + ?Sized, O: Send, SR: super::record::SearchRecord + ?Sized, diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index a7235d7ff..f804bf80a 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -25,7 +25,11 @@ use diskann_utils::future::SendFuture; -use crate::{ANNResult, graph::index::DiskANNIndex, provider::DataProvider}; +use crate::{ + ANNResult, + graph::{self, index::DiskANNIndex}, + provider::DataProvider, +}; mod knn_search; mod multihop_search; @@ -88,10 +92,8 @@ where output: &mut OB, ) -> impl SendFuture> where - PP: for<'a> crate::graph::glue::SearchPostProcess, T, O> - + Send - + Sync, - OB: crate::graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; + PP: for<'a> graph::glue::SearchPostProcess, T, O> + Send + Sync, + OB: graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; } pub use knn_search::{Knn, KnnSearchError, RecordedKnn}; diff --git a/post_process_design_sketch.rs b/post_process_design_sketch.rs deleted file mode 100644 index 994bff558..000000000 --- a/post_process_design_sketch.rs +++ /dev/null @@ -1,418 +0,0 @@ -// ============================================================================= -// Post-Processing Redesign: Sketch & Rationale -// ============================================================================= -// -// Context -// ------- -// Two competing PRs attempted to refactor how SearchStrategy interacts with -// post-processing. Both had structural problems: -// -// Exhibit-A kept `type PostProcessor` on SearchStrategy and layered a new -// `PostProcess` trait on top. This created two -// parallel "what's the post-processor?" answers on the same type that could -// silently diverge. The GAT associated type became dead weight that every -// implementor still had to fill in. -// -// Exhibit-B removed `PostProcessor` from SearchStrategy (good), but replaced -// it with a `DelegatePostProcess` marker whose blanket impl covered *all* -// processor types `P` at once: -// -// impl PostProcess for S -// where S: SearchStrategy<…> + DelegatePostProcess, -// P: for<'a> SearchPostProcess, T, O> + … -// -// This makes it impossible to override `PostProcess` for a specific `P` -// without opting out of the blanket entirely (removing DelegatePostProcess), -// which then forces manual impls for every processor type — an all-or-nothing -// cliff. It also provided no `KnnWith`-style mechanism for callers to supply -// a custom processor at the search call-site. -// -// Proposed Design -// --------------- -// Flip the blanket. Instead of "strategy S gets PostProcess for all P", -// make it "the DefaultPostProcess ZST gets support for all strategies S -// that opt in via HasDefaultProcessor". -// -// The blanket is narrow (covers exactly one P = DefaultPostProcess), so custom -// PostProcess<…, RagSearchParams, …> impls are coherence-safe. Strategies -// that don't need a default can skip HasDefaultProcessor and still be used via -// KnnWith with an explicit processor. -// -// ============================================================================= -// -// How to read this file -// --------------------- -// This is pseudocode — it won't compile. Signatures use real Rust syntax where -// possible but elide lifetimes, bounds, and async machinery for clarity. -// Comments marked "NOTE" call out places where the real implementation will -// need careful attention to HRTB / GAT interactions. -// -// ============================================================================= - -// --------------------------------------------------------------------------- -// 1. SearchStrategy — clean, no post-processing knowledge -// --------------------------------------------------------------------------- -// -// This is the same as today minus `type PostProcessor` and `fn post_processor`. - -pub trait SearchStrategy::InternalId>: - Send + Sync -where - Provider: DataProvider, - T: ?Sized, - O: Send, -{ - type QueryComputer: /* PreprocessedDistanceFunction bounds */ Send + Sync + 'static; - type SearchAccessorError: StandardError; - - // NOTE: This GAT is the source of most HRTB complexity downstream. - type SearchAccessor<'a>: ExpandBeam - + SearchExt; - - fn search_accessor<'a>( - &'a self, - provider: &'a Provider, - context: &'a Provider::Context, - ) -> Result, Self::SearchAccessorError>; -} - -// --------------------------------------------------------------------------- -// 2. SearchPostProcess — unchanged from today -// --------------------------------------------------------------------------- -// -// Low-level trait, parameterized by the *accessor* (not the strategy). -// CopyIds, Rerank, Pipeline, RemoveDeletedIdsAndCopy, etc. all -// implement this directly. No changes needed here. - -pub trait SearchPostProcess::Id> -where - A: BuildQueryComputer, - T: ?Sized, -{ - type Error: StandardError; - - fn post_process( - &self, - accessor: &mut A, - query: &T, - computer: &>::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized; -} - -// Pipeline, CopyIds, FilterStartPoints, SearchPostProcessStep — all unchanged. - -// --------------------------------------------------------------------------- -// 3. PostProcess — strategy-level bridge, parameterized by processor P -// --------------------------------------------------------------------------- -// -// This trait connects a strategy to a specific processor type. It is the -// surface that the search infrastructure (Knn, KnnWith, RecordedKnn, etc.) -// bounds on. - -pub trait PostProcess::InternalId>: - SearchStrategy -where - Provider: DataProvider, - T: ?Sized, - O: Send, - P: Send + Sync, -{ - fn post_process_with<'a, I, B>( - &self, - processor: &P, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> impl Future> + Send - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized; -} - -// --------------------------------------------------------------------------- -// 4. HasDefaultProcessor — opt-in "I have a default post-processor" -// --------------------------------------------------------------------------- -// -// Strategies that want to work with Knn (no explicit processor) implement this. -// It replaces the old `type PostProcessor` on SearchStrategy. -// -// NOTE: The `for<'a> SearchPostProcess, T, O>` HRTB -// bound is the same one that lived on SearchStrategy::PostProcessor today. -// It's not new complexity — it just moved here. - -pub trait HasDefaultProcessor::InternalId>: - SearchStrategy -where - Provider: DataProvider, - T: ?Sized, - O: Send, -{ - type Processor: for<'a> SearchPostProcess, T, O> - + Send - + Sync; - - fn create_processor(&self) -> Self::Processor; -} - -// Convenience macro (same idea as exhibit-B's has_default_processor!). -macro_rules! has_default_processor { - ($Processor:ty) => { - type Processor = $Processor; - fn create_processor(&self) -> Self::Processor { - Default::default() - } - }; -} - -// --------------------------------------------------------------------------- -// 5. DefaultPostProcess ZST + THE blanket impl -// --------------------------------------------------------------------------- -// -// KEY DESIGN POINT: The blanket covers exactly P = DefaultPostProcess. -// Custom processor types (RagSearchParams, etc.) are free to have their own -// `impl PostProcess<…, RagSearchParams, …> for MyStrategy` without any -// coherence conflict. - -#[derive(Debug, Default, Clone, Copy)] -pub struct DefaultPostProcess; - -impl PostProcess for S -where - S: HasDefaultProcessor, - Provider: DataProvider, - T: ?Sized + Sync, - O: Send, -{ - async fn post_process_with<'a, I, B>( - &self, - _processor: &DefaultPostProcess, - accessor: &mut Self::SearchAccessor<'a>, - query: &T, - computer: &Self::QueryComputer, - candidates: I, - output: &mut B, - ) -> ANNResult - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - self.create_processor() - .post_process(accessor, query, computer, candidates, output) - .await - .into_ann_result() - } -} - -// --------------------------------------------------------------------------- -// 6. Search API split: Knn vs KnnWith -// --------------------------------------------------------------------------- -// -// Knn uses the default processor. KnnWith allows an explicit override. -// Both delegate to a shared `search_core` that is parameterized over PP. - -impl Knn { - /// Shared core — the only axis of variation is the processor. - async fn search_core( - &self, - index: &DiskANNIndex, - strategy: &S, - /* … */ - post_processor: &PP, - ) -> ANNResult - where - S: PostProcess, - PP: Send + Sync, - /* … */ - { - let mut accessor = strategy.search_accessor(/* … */)?; - let computer = accessor.build_query_computer(query)?; - /* … search_internal … */ - let count = strategy - .post_process_with(post_processor, &mut accessor, query, &computer, candidates, output) - .await?; - Ok(stats.finish(count as u32)) - } -} - -// Knn: uses DefaultPostProcess -impl Search for Knn -where - S: PostProcess, - // equivalently: S: HasDefaultProcessor -{ - fn search(self, /* … */) -> impl SendFuture> { - async move { - self.search_core(/* … */, &DefaultPostProcess).await - } - } -} - -// KnnWith: uses caller-supplied processor -pub struct KnnWith { - inner: Knn, - post_processor: PP, -} - -impl Search for KnnWith -where - S: PostProcess, - PP: Send + Sync, -{ - fn search(self, /* … */) -> impl SendFuture> { - async move { - self.inner - .search_core(/* … */, &self.post_processor) - .await - } - } -} - -// --------------------------------------------------------------------------- -// 7. Example: implementing a strategy -// --------------------------------------------------------------------------- - -struct MyStrategy { /* … */ } - -impl SearchStrategy for MyStrategy { - type QueryComputer = MyComputer; - type SearchAccessorError = ANNError; - type SearchAccessor<'a> = MyAccessor<'a>; - - fn search_accessor<'a>(/* … */) -> Result, ANNError> { /* … */ } - // No PostProcessor, no post_processor() — clean. -} - -// Opt in to the default: "my default post-processor is CopyIds" -impl HasDefaultProcessor for MyStrategy { - has_default_processor!(CopyIds); -} -// That's it — Knn now works with MyStrategy. - -// Opt in to RAG reranking too (no coherence conflict!): -impl PostProcess for MyStrategy { - async fn post_process_with( - &self, - processor: &RagSearchParams, - accessor: &mut MyAccessor<'_>, - /* … */ - ) -> ANNResult { - // Custom RAG logic here - } -} -// Now `KnnWith::new(knn, rag_params)` also works with MyStrategy. - -// --------------------------------------------------------------------------- -// 8. Decorator strategies (BetaFilter) -// --------------------------------------------------------------------------- -// -// BetaFilter wraps an inner strategy and delegates. The PostProcess<…, P, …> -// impl is generic over P, which is coherence-safe because it's on a concrete -// wrapper type (not a blanket over Self). - -impl PostProcess - for BetaFilter -where - Strategy: PostProcess, - P: Send + Sync, - /* … other bounds … */ -{ - async fn post_process_with( - &self, - processor: &P, - accessor: &mut Self::SearchAccessor<'_>, - /* … */ - ) -> ANNResult { - // Unwrap the layered accessor, delegate to inner strategy - self.strategy - .post_process_with(processor, &mut accessor.inner, /* … */) - .await - } -} - -impl HasDefaultProcessor - for BetaFilter -where - Strategy: HasDefaultProcessor, - /* … */ -{ - type Processor = Strategy::Processor; - fn create_processor(&self) -> Self::Processor { - self.strategy.create_processor() - } -} - -// --------------------------------------------------------------------------- -// 9. InplaceDeleteStrategy -// --------------------------------------------------------------------------- -// -// The delete-search phase needs exactly one processor type. The associated -// type pins it, and the SearchStrategy bound requires PostProcess for that -// specific type. -// -// NOTE: The double `for<'a>` bound is verbose but unavoidable given the GAT. - -pub trait InplaceDeleteStrategy: Send + Sync + 'static -where - Provider: DataProvider, -{ - type DeleteElement<'a>: Send + Sync + ?Sized; - type DeleteElementGuard: /* … AsyncLower … */ + 'static; - type DeleteElementError: StandardError; - type PruneStrategy: PruneStrategy; - - /// The processor used during the delete-search phase. - type SearchPostProcessor: Send + Sync; - - /// The search strategy, which must support PostProcess with the above processor. - type SearchStrategy: for<'a> SearchStrategy> - + for<'a> PostProcess< - Provider, - Self::DeleteElement<'a>, - Self::SearchPostProcessor, - >; - - fn prune_strategy(&self) -> Self::PruneStrategy; - fn search_strategy(&self) -> Self::SearchStrategy; - fn search_post_processor(&self) -> Self::SearchPostProcessor; - - fn get_delete_element<'a>(/* … */) -> impl Future> + Send; -} - -// --------------------------------------------------------------------------- -// 10. Known pain points for the real implementation -// --------------------------------------------------------------------------- -// -// A. HRTB on HasDefaultProcessor::Processor -// The bound `for<'a> SearchPostProcess, T, O>` -// is the same one that lived on SearchStrategy::PostProcessor before. -// It's not new — it just moved. The has_default_processor! macro -// should absorb this. -// -// B. BetaFilter's generic P delegation -// `impl

PostProcess<…, P, …> for BetaFilter where S: PostProcess<…, P, …>` -// is coherence-safe (concrete wrapper, not a blanket over Self), but verify -// that rustc is happy with the HRTB interaction when SearchAccessor<'a> is -// a layered type (BetaAccessor wrapping the inner accessor). -// -// C. Disk provider (DiskSearchStrategy) -// Today it has PostProcessor = RerankAndFilter. Under the new design: -// - impl HasDefaultProcessor → Processor = RerankAndFilter -// - impl PostProcess<…, RagSearchParams, …> → custom RAG reranking -// These are independent impls with no coherence conflict. -// -// D. Caching provider (CachingAccessor) -// Uses Pipeline today. Same pattern: HasDefaultProcessor -// with Processor = Pipeline. The Pipeline type is just -// another SearchPostProcess impl. -// -// E. The .send() / IntoANNResult bridge -// The blanket impl calls `create_processor().post_process(…).await`. -// The SearchPostProcess::Error needs to be convertible to ANNError. Today -// this is handled via IntoANNResult / .send(). Same pattern applies. From 1cc28163e304f6f5a34166ea1864e4c3c04fc424 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 12:56:13 -0700 Subject: [PATCH 29/47] Renames. --- .../src/search/provider/disk_provider.rs | 4 ++-- diskann-garnet/src/provider.rs | 9 ++++----- .../inline_beta_search/inline_beta_filter.rs | 6 +++--- .../graph/provider/async_/bf_tree/provider.rs | 13 ++++++------- .../graph/provider/async_/caching/provider.rs | 6 +++--- .../graph/provider/async_/debug_provider.rs | 13 ++++++------- .../provider/async_/inmem/full_precision.rs | 8 ++++---- .../graph/provider/async_/inmem/product.rs | 12 ++++++------ .../graph/provider/async_/inmem/scalar.rs | 12 ++++++------ .../graph/provider/async_/inmem/spherical.rs | 12 ++++++------ .../model/graph/provider/async_/inmem/test.rs | 9 ++++----- .../model/graph/provider/layers/betafilter.rs | 10 +++++----- diskann/src/graph/glue.rs | 18 +++++++++--------- diskann/src/graph/index.rs | 2 +- diskann/src/graph/test/provider.rs | 7 +++---- 15 files changed, 68 insertions(+), 73 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 70438b00d..cffc22f10 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -19,7 +19,7 @@ use diskann::{ graph::{ self, glue::{ - self, ExpandBeam, HasDefaultProcessor, IdIterator, SearchExt, SearchPostProcess, + self, DefaultPostProcessor, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy, }, search::Knn, @@ -371,7 +371,7 @@ where } impl<'this, Data, ProviderFactory> - HasDefaultProcessor< + DefaultPostProcessor< DiskProvider, [Data::VectorDataType], ( diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index b7f4d7db6..21825ff69 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -5,16 +5,15 @@ use dashmap::DashMap; use diskann::{ - ANNError, ANNErrorKind, ANNResult, + ANNError, ANNErrorKind, ANNResult, default_post_processor, graph::{ AdjacencyList, SearchOutputBuffer, config::defaults::MAX_OCCLUSION_SIZE, glue::{ - self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + self, DefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, }, }, - has_default_processor, neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, @@ -769,8 +768,8 @@ impl SearchStrategy, [T], GarnetId> for FullPre } } -impl HasDefaultProcessor, [T], GarnetId> for FullPrecision { - has_default_processor!(glue::Pipeline); +impl DefaultPostProcessor, [T], GarnetId> for FullPrecision { + default_post_processor!(glue::Pipeline); } impl SearchStrategy, [T], u32> for FullPrecision { diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index ea509b3ff..25ddd22c3 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -61,16 +61,16 @@ where } } -/// [`HasDefaultProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps +/// [`DefaultPostProcessor`] delegation for [`InlineBetaStrategy`]. The processor wraps /// the inner strategy's default processor with [`FilterResults`]. impl - diskann::graph::glue::HasDefaultProcessor< + diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, FilteredQuery, > for InlineBetaStrategy where DP: DataProvider, - Strategy: diskann::graph::glue::HasDefaultProcessor, + Strategy: diskann::graph::glue::DefaultPostProcessor, Q: AsyncFriendly + Clone, { type Processor = FilterResults; diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 900aedd90..25f4e2dd6 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -17,16 +17,15 @@ use serde::{Deserialize, Serialize}; use bf_tree::{BfTree, Config}; use diskann::{ - ANNError, ANNResult, + ANNError, ANNResult, default_post_processor, error::IntoANNResult, graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ - self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + self, DefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, }, - has_default_processor, neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, @@ -1485,13 +1484,13 @@ where } } -impl HasDefaultProcessor, [T]> for FullPrecision +impl DefaultPostProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, { - has_default_processor!(glue::Pipeline); + default_post_processor!(glue::Pipeline); } /// An [`glue::SearchPostProcess`] implementation that reranks PQ vectors. @@ -1566,12 +1565,12 @@ where /// Starting points are filtered out of the final results and results are reranked using /// the full-precision data. -impl HasDefaultProcessor, [T]> for Hybrid +impl DefaultPostProcessor, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - has_default_processor!(glue::Pipeline); + default_post_processor!(glue::Pipeline); } // Pruning diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index d107cb684..087122efe 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -980,13 +980,13 @@ where } } -/// [`HasDefaultProcessor`] delegation for [`Cached`]. The processor is composed by +/// [`DefaultPostProcessor`] delegation for [`Cached`]. The processor is composed by /// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. -impl glue::HasDefaultProcessor, T> for Cached +impl glue::DefaultPostProcessor, T> for Cached where T: ?Sized, DP: DataProvider, - S: glue::HasDefaultProcessor + S: glue::DefaultPostProcessor + for<'a> SearchStrategy: CacheableAccessor>, C: for<'a> AsCacheAccessorFor< 'a, diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index d75b2afdb..492d8f8d6 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -11,13 +11,12 @@ use std::{ }, }; -use diskann::has_default_processor; use diskann::{ - ANNError, ANNErrorKind, ANNResult, + ANNError, ANNErrorKind, ANNResult, default_post_processor, graph::{ AdjacencyList, glue::{ - AsElement, ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, + AsElement, DefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, }, @@ -902,8 +901,8 @@ impl SearchStrategy for FullPrecision { } } -impl HasDefaultProcessor for FullPrecision { - has_default_processor!(Pipeline); +impl DefaultPostProcessor for FullPrecision { + default_post_processor!(Pipeline); } impl SearchStrategy for Quantized { @@ -920,8 +919,8 @@ impl SearchStrategy for Quantized { } } -impl HasDefaultProcessor for Quantized { - has_default_processor!(Pipeline); +impl DefaultPostProcessor for Quantized { + default_post_processor!(Pipeline); } impl PruneStrategy for FullPrecision { diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index dec10cac4..2de024bbe 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -5,13 +5,13 @@ use std::{collections::HashMap, fmt::Debug, future::Future}; -use diskann::has_default_processor; +use diskann::default_post_processor; use diskann::{ ANNError, ANNResult, graph::{ SearchOutputBuffer, glue::{ - self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + self, DefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, }, @@ -453,14 +453,14 @@ where } } -impl HasDefaultProcessor, [T]> for FullPrecision +impl DefaultPostProcessor, [T]> for FullPrecision where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(glue::Pipeline); + default_post_processor!(glue::Pipeline); } // Pruning diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 1c8bf7ebd..53cb2e827 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -5,11 +5,11 @@ use std::{collections::HashMap, future::Future, sync::Arc}; -use diskann::has_default_processor; +use diskann::default_post_processor; use diskann::{ ANNError, ANNResult, graph::glue::{ - self, ExpandBeam, FillSet, HasDefaultProcessor, InplaceDeleteStrategy, InsertStrategy, + self, DefaultPostProcessor, ExpandBeam, FillSet, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -483,13 +483,13 @@ where } } -impl HasDefaultProcessor, [T]> for Hybrid +impl DefaultPostProcessor, [T]> for Hybrid where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(glue::Pipeline); + default_post_processor!(glue::Pipeline); } impl PruneStrategy> for Hybrid @@ -604,14 +604,14 @@ where } } -impl HasDefaultProcessor, [T]> +impl DefaultPostProcessor, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(glue::Pipeline); + default_post_processor!(glue::Pipeline); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index e82f70cb7..0cfc6ef49 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -6,11 +6,11 @@ use std::{future::Future, sync::Mutex}; use crate::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann::has_default_processor; +use diskann::default_post_processor; use diskann::{ ANNError, ANNResult, graph::glue::{ - ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, Pipeline, + DefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -624,7 +624,7 @@ where } impl - HasDefaultProcessor, D, Ctx>, [T]> for Quantized + DefaultPostProcessor, D, Ctx>, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -632,7 +632,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - has_default_processor!(Pipeline); + default_post_processor!(Pipeline); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -661,7 +661,7 @@ where } impl - HasDefaultProcessor, D, Ctx>, [T]> for Quantized + DefaultPostProcessor, D, Ctx>, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, @@ -669,7 +669,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - has_default_processor!(Pipeline); + default_post_processor!(Pipeline); } impl PruneStrategy, D, Ctx>> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 9bc43d7d4..5598e6b4e 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -7,12 +7,12 @@ use std::{future::Future, sync::Mutex}; -use diskann::has_default_processor; +use diskann::default_post_processor; use diskann::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::glue::{ - ExpandBeam, FillSet, FilterStartPoints, HasDefaultProcessor, InsertStrategy, Pipeline, + DefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, }, provider::{ @@ -572,14 +572,14 @@ where } } -impl HasDefaultProcessor, [T]> +impl DefaultPostProcessor, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(Pipeline); + default_post_processor!(Pipeline); } /// SearchStrategy for quantized search when only the quantized store is present. @@ -605,14 +605,14 @@ where } } -impl HasDefaultProcessor, [T]> +impl DefaultPostProcessor, [T]> for Quantized where T: VectorRepr, D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - has_default_processor!(Pipeline); + default_post_processor!(Pipeline); } impl PruneStrategy> for Quantized diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 45f8ee61d..5b539f360 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -6,13 +6,12 @@ use std::{future::Future, sync::Mutex}; use diskann::{ - ANNError, ANNResult, + ANNError, ANNResult, default_post_processor, error::{RankedError, ToRanked, TransientError}, graph::glue::{ - AsElement, CopyIds, ExpandBeam, FillSet, HasDefaultProcessor, InsertStrategy, + AsElement, CopyIds, DefaultPostProcessor, ExpandBeam, FillSet, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, }, - has_default_processor, neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, @@ -251,8 +250,8 @@ impl SearchStrategy for Flaky { } } -impl HasDefaultProcessor for Flaky { - has_default_processor!(CopyIds); +impl DefaultPostProcessor for Flaky { + default_post_processor!(CopyIds); } impl FillSet for FlakyAccessor<'_> {} diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 0e625f9f2..57595335c 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -144,16 +144,16 @@ where } } -/// [`HasDefaultProcessor`] delegation for [`BetaFilter`]. The processor is composed by +/// [`DefaultPostProcessor`] delegation for [`BetaFilter`]. The processor is composed by /// wrapping the inner strategy's processor with [`Unwrap`] via [`Pipeline`]. -impl glue::HasDefaultProcessor +impl glue::DefaultPostProcessor for BetaFilter where T: ?Sized, I: VectorId, O: Send, Provider: DataProvider, - Strategy: glue::HasDefaultProcessor, + Strategy: glue::DefaultPostProcessor, { type Processor = glue::Pipeline; @@ -559,8 +559,8 @@ mod tests { } } - impl glue::HasDefaultProcessor for SimpleStrategy { - diskann::has_default_processor!(CopyIds); + impl glue::DefaultPostProcessor for SimpleStrategy { + diskann::default_post_processor!(CopyIds); } /// A simple `QueryLabelProvider` that matches multiples of 3. diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 7e42f14a7..fcadc3ec6 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -338,7 +338,7 @@ where /// Strategies implementing this trait work with [`super::search::Knn`] (no explicit /// processor). The search infrastructure will call `create_processor()` to obtain the /// processor and invoke its [`SearchPostProcess::post_process`] method. -pub trait HasDefaultProcessor::InternalId>: +pub trait DefaultPostProcessor::InternalId>: SearchStrategy where Provider: DataProvider, @@ -354,7 +354,7 @@ where /// Aggregate trait for strategies that support both search access and a default post-processor. pub trait DefaultSearchStrategy::InternalId>: - SearchStrategy + HasDefaultProcessor + SearchStrategy + DefaultPostProcessor where Provider: DataProvider, T: ?Sized, @@ -364,25 +364,25 @@ where impl DefaultSearchStrategy for S where - S: SearchStrategy + HasDefaultProcessor, + S: SearchStrategy + DefaultPostProcessor, Provider: DataProvider, T: ?Sized, O: Send, { } -/// Convenience macro for implementing [`HasDefaultProcessor`] when the processor +/// Convenience macro for implementing [`DefaultPostProcessor`] when the processor /// is a [`Default`]-constructible type. /// /// # Example /// /// ```ignore -/// impl HasDefaultProcessor for MyStrategy { -/// has_default_processor!(CopyIds); +/// impl DefaultPostProcessor for MyStrategy { +/// default_post_processor!(CopyIds); /// } /// ``` #[macro_export] -macro_rules! has_default_processor { +macro_rules! default_post_processor { ($Processor:ty) => { type Processor = $Processor; fn create_processor(&self) -> Self::Processor { @@ -1089,8 +1089,8 @@ mod tests { } } - impl HasDefaultProcessor for Strategy { - has_default_processor!(CopyIds); + impl DefaultPostProcessor for Strategy { + default_post_processor!(CopyIds); } // Use the provided implementation. diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index dc716ec9f..e492a461a 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2152,7 +2152,7 @@ where ) -> impl SendFuture> where P: Search, - S: glue::HasDefaultProcessor, + S: glue::DefaultPostProcessor, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index b2543b013..00186969d 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -16,10 +16,9 @@ use diskann_vector::distance::Metric; use thiserror::Error; use crate::{ - ANNError, ANNResult, + ANNError, ANNResult, default_post_processor, error::{Infallible, message}, graph::{AdjacencyList, glue, test::synthetic}, - has_default_processor, internal::counter::{Counter, LocalCounter}, provider, utils::VectorRepr, @@ -965,8 +964,8 @@ impl glue::SearchStrategy for Strategy { } } -impl glue::HasDefaultProcessor for Strategy { - has_default_processor!(glue::CopyIds); +impl glue::DefaultPostProcessor for Strategy { + default_post_processor!(glue::CopyIds); } impl glue::PruneStrategy for Strategy { From 2ad1ed4335da41d22ba9cfc77c42d5153e4e7f56 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 13:17:48 -0700 Subject: [PATCH 30/47] Get caching provider working again. --- diskann-garnet/src/provider.rs | 1 + .../graph/provider/async_/bf_tree/provider.rs | 3 +- .../graph/provider/async_/caching/provider.rs | 34 ++++++++++++++++--- .../graph/provider/async_/debug_provider.rs | 2 ++ .../provider/async_/inmem/full_precision.rs | 1 + .../graph/provider/async_/inmem/product.rs | 1 + diskann/src/graph/glue.rs | 23 ++++++++++--- diskann/src/graph/test/provider.rs | 1 + 8 files changed, 56 insertions(+), 10 deletions(-) diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 21825ff69..737189578 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -834,6 +834,7 @@ impl InplaceDeleteStrategy> for FullPrecision { type DeleteElementError = GarnetProviderError; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = FullAccessor<'a, T>; type SearchPostProcessor = glue::CopyIds; type SearchStrategy = Self; diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 25f4e2dd6..7bb269f7c 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -18,7 +18,6 @@ use serde::{Deserialize, Serialize}; use bf_tree::{BfTree, Config}; use diskann::{ ANNError, ANNResult, default_post_processor, - error::IntoANNResult, graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ @@ -1681,6 +1680,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = FullAccessor<'a, T, Q, D>; type SearchPostProcessor = RemoveDeletedIdsAndCopy; type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { @@ -1720,6 +1720,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = QuantAccessor<'a, T, D>; type SearchPostProcessor = Rerank; type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index 087122efe..68a055cb3 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -1069,20 +1069,46 @@ where } } -/// More surprisingly - the `where` clause for this implementation is **also** straightforward. -impl InplaceDeleteStrategy> for Cached +/// The `where` clause requires that: +/// +/// 1. The inner strategy's [`DeleteSearchAccessor`] is cacheable. +/// 2. The cache `C` can produce a cache-accessor for the inner strategy's accessor. +/// 3. The wrapped search strategy `Cached` remains a valid +/// `SearchStrategy` for `CachingProvider` (needed for the equality constraint on +/// [`InplaceDeleteStrategy::SearchStrategy`]). +impl InplaceDeleteStrategy> for Cached where DP: DataProvider, S: InplaceDeleteStrategy, + for<'a> S::DeleteSearchAccessor<'a>: CacheableAccessor, Cached: PruneStrategy>, - for<'a> Cached: SearchStrategy, S::DeleteElement<'a>>, - C: AsyncFriendly, + for<'a> Cached: SearchStrategy< + CachingProvider, + S::DeleteElement<'a>, + SearchAccessor<'a> = CachingAccessor< + S::DeleteSearchAccessor<'a>, + >>::Accessor, + >, + >, + C: for<'a> AsCacheAccessorFor< + 'a, + S::DeleteSearchAccessor<'a>, + Accessor: NeighborCache, + Error = E, + > + AsyncFriendly, + E: StandardError, { type DeleteElement<'a> = S::DeleteElement<'a>; type DeleteElementGuard = S::DeleteElementGuard; type DeleteElementError = S::DeleteElementError; type PruneStrategy = Cached; + + type DeleteSearchAccessor<'a> = CachingAccessor< + S::DeleteSearchAccessor<'a>, + >>::Accessor, + >; + type SearchStrategy = Cached; type SearchPostProcessor = Pipeline; diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs index 492d8f8d6..8af07aa89 100644 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs @@ -1012,6 +1012,7 @@ impl InplaceDeleteStrategy for FullPrecision { type DeleteElementGuard = Vec; type DeleteElementError = Panics; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = FullAccessor<'a>; type SearchPostProcessor = postprocess::RemoveDeletedIdsAndCopy; type SearchStrategy = Self; @@ -1043,6 +1044,7 @@ impl InplaceDeleteStrategy for Quantized { type DeleteElementGuard = Vec; type DeleteElementError = Panics; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = QuantAccessor<'a>; type SearchPostProcessor = postprocess::RemoveDeletedIdsAndCopy; type SearchStrategy = Self; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 2de024bbe..5c259edf5 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -527,6 +527,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; type SearchPostProcessor = RemoveDeletedIdsAndCopy; type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 53cb2e827..c3064ec0d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -554,6 +554,7 @@ where type DeleteElement<'a> = [T]; type DeleteElementGuard = Box<[T]>; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchPostProcessor = Rerank; type SearchStrategy = Self; fn search_strategy(&self) -> Self::SearchStrategy { diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index fcadc3ec6..829c6c517 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -792,15 +792,28 @@ where /// The pruning strategy to use after the initial search is complete. type PruneStrategy: PruneStrategy; + /// The accessor used during the delete-search phase. + /// + /// This is technically redundant information as in theory, we could project trhough + /// [`Self::SearchStrategy`]. However, when trying to write generic wrappers (read, + /// the "caching" provider), rustc is unable to project all the way through the layers + /// of associated types. + /// + /// Lifting the accessor all the way to the trait level makes the caching provider possible. + type DeleteSearchAccessor<'a>: ExpandBeam, Id = Provider::InternalId> + + SearchExt; + /// The processor used during the delete-search phase. - type SearchPostProcessor: for<'a> SearchPostProcess< - >>::SearchAccessor<'a>, - Self::DeleteElement<'a>, - > + Send + type SearchPostProcessor: for<'a> SearchPostProcess, Self::DeleteElement<'a>> + + Send + Sync; /// The type of the search strategy to use for graph traversal. - type SearchStrategy: for<'a> SearchStrategy>; + type SearchStrategy: for<'a> SearchStrategy< + Provider, + Self::DeleteElement<'a>, + SearchAccessor<'a> = Self::DeleteSearchAccessor<'a>, + >; /// Construct the prune strategy object. fn prune_strategy(&self) -> Self::PruneStrategy; diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 00186969d..2509b3c7c 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1014,6 +1014,7 @@ impl glue::InplaceDeleteStrategy for Strategy { type DeleteElementGuard = Box<[f32]>; type DeleteElementError = AccessedInvalidId; type PruneStrategy = Self; + type DeleteSearchAccessor<'a> = Accessor<'a>; type SearchStrategy = Self; type SearchPostProcessor = glue::CopyIds; From 0ee76b14bb49f03dcf33454ee2d34a69e3af0b05 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 13:35:22 -0700 Subject: [PATCH 31/47] Clean up some stragglers. --- diskann-benchmark/src/inputs/disk.rs | 1 - diskann-disk/src/search/provider/disk_provider.rs | 8 +++++--- .../model/graph/provider/async_/inmem/full_precision.rs | 5 +++-- .../src/model/graph/provider/async_/inmem/product.rs | 5 ++--- .../src/model/graph/provider/async_/inmem/scalar.rs | 3 +-- .../src/model/graph/provider/async_/inmem/spherical.rs | 3 +-- diskann/src/graph/search/knn_search.rs | 2 +- diskann/src/graph/search/mod.rs | 2 +- 8 files changed, 14 insertions(+), 15 deletions(-) diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 0572f99f6..bf843d72f 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -234,7 +234,6 @@ impl CheckDeserialization for DiskSearchPhase { anyhow::bail!("search_io_limit must be positive if specified"); } } - Ok(()) } } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index cffc22f10..d40e921b7 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -23,7 +23,7 @@ use diskann::{ SearchStrategy, }, search::Knn, - search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, + search_output_buffer, AdjacencyList, DiskANNIndex, }, neighbor::Neighbor, provider::{ @@ -303,7 +303,9 @@ where ) -> Result where I: Iterator> + Send, - B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, { let provider = accessor.provider; @@ -470,7 +472,7 @@ where let load_ids: Box<[_]> = ids.take(io_limit).collect(); self.ensure_loaded(&load_ids)?; - let mut ids: Vec = Vec::new(); + let mut ids = Vec::new(); for i in load_ids { ids.clear(); ids.extend( diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 5c259edf5..6d2d4c627 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -404,6 +404,7 @@ where let checker = accessor.as_deletion_check(); let f = full.distance(); + // Filter before computing the full precision distances. let mut reranked: Vec<(u32, f32)> = candidates .filter_map(|n| { if checker.deletion_check(n.id) { @@ -419,9 +420,11 @@ where }) .collect(); + // Sort the full precision distances. reranked .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + // Store the reranked results. std::future::ready(Ok(output.extend(reranked))) } } @@ -431,8 +434,6 @@ where //////////////// /// Perform a search entirely in the full-precision space. -/// -/// Starting points are not filtered out of the final results. impl SearchStrategy, [T]> for FullPrecision where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index c3064ec0d..d59307805 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -461,9 +461,6 @@ where //////////// /// Perform a search entirely in the quantized space. -/// -/// Starting points are filtered out of the final results and results are reranked using -/// the full-precision data. impl SearchStrategy, [T]> for Hybrid where T: VectorRepr, @@ -483,6 +480,8 @@ where } } +/// Starting points are filtered out of the final results and results are reranked using +/// the full-precision data. impl DefaultPostProcessor, [T]> for Hybrid where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 0cfc6ef49..0bc909486 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -6,9 +6,8 @@ use std::{future::Future, sync::Mutex}; use crate::storage::{StorageReadProvider, StorageWriteProvider}; -use diskann::default_post_processor; use diskann::{ - ANNError, ANNResult, + ANNError, ANNResult, default_post_processor, graph::glue::{ DefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 5598e6b4e..c17f2b57f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -7,9 +7,8 @@ use std::{future::Future, sync::Mutex}; -use diskann::default_post_processor; use diskann::{ - ANNError, ANNErrorKind, ANNResult, + ANNError, ANNErrorKind, ANNResult, default_post_processor, error::IntoANNResult, graph::glue::{ DefaultPostProcessor, ExpandBeam, FillSet, FilterStartPoints, InsertStrategy, Pipeline, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 1f5cba95d..51fb0ec25 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -146,7 +146,7 @@ impl Knn { impl Search for Knn where DP: DataProvider, - S: crate::graph::glue::SearchStrategy, + S: glue::SearchStrategy, T: Sync + ?Sized, O: Send, { diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index f804bf80a..568fdc594 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -54,7 +54,7 @@ pub(crate) mod scratch; pub trait Search where DP: DataProvider, - S: crate::graph::glue::SearchStrategy, + S: graph::glue::SearchStrategy, O: Send, { /// The result type returned by this search. From 5fef9919c9b0c33ca3aa5b17c55386f39c7eba0a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 13:38:03 -0700 Subject: [PATCH 32/47] Unify naming. --- diskann-disk/src/search/provider/disk_provider.rs | 2 +- .../src/inline_beta_search/inline_beta_filter.rs | 4 ++-- .../src/model/graph/provider/async_/caching/provider.rs | 4 ++-- .../src/model/graph/provider/layers/betafilter.rs | 4 ++-- diskann/src/graph/glue.rs | 8 ++++---- diskann/src/graph/index.rs | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index d40e921b7..5913d0030 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -387,7 +387,7 @@ where { type Processor = RerankAndFilter<'this>; - fn create_processor(&self) -> Self::Processor { + fn default_post_processor(&self) -> Self::Processor { RerankAndFilter::new(self.vector_filter) } } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 25ddd22c3..6e4015948 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -75,9 +75,9 @@ where { type Processor = FilterResults; - fn create_processor(&self) -> Self::Processor { + fn default_post_processor(&self) -> Self::Processor { FilterResults { - inner_post_processor: self.inner.create_processor(), + inner_post_processor: self.inner.default_post_processor(), } } } diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index 68a055cb3..dca030490 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -998,8 +998,8 @@ where { type Processor = Pipeline; - fn create_processor(&self) -> Self::Processor { - Pipeline::new(Unwrap, self.strategy.create_processor()) + fn default_post_processor(&self) -> Self::Processor { + Pipeline::new(Unwrap, self.strategy.default_post_processor()) } } diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 57595335c..e01886943 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -157,8 +157,8 @@ where { type Processor = glue::Pipeline; - fn create_processor(&self) -> Self::Processor { - glue::Pipeline::new(Unwrap, self.strategy.create_processor()) + fn default_post_processor(&self) -> Self::Processor { + glue::Pipeline::new(Unwrap, self.strategy.default_post_processor()) } } diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 829c6c517..6a4a15bc6 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -336,7 +336,7 @@ where /// Opt-in trait for strategies that have a default post-processor. /// /// Strategies implementing this trait work with [`super::search::Knn`] (no explicit -/// processor). The search infrastructure will call `create_processor()` to obtain the +/// processor). The search infrastructure will call `default_post_processor()` to obtain the /// processor and invoke its [`SearchPostProcess::post_process`] method. pub trait DefaultPostProcessor::InternalId>: SearchStrategy @@ -349,7 +349,7 @@ where type Processor: for<'a> SearchPostProcess, T, O> + Send + Sync; /// Create the default post-processor. - fn create_processor(&self) -> Self::Processor; + fn default_post_processor(&self) -> Self::Processor; } /// Aggregate trait for strategies that support both search access and a default post-processor. @@ -385,7 +385,7 @@ where macro_rules! default_post_processor { ($Processor:ty) => { type Processor = $Processor; - fn create_processor(&self) -> Self::Processor { + fn default_post_processor(&self) -> Self::Processor { Default::default() } }; @@ -1154,7 +1154,7 @@ mod tests { let mut output = vec![Neighbor::::default(); output_len]; let count = strategy - .create_processor() + .default_post_processor() .post_process( &mut accessor, &query, diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index e492a461a..90be69569 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2157,7 +2157,7 @@ where OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, T: ?Sized, { - let processor = strategy.create_processor(); + let processor = strategy.default_post_processor(); self.search_with(search_params, strategy, processor, context, query, output) } @@ -2255,7 +2255,7 @@ where } let result_count = strategy - .create_processor() + .default_post_processor() .post_process( &mut accessor, query, From ca84f41366cee1ce3fa2cee9edf049b169d5c4e5 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 14:08:42 -0700 Subject: [PATCH 33/47] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann/src/graph/search/mod.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 568fdc594..844beb51d 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -12,15 +12,26 @@ //! # Usage //! //! ```ignore -//! use diskann::graph::{search::{Knn, Range, MultihopSearch}, Search}; +//! use diskann::graph::{ +//! neighbor::{BackInserter, Neighbor}, +//! search::{Knn, Range, MultihopSearch}, +//! Search, +//! }; //! -//! // Standard k-NN search +//! // Standard k-NN search: use a fixed-capacity output buffer //! let params = Knn::new(10, 100, None)?; -//! let stats = index.search(params, &strategy, &context, &query, &mut output).await?; +//! let mut knn_storage = [Neighbor::default(); 10]; +//! let mut knn_output = BackInserter::new(&mut knn_storage); +//! let stats = index +//! .search(params, &strategy, &context, &query, &mut knn_output) +//! .await?; //! -//! // Range search +//! // Range search: use a growable Vec buffer for an unknown number of results //! let params = Range::new(100, 0.5)?; -//! let stats = index.search(params, &strategy, &context, &query, &mut output).await?; +//! let mut range_output: Vec> = Vec::new(); +//! let stats = index +//! .search(params, &strategy, &context, &query, &mut range_output) +//! .await?; //! ``` use diskann_utils::future::SendFuture; From 0fcfde65af5a80c2206ce97b64f08ca07905f75e Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 14:09:02 -0700 Subject: [PATCH 34/47] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann/src/graph/search/range_search.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 3453f9f87..2d4b2a688 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -245,7 +245,7 @@ where { return false; } - dist < radius + dist <= radius }); let result_count = processor From 7ee952b72a60a91aa84120b65ecffe8c12f6a0eb Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 14:09:55 -0700 Subject: [PATCH 35/47] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann/src/graph/search/range_search.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 2d4b2a688..7f8f3457d 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -295,7 +295,10 @@ where if (self.predicate)(distance) { self.inner.push(id, distance) } else { - search_output_buffer::BufferState::Available + match self.inner.size_hint() { + Some(0) => search_output_buffer::BufferState::Full, + _ => search_output_buffer::BufferState::Available, + } } } From bc11046f77fdf0a0b6de6d065018ae71e32be21e Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 14:10:08 -0700 Subject: [PATCH 36/47] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann/src/graph/glue.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 6a4a15bc6..726d6efdb 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -794,7 +794,7 @@ where /// The accessor used during the delete-search phase. /// - /// This is technically redundant information as in theory, we could project trhough + /// This is technically redundant information as in theory, we could project through /// [`Self::SearchStrategy`]. However, when trying to write generic wrappers (read, /// the "caching" provider), rustc is unable to project all the way through the layers /// of associated types. From 4d44eb99f82c2b977b22cb426f9903eb7ea85a86 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Sat, 14 Mar 2026 14:10:24 -0700 Subject: [PATCH 37/47] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann/src/graph/glue.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 726d6efdb..2ffd13b02 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -335,9 +335,12 @@ where /// Opt-in trait for strategies that have a default post-processor. /// -/// Strategies implementing this trait work with [`super::search::Knn`] (no explicit -/// processor). The search infrastructure will call `default_post_processor()` to obtain the -/// processor and invoke its [`SearchPostProcess::post_process`] method. +/// Strategies implementing this trait can be used with index-level search APIs such as +/// [`crate::index::diskann_async::DiskANNIndex::search`] and +/// [`crate::index::diskann_async::DiskANNIndex::search_with`] when no explicit +/// post-processor is specified. The search infrastructure will call +/// `default_post_processor()` to obtain the processor and invoke its +/// [`SearchPostProcess::post_process`] method. pub trait DefaultPostProcessor::InternalId>: SearchStrategy where From 1f688bb2961edbb1d27931d1273e81b5ec696e00 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 17 Mar 2026 17:01:33 -0700 Subject: [PATCH 38/47] Bump version. --- Cargo.lock | 30 +++++++++++++++--------------- Cargo.toml | 28 ++++++++++++++-------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fd817b46a..d6a50902d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -621,7 +621,7 @@ dependencies = [ [[package]] name = "diskann" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "bytemuck", @@ -645,7 +645,7 @@ dependencies = [ [[package]] name = "diskann-benchmark" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "bf-tree", @@ -682,7 +682,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-core" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "diskann", @@ -699,7 +699,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-runner" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "clap", @@ -713,7 +713,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-simd" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "diskann-benchmark-runner", @@ -730,7 +730,7 @@ dependencies = [ [[package]] name = "diskann-disk" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "bincode", @@ -781,7 +781,7 @@ dependencies = [ [[package]] name = "diskann-label-filter" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "bf-tree", @@ -804,7 +804,7 @@ dependencies = [ [[package]] name = "diskann-linalg" -version = "0.49.1" +version = "0.50.0" dependencies = [ "approx", "cfg-if", @@ -818,7 +818,7 @@ dependencies = [ [[package]] name = "diskann-platform" -version = "0.49.1" +version = "0.50.0" dependencies = [ "io-uring", "libc", @@ -828,7 +828,7 @@ dependencies = [ [[package]] name = "diskann-providers" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "approx", @@ -872,7 +872,7 @@ dependencies = [ [[package]] name = "diskann-quantization" -version = "0.49.1" +version = "0.50.0" dependencies = [ "bytemuck", "cfg-if", @@ -891,7 +891,7 @@ dependencies = [ [[package]] name = "diskann-tools" -version = "0.49.1" +version = "0.50.0" dependencies = [ "anyhow", "bincode", @@ -923,7 +923,7 @@ dependencies = [ [[package]] name = "diskann-utils" -version = "0.49.1" +version = "0.50.0" dependencies = [ "bytemuck", "cfg-if", @@ -938,7 +938,7 @@ dependencies = [ [[package]] name = "diskann-vector" -version = "0.49.1" +version = "0.50.0" dependencies = [ "approx", "cfg-if", @@ -952,7 +952,7 @@ dependencies = [ [[package]] name = "diskann-wide" -version = "0.49.1" +version = "0.50.0" dependencies = [ "cfg-if", "half", diff --git a/Cargo.toml b/Cargo.toml index 91cb564af..4e0cda86a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ default-members = [ resolver = "3" [workspace.package] -version = "0.49.1" # Obeying semver +version = "0.50.0" # Obeying semver description = "DiskANN is a fast approximate nearest neighbor search library for high dimensional data" authors = ["Microsoft"] documentation = "https://github.com/microsoft/DiskANN" @@ -48,22 +48,22 @@ undocumented_unsafe_blocks = "warn" [workspace.dependencies] # Base And Numerics -diskann-wide = { path = "diskann-wide", version = "0.49.1" } -diskann-vector = { path = "diskann-vector", version = "0.49.1" } -diskann-linalg = { path = "diskann-linalg", version = "0.49.1" } -diskann-utils = { path = "diskann-utils", default-features = false, version = "0.49.1" } -diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.49.1" } -diskann-platform = { path = "diskann-platform", version = "0.49.1" } +diskann-wide = { path = "diskann-wide", version = "0.50.0" } +diskann-vector = { path = "diskann-vector", version = "0.50.0" } +diskann-linalg = { path = "diskann-linalg", version = "0.50.0" } +diskann-utils = { path = "diskann-utils", default-features = false, version = "0.50.0" } +diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.50.0" } +diskann-platform = { path = "diskann-platform", version = "0.50.0" } # Algorithm -diskann = { path = "diskann", version = "0.49.1" } +diskann = { path = "diskann", version = "0.50.0" } # Providers -diskann-providers = { path = "diskann-providers", default-features = false, version = "0.49.1" } -diskann-disk = { path = "diskann-disk", version = "0.49.1" } -diskann-label-filter = { path = "diskann-label-filter", version = "0.49.1" } +diskann-providers = { path = "diskann-providers", default-features = false, version = "0.50.0" } +diskann-disk = { path = "diskann-disk", version = "0.50.0" } +diskann-label-filter = { path = "diskann-label-filter", version = "0.50.0" } # Infra -diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.49.1" } -diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.49.1" } -diskann-tools = { path = "diskann-tools", version = "0.49.1" } +diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.50.0" } +diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.50.0" } +diskann-tools = { path = "diskann-tools", version = "0.50.0" } # External dependencies (shared versions) anyhow = "1.0.98" From 635012d0bc846e8bcfcc2a68369feacdaf7769f9 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 18 Mar 2026 08:24:07 -0700 Subject: [PATCH 39/47] Reset back to 0.49.1 --- Cargo.lock | 30 +++++++++++++++--------------- Cargo.toml | 28 ++++++++++++++-------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d6a50902d..fd817b46a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -621,7 +621,7 @@ dependencies = [ [[package]] name = "diskann" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "bytemuck", @@ -645,7 +645,7 @@ dependencies = [ [[package]] name = "diskann-benchmark" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "bf-tree", @@ -682,7 +682,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-core" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "diskann", @@ -699,7 +699,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-runner" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "clap", @@ -713,7 +713,7 @@ dependencies = [ [[package]] name = "diskann-benchmark-simd" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "diskann-benchmark-runner", @@ -730,7 +730,7 @@ dependencies = [ [[package]] name = "diskann-disk" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "bincode", @@ -781,7 +781,7 @@ dependencies = [ [[package]] name = "diskann-label-filter" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "bf-tree", @@ -804,7 +804,7 @@ dependencies = [ [[package]] name = "diskann-linalg" -version = "0.50.0" +version = "0.49.1" dependencies = [ "approx", "cfg-if", @@ -818,7 +818,7 @@ dependencies = [ [[package]] name = "diskann-platform" -version = "0.50.0" +version = "0.49.1" dependencies = [ "io-uring", "libc", @@ -828,7 +828,7 @@ dependencies = [ [[package]] name = "diskann-providers" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "approx", @@ -872,7 +872,7 @@ dependencies = [ [[package]] name = "diskann-quantization" -version = "0.50.0" +version = "0.49.1" dependencies = [ "bytemuck", "cfg-if", @@ -891,7 +891,7 @@ dependencies = [ [[package]] name = "diskann-tools" -version = "0.50.0" +version = "0.49.1" dependencies = [ "anyhow", "bincode", @@ -923,7 +923,7 @@ dependencies = [ [[package]] name = "diskann-utils" -version = "0.50.0" +version = "0.49.1" dependencies = [ "bytemuck", "cfg-if", @@ -938,7 +938,7 @@ dependencies = [ [[package]] name = "diskann-vector" -version = "0.50.0" +version = "0.49.1" dependencies = [ "approx", "cfg-if", @@ -952,7 +952,7 @@ dependencies = [ [[package]] name = "diskann-wide" -version = "0.50.0" +version = "0.49.1" dependencies = [ "cfg-if", "half", diff --git a/Cargo.toml b/Cargo.toml index 4e0cda86a..91cb564af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ default-members = [ resolver = "3" [workspace.package] -version = "0.50.0" # Obeying semver +version = "0.49.1" # Obeying semver description = "DiskANN is a fast approximate nearest neighbor search library for high dimensional data" authors = ["Microsoft"] documentation = "https://github.com/microsoft/DiskANN" @@ -48,22 +48,22 @@ undocumented_unsafe_blocks = "warn" [workspace.dependencies] # Base And Numerics -diskann-wide = { path = "diskann-wide", version = "0.50.0" } -diskann-vector = { path = "diskann-vector", version = "0.50.0" } -diskann-linalg = { path = "diskann-linalg", version = "0.50.0" } -diskann-utils = { path = "diskann-utils", default-features = false, version = "0.50.0" } -diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.50.0" } -diskann-platform = { path = "diskann-platform", version = "0.50.0" } +diskann-wide = { path = "diskann-wide", version = "0.49.1" } +diskann-vector = { path = "diskann-vector", version = "0.49.1" } +diskann-linalg = { path = "diskann-linalg", version = "0.49.1" } +diskann-utils = { path = "diskann-utils", default-features = false, version = "0.49.1" } +diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.49.1" } +diskann-platform = { path = "diskann-platform", version = "0.49.1" } # Algorithm -diskann = { path = "diskann", version = "0.50.0" } +diskann = { path = "diskann", version = "0.49.1" } # Providers -diskann-providers = { path = "diskann-providers", default-features = false, version = "0.50.0" } -diskann-disk = { path = "diskann-disk", version = "0.50.0" } -diskann-label-filter = { path = "diskann-label-filter", version = "0.50.0" } +diskann-providers = { path = "diskann-providers", default-features = false, version = "0.49.1" } +diskann-disk = { path = "diskann-disk", version = "0.49.1" } +diskann-label-filter = { path = "diskann-label-filter", version = "0.49.1" } # Infra -diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.50.0" } -diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.50.0" } -diskann-tools = { path = "diskann-tools", version = "0.50.0" } +diskann-benchmark-runner = { path = "diskann-benchmark-runner", version = "0.49.1" } +diskann-benchmark-core = { path = "diskann-benchmark-core", version = "0.49.1" } +diskann-tools = { path = "diskann-tools", version = "0.49.1" } # External dependencies (shared versions) anyhow = "1.0.98" From 3145904ed3d28e5e0818729ebdcbaab2a783896f Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 18 Mar 2026 18:43:02 -0700 Subject: [PATCH 40/47] SearchStrategy no longer needs the output "ID" type. --- .../src/search/provider/disk_provider.rs | 11 ++--------- diskann-garnet/src/dyn_index.rs | 2 +- diskann-garnet/src/provider.rs | 16 +--------------- diskann-providers/src/index/diskann_async.rs | 1 - .../model/graph/provider/layers/betafilter.rs | 5 ++--- diskann/src/graph/glue.rs | 18 +++++------------- diskann/src/graph/index.rs | 6 +++--- diskann/src/graph/search/diverse_search.rs | 8 ++++---- diskann/src/graph/search/knn_search.rs | 16 ++++++++-------- diskann/src/graph/search/mod.rs | 8 ++++---- diskann/src/graph/search/multihop_search.rs | 8 ++++---- diskann/src/graph/search/range_search.rs | 8 ++++---- 12 files changed, 38 insertions(+), 69 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 5913d0030..c0b16beba 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -340,15 +340,8 @@ where } } -impl<'this, Data, ProviderFactory> - SearchStrategy< - DiskProvider, - [Data::VectorDataType], - ( - as DataProvider>::InternalId, - Data::AssociatedDataType, - ), - > for DiskSearchStrategy<'this, Data, ProviderFactory> +impl<'this, Data, ProviderFactory> SearchStrategy, [Data::VectorDataType]> + for DiskSearchStrategy<'this, Data, ProviderFactory> where Data: GraphDataType, ProviderFactory: VertexProviderFactory, diff --git a/diskann-garnet/src/dyn_index.rs b/diskann-garnet/src/dyn_index.rs index 37a539efa..c4f35e530 100644 --- a/diskann-garnet/src/dyn_index.rs +++ b/diskann-garnet/src/dyn_index.rs @@ -106,7 +106,7 @@ impl DynIndex for DiskANNIndex> { .build() .map_err(|e| ANNError::new(diskann::ANNErrorKind::Opaque, e))?; let mut accessor: provider::FullAccessor<'_, T> = - >::search_accessor( + >::search_accessor( &FullPrecision, self.inner.provider(), context, diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 65da56552..79f0143ef 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -760,7 +760,7 @@ impl<'a, T: VectorRepr> SearchPostProcess, [T], GarnetId> fo } } -impl SearchStrategy, [T], GarnetId> for FullPrecision { +impl SearchStrategy, [T]> for FullPrecision { type SearchAccessor<'a> = FullAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; type QueryComputer = T::QueryDistance; @@ -778,20 +778,6 @@ impl DefaultPostProcessor, [T], GarnetId> for F default_post_processor!(glue::Pipeline); } -impl SearchStrategy, [T], u32> for FullPrecision { - type SearchAccessor<'a> = FullAccessor<'a, T>; - type SearchAccessorError = GarnetProviderError; - type QueryComputer = T::QueryDistance; - - fn search_accessor<'a>( - &'a self, - provider: &'a GarnetProvider, - context: &'a as DataProvider>::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, true)) - } -} - impl PruneStrategy> for FullPrecision { type PruneAccessor<'a> = FullAccessor<'a, T>; type PruneAccessorError = GarnetProviderError; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 58a499d43..3ad3797b8 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -2808,7 +2808,6 @@ pub(crate) mod tests { let accessor = , [f32], - _, >>::search_accessor(&strategy, index.provider(), ctx) .unwrap(); let computer = accessor.build_query_computer(data.row(0)).unwrap(); diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index e01886943..fddeb74f6 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -112,13 +112,12 @@ where /// /// The [`BetaComputer`] then uses this ID to consult the filter predicate and adjust the /// distance accordingly. -impl SearchStrategy for BetaFilter +impl SearchStrategy for BetaFilter where T: ?Sized, I: VectorId, - O: Send, Provider: DataProvider, - Strategy: SearchStrategy, + Strategy: SearchStrategy, { /// An accessor that returns the ID in addition to the element yielded by the inner /// accessor. diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 2ffd13b02..c5d8fed6c 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -291,19 +291,11 @@ where /// A search strategy for query objects of type `T`. /// /// This trait should be overloaded by data providers wishing to extend -/// -/// The type `O` represents the type written into the output buffer -/// during a search. This is often the same as the provider's internal ID type, -/// but it can differ depending on the use case. For example, it might represent -/// associated data or alternative identifiers. -/// -/// [`crate::index::diskann_async::DiskANNIndex::search`]. -pub trait SearchStrategy::InternalId>: - Send + Sync +/// (search)[`crate::graph::DiskANNIndex::search`]. +pub trait SearchStrategy: Send + Sync where Provider: DataProvider, T: ?Sized, - O: Send, { /// The computer used by the associated accessor. /// @@ -342,7 +334,7 @@ where /// `default_post_processor()` to obtain the processor and invoke its /// [`SearchPostProcess::post_process`] method. pub trait DefaultPostProcessor::InternalId>: - SearchStrategy + SearchStrategy where Provider: DataProvider, T: ?Sized, @@ -357,7 +349,7 @@ where /// Aggregate trait for strategies that support both search access and a default post-processor. pub trait DefaultSearchStrategy::InternalId>: - SearchStrategy + DefaultPostProcessor + SearchStrategy + DefaultPostProcessor where Provider: DataProvider, T: ?Sized, @@ -367,7 +359,7 @@ where impl DefaultSearchStrategy for S where - S: SearchStrategy + DefaultPostProcessor, + S: SearchStrategy + DefaultPostProcessor, Provider: DataProvider, T: ?Sized, O: Send, diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 90be69569..e849ee92f 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2151,7 +2151,7 @@ where output: &mut OB, ) -> impl SendFuture> where - P: Search, + P: Search, S: glue::DefaultPostProcessor, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, @@ -2172,8 +2172,8 @@ where output: &mut OB, ) -> impl SendFuture> where - P: Search, - S: glue::SearchStrategy, + P: Search, + S: glue::SearchStrategy, PP: for<'a> glue::SearchPostProcess, T, O> + Send + Sync, O: Send, OB: search_output_buffer::SearchOutputBuffer + Send + ?Sized, diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index f7e57f52e..942b55ca0 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -92,17 +92,16 @@ where } } -impl Search for Diverse

+impl Search for Diverse

where DP: DataProvider, - S: crate::graph::glue::SearchStrategy, + S: crate::graph::glue::SearchStrategy, T: Sync + ?Sized, - O: Send, P: AttributeValueProvider, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -112,6 +111,7 @@ where output: &mut OB, ) -> impl SendFuture> where + O: Send, PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 51fb0ec25..940c9b9ab 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -143,12 +143,11 @@ impl Knn { } } -impl Search for Knn +impl Search for Knn where DP: DataProvider, - S: glue::SearchStrategy, + S: glue::SearchStrategy, T: Sync + ?Sized, - O: Send, { type Output = SearchStats; @@ -177,7 +176,7 @@ where /// # Errors /// /// Returns an error if there is a failure accessing elements or computing distances. - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -187,6 +186,7 @@ where output: &mut OB, ) -> impl SendFuture> where + O: Send, PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { @@ -249,17 +249,16 @@ impl<'r, SR: ?Sized> RecordedKnn<'r, SR> { } } -impl<'r, DP, S, T, O, SR> Search for RecordedKnn<'r, SR> +impl<'r, DP, S, T, SR> Search for RecordedKnn<'r, SR> where DP: DataProvider, - S: glue::SearchStrategy, + S: glue::SearchStrategy, T: Sync + ?Sized, - O: Send, SR: super::record::SearchRecord + ?Sized, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -269,6 +268,7 @@ where output: &mut OB, ) -> impl SendFuture> where + O: Send, PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index 844beb51d..4931d716a 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -62,11 +62,10 @@ pub(crate) mod scratch; /// - [`Diverse`] - Diversity-aware search (feature-gated) /// - [`MultihopSearch`] - Label-filtered search with multi-hop expansion /// - [`RecordedKnn`] - K-NN search with path recording for debugging -pub trait Search +pub trait Search where DP: DataProvider, - S: graph::glue::SearchStrategy, - O: Send, + S: graph::glue::SearchStrategy, { /// The result type returned by this search. type Output; @@ -93,7 +92,7 @@ where /// # Errors /// /// Returns an error if there is a failure accessing elements or computing distances. - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -103,6 +102,7 @@ where output: &mut OB, ) -> impl SendFuture> where + O: Send, PP: for<'a> graph::glue::SearchPostProcess, T, O> + Send + Sync, OB: graph::search_output_buffer::SearchOutputBuffer + Send + ?Sized; } diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 09ec0a001..2e9d3f69a 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -53,16 +53,15 @@ impl<'q, InternalId> MultihopSearch<'q, InternalId> { } } -impl<'q, DP, S, T, O> Search for MultihopSearch<'q, DP::InternalId> +impl<'q, DP, S, T> Search for MultihopSearch<'q, DP::InternalId> where DP: DataProvider, - S: glue::SearchStrategy, + S: glue::SearchStrategy, T: Sync + ?Sized, - O: Send, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -72,6 +71,7 @@ where output: &mut OB, ) -> impl SendFuture> where + O: Send, PP: for<'a> SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 7f8f3457d..e947cc259 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -157,16 +157,15 @@ impl Range { } } -impl Search for Range +impl Search for Range where DP: DataProvider, - S: glue::SearchStrategy, + S: glue::SearchStrategy, T: Sync + ?Sized, - O: Send + Default + Clone, { type Output = SearchStats; - fn search( + fn search( self, index: &DiskANNIndex, strategy: &S, @@ -176,6 +175,7 @@ where output: &mut OB, ) -> impl SendFuture> where + O: Send, PP: for<'a> glue::SearchPostProcess, T, O> + Send + Sync, OB: SearchOutputBuffer + Send + ?Sized, { From 58579bb6e9f4c177d1f2aaa271cfabcf1dab52e5 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 23 Mar 2026 16:42:19 +0530 Subject: [PATCH 41/47] Add determinant-diversity search post-process and benchmark integration --- Cargo.lock | 1 + diskann-benchmark-core/Cargo.toml | 1 + .../src/search/graph/determinant_diversity.rs | 205 +++++++ .../src/search/graph/mod.rs | 1 + .../src/backend/index/benchmarks.rs | 42 +- diskann-benchmark/src/backend/index/result.rs | 36 ++ .../src/backend/index/search/knn.rs | 133 ++++- .../src/backend/index/spherical.rs | 41 +- diskann-benchmark/src/inputs/async_.rs | 36 ++ .../determinant_diversity_post_process.rs | 503 ++++++++++++++++++ .../src/model/graph/provider/async_/mod.rs | 5 + tmp/wiki_compare_determinant_diversity.json | 63 +++ ...compare_determinant_diversity_results.json | 454 ++++++++++++++++ 13 files changed, 1494 insertions(+), 27 deletions(-) create mode 100644 diskann-benchmark-core/src/search/graph/determinant_diversity.rs create mode 100644 diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs create mode 100644 tmp/wiki_compare_determinant_diversity.json create mode 100644 tmp/wiki_compare_determinant_diversity_results.json diff --git a/Cargo.lock b/Cargo.lock index fd817b46a..e37c42b39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,6 +687,7 @@ dependencies = [ "anyhow", "diskann", "diskann-benchmark-runner", + "diskann-providers", "diskann-utils", "diskann-vector", "futures-util", diff --git a/diskann-benchmark-core/Cargo.toml b/diskann-benchmark-core/Cargo.toml index 90e64b9e3..689978c03 100644 --- a/diskann-benchmark-core/Cargo.toml +++ b/diskann-benchmark-core/Cargo.toml @@ -11,6 +11,7 @@ edition = "2024" anyhow.workspace = true diskann.workspace = true diskann-benchmark-runner = { workspace = true } +diskann-providers.workspace = true diskann-utils.default-features = false diskann-utils.workspace = true futures-util = { workspace = true, default-features = false } diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs new file mode 100644 index 000000000..58008afda --- /dev/null +++ b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs @@ -0,0 +1,205 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::sync::Arc; + +use diskann::{ + ANNResult, + graph::{self, glue}, + provider, +}; +use diskann_benchmark_runner::utils::{MicroSeconds, percentiles}; +use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; +use diskann_utils::{future::AsyncFriendly, views::Matrix}; + +use crate::{ + recall, + search::{self, Search, graph::Strategy}, + utils, +}; + +#[derive(Debug, Clone, Copy)] +pub struct Parameters { + pub inner: graph::search::Knn, + pub processor: DeterminantDiversitySearchParams, +} + +#[derive(Debug)] +pub struct DeterminantDiversity +where + DP: provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + strategy: Strategy, +} + +impl DeterminantDiversity +where + DP: provider::DataProvider, +{ + pub fn new( + index: Arc>, + queries: Arc>, + strategy: Strategy, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + })) + } +} + +impl Search for DeterminantDiversity +where + DP: provider::DataProvider, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, + DeterminantDiversitySearchParams: + for<'a> glue::SearchPostProcess, [T], DP::ExternalId> + Send + Sync, + T: AsyncFriendly + Clone, +{ + type Id = DP::ExternalId; + type Parameters = Parameters; + type Output = super::knn::Metrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { + search::IdCount::Fixed(parameters.inner.k_value()) + } + + async fn search( + &self, + parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> ANNResult + where + O: graph::SearchOutputBuffer + Send, + { + let context = DP::Context::default(); + let stats = self + .index + .search_with( + parameters.inner, + self.strategy.get(index)?, + parameters.processor, + &context, + self.queries.row(index), + buffer, + ) + .await?; + + Ok(super::knn::Metrics { + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct Summary { + pub setup: search::Setup, + pub parameters: Parameters, + pub end_to_end_latencies: Vec, + pub mean_latencies: Vec, + pub p90_latencies: Vec, + pub p99_latencies: Vec, + pub recall: recall::RecallMetrics, + pub mean_cmps: f64, + pub mean_hops: f64, +} + +pub struct Aggregator<'a, I> { + groundtruth: &'a dyn crate::recall::Rows, + recall_k: usize, + recall_n: usize, +} + +impl<'a, I> Aggregator<'a, I> { + pub fn new( + groundtruth: &'a dyn crate::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> Self { + Self { + groundtruth, + recall_k, + recall_n, + } + } +} + +impl search::Aggregate for Aggregator<'_, I> +where + I: crate::recall::RecallCompatible, +{ + type Output = Summary; + + fn aggregate( + &mut self, + run: search::Run, + mut results: Vec>, + ) -> anyhow::Result

{ + let recall = match results.first() { + Some(first) => crate::recall::knn( + self.groundtruth, + None, + first.ids().as_rows(), + self.recall_k, + self.recall_n, + true, + )?, + None => anyhow::bail!("Results must be non-empty"), + }; + + let mut mean_latencies = Vec::with_capacity(results.len()); + let mut p90_latencies = Vec::with_capacity(results.len()); + let mut p99_latencies = Vec::with_capacity(results.len()); + + results.iter_mut().for_each(|r| { + match percentiles::compute_percentiles(r.latencies_mut()) { + Ok(values) => { + let percentiles::Percentiles { mean, p90, p99, .. } = values; + mean_latencies.push(mean); + p90_latencies.push(p90); + p99_latencies.push(p99); + } + Err(_) => { + let zero = MicroSeconds::new(0); + mean_latencies.push(0.0); + p90_latencies.push(zero); + p99_latencies.push(zero); + } + } + }); + + Ok(Summary { + setup: run.setup().clone(), + parameters: *run.parameters(), + end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), + recall, + mean_latencies, + p90_latencies, + p99_latencies, + mean_cmps: utils::average_all( + results + .iter() + .flat_map(|r| r.output().iter().map(|o| o.comparisons)), + ), + mean_hops: utils::average_all( + results + .iter() + .flat_map(|r| r.output().iter().map(|o| o.hops)), + ), + }) + } +} diff --git a/diskann-benchmark-core/src/search/graph/mod.rs b/diskann-benchmark-core/src/search/graph/mod.rs index eddb4fbcf..cfcecb0db 100644 --- a/diskann-benchmark-core/src/search/graph/mod.rs +++ b/diskann-benchmark-core/src/search/graph/mod.rs @@ -3,6 +3,7 @@ * Licensed under the MIT license. */ +pub mod determinant_diversity; pub mod knn; pub mod multihop; pub mod range; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index fa4a77078..a733ada8e 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -24,7 +24,10 @@ use diskann_benchmark_runner::{ }; use diskann_providers::{ index::diskann_async, - model::{configuration::IndexConfiguration, graph::provider::async_::common}, + model::{ + configuration::IndexConfiguration, + graph::provider::async_::{common, DeterminantDiversitySearchParams}, + }, }; use diskann_utils::{ future::AsyncFriendly, @@ -350,6 +353,8 @@ where + provider::SetElement<[T]>, T: SampleableForStart + std::fmt::Debug + Copy + AsyncFriendly + bytemuck::Pod, S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, + DeterminantDiversitySearchParams: + for<'a> glue::SearchPostProcess, [T], DP::ExternalId> + Send + Sync, { match &input { SearchPhase::Topk(search_phase) => { @@ -366,19 +371,40 @@ where let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&search_phase.groundtruth))?; - let knn = benchmark_core::search::graph::KNN::new( - index, - queries, - benchmark_core::search::graph::Strategy::broadcast(search_strategy), - )?; - let steps = search::knn::SearchSteps::new( search_phase.reps, &search_phase.num_threads, &search_phase.runs, ); - let search_results = search::knn::run(&knn, &groundtruth, steps)?; + let search_results = if let (Some(eta), Some(power)) = ( + search_phase.determinant_diversity_eta, + search_phase.determinant_diversity_power, + ) { + let knn = + benchmark_core::search::graph::determinant_diversity::DeterminantDiversity::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + )?; + + search::knn::run_determinant_diversity( + &knn, + &groundtruth, + steps, + eta, + power, + search_phase.determinant_diversity_results_k, + )? + } else { + let knn = benchmark_core::search::graph::KNN::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(search_strategy), + )?; + + search::knn::run(&knn, &groundtruth, steps)? + }; result.append(AggregatedSearchResults::Topk(search_results)); Ok(result) } diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 1d6102f9b..1f9c2e50a 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -155,6 +155,42 @@ impl SearchResults { mean_hops: mean_hops as f32, } } + + pub fn new_determinant_diversity( + summary: benchmark_core::search::graph::determinant_diversity::Summary, + ) -> Self { + let benchmark_core::search::graph::determinant_diversity::Summary { + setup, + parameters, + end_to_end_latencies, + mean_latencies, + p90_latencies, + p99_latencies, + recall, + mean_cmps, + mean_hops, + .. + } = summary; + + let qps = end_to_end_latencies + .iter() + .map(|latency| recall.num_queries as f64 / latency.as_seconds()) + .collect(); + + Self { + num_tasks: setup.tasks.into(), + search_n: parameters.inner.k_value().get(), + search_l: parameters.inner.l_value().get(), + qps, + search_latencies: end_to_end_latencies, + mean_latencies, + p90_latencies, + p99_latencies, + recall: (&recall).into(), + mean_cmps: mean_cmps as f32, + mean_hops: mean_hops as f32, + } + } } fn format_search_results_table( diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 915b8eca6..30560a6cd 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -6,6 +6,7 @@ use std::{num::NonZeroUsize, sync::Arc}; use diskann_benchmark_core::{self as benchmark_core, search as core_search}; +use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use crate::{backend::index::result::SearchResults, inputs::async_::GraphSearch}; @@ -35,6 +36,60 @@ pub(crate) fn run( groundtruth: &dyn benchmark_core::recall::Rows, steps: SearchSteps<'_>, ) -> anyhow::Result> { + run_search(runner, groundtruth, steps, |setup, search_l, search_n| { + let search_params = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + core_search::Run::new(search_params, setup) + }) +} + +type Run = core_search::Run; +pub(crate) trait Knn { + fn search_all( + &self, + parameters: Vec, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result>; +} + +type DeterminantRun = + core_search::Run; + +pub(crate) fn run_determinant_diversity( + runner: &dyn DeterminantDiversityKnn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + eta: f64, + power: f64, + results_k: Option, +) -> anyhow::Result> { + run_search_determinant_diversity(runner, groundtruth, steps, |setup, search_l, search_n| { + let base = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); + let processor = + DeterminantDiversitySearchParams::new(results_k.unwrap_or(search_n), eta, power) + .map_err(|err| { + anyhow::anyhow!("Invalid determinant-diversity parameters: {err}") + })?; + + let search_params = + diskann_benchmark_core::search::graph::determinant_diversity::Parameters { + inner: base, + processor, + }; + Ok(core_search::Run::new(search_params, setup)) + }) +} + +fn run_search( + runner: &dyn Knn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + builder: F, +) -> anyhow::Result> +where + F: Fn(core_search::Setup, usize, usize) -> Run, +{ let mut all = Vec::new(); for threads in steps.num_tasks.iter() { @@ -48,12 +103,7 @@ pub(crate) fn run( let parameters: Vec<_> = run .search_l .iter() - .map(|search_l| { - let search_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); - - core_search::Run::new(search_params, setup.clone()) - }) + .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) .collect(); all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); @@ -63,11 +113,42 @@ pub(crate) fn run( Ok(all) } -type Run = core_search::Run; -pub(crate) trait Knn { +fn run_search_determinant_diversity( + runner: &dyn DeterminantDiversityKnn, + groundtruth: &dyn benchmark_core::recall::Rows, + steps: SearchSteps<'_>, + builder: F, +) -> anyhow::Result> +where + F: Fn(core_search::Setup, usize, usize) -> anyhow::Result, +{ + let mut all = Vec::new(); + + for threads in steps.num_tasks.iter() { + for run in steps.runs.iter() { + let setup = core_search::Setup { + threads: *threads, + tasks: *threads, + reps: steps.reps, + }; + + let parameters: Vec<_> = run + .search_l + .iter() + .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) + .collect::>>()?; + + all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); + } + } + + Ok(all) +} + +pub(crate) trait DeterminantDiversityKnn { fn search_all( &self, - parameters: Vec, + parameters: Vec, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -129,3 +210,37 @@ where Ok(results.into_iter().map(SearchResults::new).collect()) } } + +impl DeterminantDiversityKnn + for Arc> +where + DP: diskann::provider::DataProvider, + core_search::graph::determinant_diversity::DeterminantDiversity: core_search::Search< + Id = DP::InternalId, + Parameters = diskann_benchmark_core::search::graph::determinant_diversity::Parameters, + Output = core_search::graph::knn::Metrics, + >, +{ + fn search_all( + &self, + parameters: Vec, + groundtruth: &dyn benchmark_core::recall::Rows, + recall_k: usize, + recall_n: usize, + ) -> anyhow::Result> { + let results = core_search::search_all( + self.clone(), + parameters.into_iter(), + core_search::graph::determinant_diversity::Aggregator::new( + groundtruth, + recall_k, + recall_n, + ), + )?; + + Ok(results + .into_iter() + .map(SearchResults::new_determinant_diversity) + .collect()) + } +} diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 82bb37dae..1918d6676 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -66,7 +66,9 @@ mod imp { }; use diskann_providers::{ index::diskann_async::{self}, - model::graph::provider::async_::{common::NoDeletes, inmem}, + model::graph::provider::async_::{ + common::NoDeletes, inmem, DeterminantDiversitySearchParams, + }, }; use diskann_quantization::alloc::GlobalAllocator; use diskann_utils::views::Matrix; @@ -331,15 +333,34 @@ mod imp { ); for &layout in self.input.query_layouts.iter() { - let knn = benchmark_core::search::graph::KNN::new( - index.clone(), - queries.clone(), - benchmark_core::search::graph::Strategy::broadcast( - inmem::spherical::Quantized::search(layout.into()), - ), - )?; - - let search_results = search::knn::run(&knn, &groundtruth, steps)?; + let strategy = inmem::spherical::Quantized::search(layout.into()); + let search_results = if let (Some(eta), Some(power)) = ( + search_phase.determinant_diversity_eta, + search_phase.determinant_diversity_power, + ) { + let knn = benchmark_core::search::graph::determinant_diversity::DeterminantDiversity::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast(strategy), + )?; + + search::knn::run_determinant_diversity( + &knn, + &groundtruth, + steps, + eta, + power, + search_phase.determinant_diversity_results_k, + )? + } else { + let knn = benchmark_core::search::graph::KNN::new( + index.clone(), + queries.clone(), + benchmark_core::search::graph::Strategy::broadcast(strategy), + )?; + + search::knn::run(&knn, &groundtruth, steps)? + }; result.append(SearchRun { layout, results: AggregatedSearchResults::Topk(search_results), diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 19230977d..54b795635 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -123,6 +123,9 @@ pub(crate) struct TopkSearchPhase { pub(crate) queries: InputFile, pub(crate) groundtruth: InputFile, pub(crate) reps: NonZeroUsize, + pub(crate) determinant_diversity_eta: Option, + pub(crate) determinant_diversity_power: Option, + pub(crate) determinant_diversity_results_k: Option, // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, @@ -139,6 +142,36 @@ impl CheckDeserialization for TopkSearchPhase { .with_context(|| format!("search run {}", i))?; } + if self.determinant_diversity_eta.is_some() != self.determinant_diversity_power.is_some() { + return Err(anyhow!( + "determinant_diversity_eta and determinant_diversity_power must either both be set or both be omitted" + )); + } + + if let Some(eta) = self.determinant_diversity_eta { + if eta < 0.0 { + return Err(anyhow!( + "determinant_diversity_eta must be >= 0.0, got {}", + eta + )); + } + } + + if let Some(power) = self.determinant_diversity_power { + if power <= 0.0 { + return Err(anyhow!( + "determinant_diversity_power must be > 0.0, got {}", + power + )); + } + } + + if let Some(k) = self.determinant_diversity_results_k { + if k == 0 { + return Err(anyhow!("determinant_diversity_results_k must be > 0")); + } + } + Ok(()) } } @@ -164,6 +197,9 @@ impl Example for TopkSearchPhase { queries: InputFile::new("path/to/queries"), groundtruth: InputFile::new("path/to/groundtruth"), reps: REPS, + determinant_diversity_eta: None, + determinant_diversity_power: None, + determinant_diversity_results_k: None, num_threads: THREAD_COUNTS.to_vec(), runs, } diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs new file mode 100644 index 000000000..75e84bbf1 --- /dev/null +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -0,0 +1,503 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Determinant-diversity search post-processing. + +use std::future::Future; + +use diskann::{ + ANNError, + graph::{SearchOutputBuffer, glue}, + neighbor::Neighbor, + provider::BuildQueryComputer, + utils::{IntoUsize, VectorRepr}, +}; +use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; + +use super::{ + inmem::GetFullPrecision, + postprocess::{AsDeletionCheck, DeletionCheck}, +}; + +#[derive(Debug)] +pub enum DeterminantDiversityError { + InvalidTopK { top_k: usize }, + InvalidEta { eta: f64 }, + InvalidPower { power: f64 }, +} + +impl std::fmt::Display for DeterminantDiversityError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidTopK { top_k } => write!(f, "top_k must be > 0, got {top_k}"), + Self::InvalidEta { eta } => write!(f, "eta must be >= 0.0, got {eta}"), + Self::InvalidPower { power } => write!(f, "power must be > 0.0, got {power}"), + } + } +} + +impl std::error::Error for DeterminantDiversityError {} + +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversitySearchParams { + pub top_k: usize, + pub determinant_diversity_eta: f64, + pub determinant_diversity_power: f64, +} + +impl DeterminantDiversitySearchParams { + pub fn new( + top_k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, + ) -> Result { + if top_k == 0 { + return Err(DeterminantDiversityError::InvalidTopK { top_k }); + } + + if determinant_diversity_eta < 0.0 || !determinant_diversity_eta.is_finite() { + return Err(DeterminantDiversityError::InvalidEta { + eta: determinant_diversity_eta, + }); + } + + if determinant_diversity_power <= 0.0 || !determinant_diversity_power.is_finite() { + return Err(DeterminantDiversityError::InvalidPower { + power: determinant_diversity_power, + }); + } + + Ok(Self { + top_k, + determinant_diversity_eta, + determinant_diversity_power, + }) + } +} + +impl glue::SearchPostProcess for DeterminantDiversitySearchParams +where + T: VectorRepr, + A: BuildQueryComputer<[T], Id = u32> + GetFullPrecision + AsDeletionCheck, +{ + type Error = ANNError; + + fn post_process( + &self, + accessor: &mut A, + query: &[T], + _computer: &A::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let result = (|| { + let query_f32 = T::as_f32(query).map_err(Into::into)?; + let full = accessor.as_full_precision(); + let checker = accessor.as_deletion_check(); + + let mut candidates_with_vectors = Vec::new(); + for candidate in candidates { + if checker.deletion_check(candidate.id) { + continue; + } + + let vector = unsafe { full.get_vector_sync(candidate.id.into_usize()) }; + let vector_f32 = T::as_f32(vector).map_err(Into::into)?; + candidates_with_vectors.push(( + candidate.id, + candidate.distance, + vector_f32.to_vec(), + )); + } + + let borrowed: Vec<(u32, f32, &[f32])> = candidates_with_vectors + .iter() + .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) + .collect(); + + let reranked = determinant_diversity_post_process( + borrowed, + &query_f32[..], + self.top_k, + self.determinant_diversity_eta, + self.determinant_diversity_power, + ); + + Ok(output.extend(reranked)) + })(); + + std::future::ready(result) + } +} + +pub fn determinant_diversity_post_process( + candidates: Vec<(Id, f32, &[f32])>, + query: &[f32], + k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let k = k.min(candidates.len()); + if k == 0 { + return Vec::new(); + } + + let candidates_f32: Vec<(Id, f32, Vec)> = candidates + .into_iter() + .map(|(id, dist, v)| (id, dist, v.to_vec())) + .collect(); + + if candidates_f32[0].2.is_empty() { + return Vec::new(); + } + + if determinant_diversity_eta > 0.0 { + post_process_with_eta_f32( + candidates_f32, + query, + k, + determinant_diversity_eta, + determinant_diversity_power, + ) + } else { + post_process_greedy_orthogonalization_f32( + candidates_f32, + query, + k, + determinant_diversity_power, + ) + } +} + +fn post_process_with_eta_f32( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_eta: f64, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + let eta = determinant_diversity_eta as f32; + let power = determinant_diversity_power; + + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + if candidates[0].2.is_empty() { + return Vec::new(); + } + + let inv_sqrt_eta = 1.0 / eta.sqrt(); + let mut residuals = Vec::with_capacity(n); + let mut norms_sq = Vec::with_capacity(n); + + for (_, _, v) in &candidates { + let similarity = dot_product(v, query); + let scale = similarity.max(0.0).powf(power as f32) * inv_sqrt_eta; + let residual: Vec = v.iter().map(|&x| x * scale).collect(); + let norm_sq = dot_product(&residual, &residual); + residuals.push(residual); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + + for _ in 0..k { + let best_idx = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i); + + let Some(selected_index) = best_idx else { + break; + }; + + selected.push(selected_index); + available[selected_index] = false; + + if selected.len() == k { + break; + } + + let norm_factor = 1.0 / (1.0 + norms_sq[selected_index]).sqrt(); + let q_scaled: Vec = residuals[selected_index] + .iter() + .map(|&x| x * norm_factor) + .collect(); + + let mut projections = Vec::with_capacity(n); + for i in 0..n { + if !available[i] { + projections.push(0.0); + } else { + let alpha = dot_product(&residuals[selected_index], &residuals[i]) + * norm_factor + * norm_factor; + projections.push(alpha); + } + } + + for i in 0..n { + if !available[i] { + continue; + } + + let alpha = projections[i]; + for (residual, &q_value) in residuals[i].iter_mut().zip(q_scaled.iter()) { + *residual -= alpha * q_value; + } + + norms_sq[i] = (norms_sq[i] - alpha * alpha).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +fn post_process_greedy_orthogonalization_f32( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_power: f64, +) -> Vec<(Id, f32)> { + let power = determinant_diversity_power; + + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + let mut residuals = Vec::with_capacity(n); + let mut norms_sq = Vec::with_capacity(n); + + for (_, _, v) in &candidates { + let similarity = dot_product(v, query); + let scale = similarity.max(0.0).powf(power as f32); + let residual: Vec = v.iter().map(|&x| x * scale).collect(); + let norm_sq = dot_product(&residual, &residual); + residuals.push(residual); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + + for _ in 0..k { + let best = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let Some((best_index, _)) = best else { + break; + }; + + let best_norm_sq = norms_sq[best_index]; + selected.push(best_index); + available[best_index] = false; + + if selected.len() == k { + break; + } + + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq_star = 1.0 / best_norm_sq; + let r_star_copy = residuals[best_index].clone(); + + let mut projections = Vec::with_capacity(n); + for j in 0..n { + if !available[j] { + projections.push(0.0); + } else { + let projection = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; + projections.push(projection); + } + } + + for j in 0..n { + if !available[j] { + continue; + } + + let projection = projections[j]; + for (residual, &star) in residuals[j].iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + norms_sq[j] = (norms_sq[j] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +#[inline] +fn dot_product(a: &[f32], b: &[f32]) -> f32 { + >>::evaluate(a, b) + .into_inner() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validation_valid_params() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, 2.0); + assert!(result.is_ok()); + } + + #[test] + fn test_validation_zero_top_k() { + let result = DeterminantDiversitySearchParams::new(0, 0.01, 2.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidTopK { top_k: 0 }) + )); + } + + #[test] + fn test_validation_negative_eta() { + let result = DeterminantDiversitySearchParams::new(10, -0.01, 2.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidEta { .. }) + )); + } + + #[test] + fn test_validation_zero_power() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, 0.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidPower { .. }) + )); + } + + #[test] + fn test_validation_negative_power() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, -1.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidPower { .. }) + )); + } + + #[test] + fn test_validation_nan_eta() { + let result = DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidEta { .. }) + )); + } + + #[test] + fn test_validation_infinity_power() { + let result = DeterminantDiversitySearchParams::new(10, 0.01, f64::INFINITY); + assert!(matches!( + result, + Err(DeterminantDiversityError::InvalidPower { .. }) + )); + } + + #[test] + fn test_determinant_diversity_post_process_with_eta() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.0f32, 1.0, 0.0]; + let v3 = vec![0.0f32, 0.0, 1.0]; + let candidates = vec![ + (1u32, 0.5f32, v1.as_slice()), + (2u32, 0.3f32, v2.as_slice()), + (3u32, 0.7f32, v3.as_slice()), + ]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); + assert_eq!(result.len(), 3); + } + + #[test] + fn test_determinant_diversity_post_process_enabled_greedy() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.99f32, 0.1, 0.0]; + let v3 = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![ + (1u32, 0.5f32, v1.as_slice()), + (2u32, 0.3f32, v2.as_slice()), + (3u32, 0.4f32, v3.as_slice()), + ]; + let query = vec![1.0, 1.0, 0.0]; + + let result = determinant_diversity_post_process(candidates, &query, 2, 0.0, 1.0); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_determinant_diversity_post_process_empty() { + let candidates: Vec<(u32, f32, &[f32])> = vec![]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); + assert!(result.is_empty()); + } + + #[test] + fn test_determinant_diversity_post_process_k_larger_than_candidates() { + let v1 = vec![1.0f32, 0.0, 0.0]; + let v2 = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![(1u32, 0.5f32, v1.as_slice()), (2u32, 0.3f32, v2.as_slice())]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); + assert_eq!(result.len(), 2); + } +} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 3d89359e2..774c5530d 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -9,6 +9,11 @@ pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; pub(crate) mod postprocess; +mod determinant_diversity_post_process; +pub use determinant_diversity_post_process::{ + DeterminantDiversityError, DeterminantDiversitySearchParams, determinant_diversity_post_process, +}; + pub mod distances; pub mod memory_vector_provider; diff --git a/tmp/wiki_compare_determinant_diversity.json b/tmp/wiki_compare_determinant_diversity.json new file mode 100644 index 000000000..c2c4972c3 --- /dev/null +++ b/tmp/wiki_compare_determinant_diversity.json @@ -0,0 +1,63 @@ +{ + "search_directories": [ + "C:\\wikipedia_dataset" + ], + "jobs": [ + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" + }, + "search_phase": { + "search-type": "topk", + "queries": "C:\\wikipedia_dataset\\query.bin", + "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", + "reps": 1, + "determinant_diversity_eta": null, + "determinant_diversity_power": null, + "determinant_diversity_results_k": null, + "num_threads": [8], + "runs": [ + { + "search_n": 10, + "search_l": [20, 30, 40, 50, 100, 200], + "recall_k": 10 + } + ] + } + } + }, + { + "type": "async-index-build", + "content": { + "source": { + "index-source": "Load", + "data_type": "float32", + "distance": "squared_l2", + "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" + }, + "search_phase": { + "search-type": "topk", + "queries": "C:\\wikipedia_dataset\\query.bin", + "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", + "reps": 1, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 2.0, + "determinant_diversity_results_k": 10, + "num_threads": [8], + "runs": [ + { + "search_n": 10, + "search_l": [20, 30, 40, 50, 100, 200], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/tmp/wiki_compare_determinant_diversity_results.json b/tmp/wiki_compare_determinant_diversity_results.json new file mode 100644 index 000000000..27672d734 --- /dev/null +++ b/tmp/wiki_compare_determinant_diversity_results.json @@ -0,0 +1,454 @@ +[ + { + "input": { + "content": { + "search_phase": { + "determinant_diversity_eta": null, + "determinant_diversity_power": null, + "determinant_diversity_results_k": null, + "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", + "num_threads": [ + 8 + ], + "queries": "C:\\wikipedia_dataset\\query.bin", + "reps": 1, + "runs": [ + { + "recall_k": 10, + "search_l": [ + 20, + 30, + 40, + 50, + 100, + 200 + ], + "search_n": 10 + } + ], + "search-type": "topk" + }, + "source": { + "data_type": "float32", + "distance": "squared_l2", + "index-source": "Load", + "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" + } + }, + "type": "async-index-build" + }, + "results": { + "build": null, + "search": { + "Topk": [ + { + "mean_cmps": 1514.68994140625, + "mean_hops": 28.469999313354492, + "mean_latencies": [ + 1413.43 + ], + "num_tasks": 8, + "p90_latencies": [ + 2157 + ], + "p99_latencies": [ + 3712 + ], + "qps": [ + 4886.15264340858 + ], + "recall": { + "average": 0.406, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 20, + "search_latencies": [ + 20466 + ], + "search_n": 10 + }, + { + "mean_cmps": 1916.7900390625, + "mean_hops": 37.97999954223633, + "mean_latencies": [ + 1921.54 + ], + "num_tasks": 8, + "p90_latencies": [ + 2990 + ], + "p99_latencies": [ + 5072 + ], + "qps": [ + 3545.8478122118995 + ], + "recall": { + "average": 0.409, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 30, + "search_latencies": [ + 28202 + ], + "search_n": 10 + }, + { + "mean_cmps": 2288.7900390625, + "mean_hops": 47.43000030517578, + "mean_latencies": [ + 2935.55 + ], + "num_tasks": 8, + "p90_latencies": [ + 4750 + ], + "p99_latencies": [ + 7347 + ], + "qps": [ + 2267.5222784063853 + ], + "recall": { + "average": 0.417, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 40, + "search_latencies": [ + 44101 + ], + "search_n": 10 + }, + { + "mean_cmps": 2652.449951171875, + "mean_hops": 57.040000915527344, + "mean_latencies": [ + 2716.51 + ], + "num_tasks": 8, + "p90_latencies": [ + 4333 + ], + "p99_latencies": [ + 6903 + ], + "qps": [ + 2401.5946588534784 + ], + "recall": { + "average": 0.419, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 50, + "search_latencies": [ + 41639 + ], + "search_n": 10 + }, + { + "mean_cmps": 4426.0400390625, + "mean_hops": 106.4000015258789, + "mean_latencies": [ + 4522.62 + ], + "num_tasks": 8, + "p90_latencies": [ + 6900 + ], + "p99_latencies": [ + 8430 + ], + "qps": [ + 1539.7644160443451 + ], + "recall": { + "average": 0.425, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 100, + "search_latencies": [ + 64945 + ], + "search_n": 10 + }, + { + "mean_cmps": 7640.009765625, + "mean_hops": 205.7100067138672, + "mean_latencies": [ + 8594.32 + ], + "num_tasks": 8, + "p90_latencies": [ + 12535 + ], + "p99_latencies": [ + 21987 + ], + "qps": [ + 831.1861025683651 + ], + "recall": { + "average": 0.432, + "maximum": 10, + "minimum": 1, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 200, + "search_latencies": [ + 120310 + ], + "search_n": 10 + } + ] + } + } + }, + { + "input": { + "content": { + "search_phase": { + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 2.0, + "determinant_diversity_results_k": 10, + "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", + "num_threads": [ + 8 + ], + "queries": "C:\\wikipedia_dataset\\query.bin", + "reps": 1, + "runs": [ + { + "recall_k": 10, + "search_l": [ + 20, + 30, + 40, + 50, + 100, + 200 + ], + "search_n": 10 + } + ], + "search-type": "topk" + }, + "source": { + "data_type": "float32", + "distance": "squared_l2", + "index-source": "Load", + "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" + } + }, + "type": "async-index-build" + }, + "results": { + "build": null, + "search": { + "Topk": [ + { + "mean_cmps": 1514.68994140625, + "mean_hops": 28.469999313354492, + "mean_latencies": [ + 2416.91 + ], + "num_tasks": 8, + "p90_latencies": [ + 3698 + ], + "p99_latencies": [ + 8820 + ], + "qps": [ + 2810.883741848437 + ], + "recall": { + "average": 0.408, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 20, + "search_latencies": [ + 35576 + ], + "search_n": 10 + }, + { + "mean_cmps": 1916.7900390625, + "mean_hops": 37.97999954223633, + "mean_latencies": [ + 3083.14 + ], + "num_tasks": 8, + "p90_latencies": [ + 4464 + ], + "p99_latencies": [ + 6221 + ], + "qps": [ + 2288.8532845044633 + ], + "recall": { + "average": 0.412, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 30, + "search_latencies": [ + 43690 + ], + "search_n": 10 + }, + { + "mean_cmps": 2288.7900390625, + "mean_hops": 47.43000030517578, + "mean_latencies": [ + 4265.17 + ], + "num_tasks": 8, + "p90_latencies": [ + 6233 + ], + "p99_latencies": [ + 10556 + ], + "qps": [ + 1710.1617813045114 + ], + "recall": { + "average": 0.418, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 40, + "search_latencies": [ + 58474 + ], + "search_n": 10 + }, + { + "mean_cmps": 2652.449951171875, + "mean_hops": 57.040000915527344, + "mean_latencies": [ + 4419.41 + ], + "num_tasks": 8, + "p90_latencies": [ + 6687 + ], + "p99_latencies": [ + 10292 + ], + "qps": [ + 1623.666563834451 + ], + "recall": { + "average": 0.422, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 50, + "search_latencies": [ + 61589 + ], + "search_n": 10 + }, + { + "mean_cmps": 4426.0400390625, + "mean_hops": 106.4000015258789, + "mean_latencies": [ + 8935.56 + ], + "num_tasks": 8, + "p90_latencies": [ + 14054 + ], + "p99_latencies": [ + 17442 + ], + "qps": [ + 818.7393052178256 + ], + "recall": { + "average": 0.428, + "maximum": 10, + "minimum": 0, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 100, + "search_latencies": [ + 122139 + ], + "search_n": 10 + }, + { + "mean_cmps": 7640.009765625, + "mean_hops": 205.7100067138672, + "mean_latencies": [ + 16669.73 + ], + "num_tasks": 8, + "p90_latencies": [ + 22001 + ], + "p99_latencies": [ + 72498 + ], + "qps": [ + 436.14603913974554 + ], + "recall": { + "average": 0.437, + "maximum": 10, + "minimum": 1, + "num_queries": 100, + "recall_k": 10, + "recall_n": 10 + }, + "search_l": 200, + "search_latencies": [ + 229281 + ], + "search_n": 10 + } + ] + } + } + } +] \ No newline at end of file From 329793b46f660523123f7af80a8c0665f3db4fcc Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 23 Mar 2026 16:50:29 +0530 Subject: [PATCH 42/47] Remove temporary benchmark results artifact --- ...compare_determinant_diversity_results.json | 454 ------------------ 1 file changed, 454 deletions(-) delete mode 100644 tmp/wiki_compare_determinant_diversity_results.json diff --git a/tmp/wiki_compare_determinant_diversity_results.json b/tmp/wiki_compare_determinant_diversity_results.json deleted file mode 100644 index 27672d734..000000000 --- a/tmp/wiki_compare_determinant_diversity_results.json +++ /dev/null @@ -1,454 +0,0 @@ -[ - { - "input": { - "content": { - "search_phase": { - "determinant_diversity_eta": null, - "determinant_diversity_power": null, - "determinant_diversity_results_k": null, - "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", - "num_threads": [ - 8 - ], - "queries": "C:\\wikipedia_dataset\\query.bin", - "reps": 1, - "runs": [ - { - "recall_k": 10, - "search_l": [ - 20, - 30, - 40, - 50, - 100, - 200 - ], - "search_n": 10 - } - ], - "search-type": "topk" - }, - "source": { - "data_type": "float32", - "distance": "squared_l2", - "index-source": "Load", - "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" - } - }, - "type": "async-index-build" - }, - "results": { - "build": null, - "search": { - "Topk": [ - { - "mean_cmps": 1514.68994140625, - "mean_hops": 28.469999313354492, - "mean_latencies": [ - 1413.43 - ], - "num_tasks": 8, - "p90_latencies": [ - 2157 - ], - "p99_latencies": [ - 3712 - ], - "qps": [ - 4886.15264340858 - ], - "recall": { - "average": 0.406, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 20, - "search_latencies": [ - 20466 - ], - "search_n": 10 - }, - { - "mean_cmps": 1916.7900390625, - "mean_hops": 37.97999954223633, - "mean_latencies": [ - 1921.54 - ], - "num_tasks": 8, - "p90_latencies": [ - 2990 - ], - "p99_latencies": [ - 5072 - ], - "qps": [ - 3545.8478122118995 - ], - "recall": { - "average": 0.409, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 30, - "search_latencies": [ - 28202 - ], - "search_n": 10 - }, - { - "mean_cmps": 2288.7900390625, - "mean_hops": 47.43000030517578, - "mean_latencies": [ - 2935.55 - ], - "num_tasks": 8, - "p90_latencies": [ - 4750 - ], - "p99_latencies": [ - 7347 - ], - "qps": [ - 2267.5222784063853 - ], - "recall": { - "average": 0.417, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 40, - "search_latencies": [ - 44101 - ], - "search_n": 10 - }, - { - "mean_cmps": 2652.449951171875, - "mean_hops": 57.040000915527344, - "mean_latencies": [ - 2716.51 - ], - "num_tasks": 8, - "p90_latencies": [ - 4333 - ], - "p99_latencies": [ - 6903 - ], - "qps": [ - 2401.5946588534784 - ], - "recall": { - "average": 0.419, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 50, - "search_latencies": [ - 41639 - ], - "search_n": 10 - }, - { - "mean_cmps": 4426.0400390625, - "mean_hops": 106.4000015258789, - "mean_latencies": [ - 4522.62 - ], - "num_tasks": 8, - "p90_latencies": [ - 6900 - ], - "p99_latencies": [ - 8430 - ], - "qps": [ - 1539.7644160443451 - ], - "recall": { - "average": 0.425, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 100, - "search_latencies": [ - 64945 - ], - "search_n": 10 - }, - { - "mean_cmps": 7640.009765625, - "mean_hops": 205.7100067138672, - "mean_latencies": [ - 8594.32 - ], - "num_tasks": 8, - "p90_latencies": [ - 12535 - ], - "p99_latencies": [ - 21987 - ], - "qps": [ - 831.1861025683651 - ], - "recall": { - "average": 0.432, - "maximum": 10, - "minimum": 1, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 200, - "search_latencies": [ - 120310 - ], - "search_n": 10 - } - ] - } - } - }, - { - "input": { - "content": { - "search_phase": { - "determinant_diversity_eta": 0.01, - "determinant_diversity_power": 2.0, - "determinant_diversity_results_k": 10, - "groundtruth": "C:\\wikipedia_dataset\\groundtruth_k100.bin", - "num_threads": [ - 8 - ], - "queries": "C:\\wikipedia_dataset\\query.bin", - "reps": 1, - "runs": [ - { - "recall_k": 10, - "search_l": [ - 20, - 30, - 40, - 50, - 100, - 200 - ], - "search_n": 10 - } - ], - "search-type": "topk" - }, - "source": { - "data_type": "float32", - "distance": "squared_l2", - "index-source": "Load", - "load_path": "C:\\wikipedia_dataset\\wikipedia_saved_index" - } - }, - "type": "async-index-build" - }, - "results": { - "build": null, - "search": { - "Topk": [ - { - "mean_cmps": 1514.68994140625, - "mean_hops": 28.469999313354492, - "mean_latencies": [ - 2416.91 - ], - "num_tasks": 8, - "p90_latencies": [ - 3698 - ], - "p99_latencies": [ - 8820 - ], - "qps": [ - 2810.883741848437 - ], - "recall": { - "average": 0.408, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 20, - "search_latencies": [ - 35576 - ], - "search_n": 10 - }, - { - "mean_cmps": 1916.7900390625, - "mean_hops": 37.97999954223633, - "mean_latencies": [ - 3083.14 - ], - "num_tasks": 8, - "p90_latencies": [ - 4464 - ], - "p99_latencies": [ - 6221 - ], - "qps": [ - 2288.8532845044633 - ], - "recall": { - "average": 0.412, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 30, - "search_latencies": [ - 43690 - ], - "search_n": 10 - }, - { - "mean_cmps": 2288.7900390625, - "mean_hops": 47.43000030517578, - "mean_latencies": [ - 4265.17 - ], - "num_tasks": 8, - "p90_latencies": [ - 6233 - ], - "p99_latencies": [ - 10556 - ], - "qps": [ - 1710.1617813045114 - ], - "recall": { - "average": 0.418, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 40, - "search_latencies": [ - 58474 - ], - "search_n": 10 - }, - { - "mean_cmps": 2652.449951171875, - "mean_hops": 57.040000915527344, - "mean_latencies": [ - 4419.41 - ], - "num_tasks": 8, - "p90_latencies": [ - 6687 - ], - "p99_latencies": [ - 10292 - ], - "qps": [ - 1623.666563834451 - ], - "recall": { - "average": 0.422, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 50, - "search_latencies": [ - 61589 - ], - "search_n": 10 - }, - { - "mean_cmps": 4426.0400390625, - "mean_hops": 106.4000015258789, - "mean_latencies": [ - 8935.56 - ], - "num_tasks": 8, - "p90_latencies": [ - 14054 - ], - "p99_latencies": [ - 17442 - ], - "qps": [ - 818.7393052178256 - ], - "recall": { - "average": 0.428, - "maximum": 10, - "minimum": 0, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 100, - "search_latencies": [ - 122139 - ], - "search_n": 10 - }, - { - "mean_cmps": 7640.009765625, - "mean_hops": 205.7100067138672, - "mean_latencies": [ - 16669.73 - ], - "num_tasks": 8, - "p90_latencies": [ - 22001 - ], - "p99_latencies": [ - 72498 - ], - "qps": [ - 436.14603913974554 - ], - "recall": { - "average": 0.437, - "maximum": 10, - "minimum": 1, - "num_queries": 100, - "recall_k": 10, - "recall_n": 10 - }, - "search_l": 200, - "search_latencies": [ - 229281 - ], - "search_n": 10 - } - ] - } - } - } -] \ No newline at end of file From 595e94955e4daec902e5f7f8b4d4e6ecb34e5909 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Mon, 23 Mar 2026 17:18:21 +0530 Subject: [PATCH 43/47] Rename rag terminology to determinant_diversity --- diskann-benchmark-runner/src/any.rs | 2 +- ...ai-disk-determinant-diversity-compare.json | 52 ++++++ .../src/backend/disk_index/search.rs | 25 +++ diskann-benchmark/src/inputs/disk.rs | 54 ++++++ diskann-disk/src/build/builder/core.rs | 3 + .../src/search/provider/disk_provider.rs | 173 +++++++++++++++++- diskann-tools/src/utils/search_disk_index.rs | 3 + 7 files changed, 301 insertions(+), 11 deletions(-) create mode 100644 diskann-benchmark/example/openai-disk-determinant-diversity-compare.json diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs index dd57dbf69..c5e19b0e3 100644 --- a/diskann-benchmark-runner/src/any.rs +++ b/diskann-benchmark-runner/src/any.rs @@ -397,7 +397,7 @@ mod tests { let _: Type = value.convert::().unwrap(); // An invalid match should return an error. - let value = Any::new(0usize, "random-rag"); + let value = Any::new(0usize, "random-determinant-diversity"); let err = value.convert::>().unwrap_err(); let msg = err.to_string(); assert!(msg.contains("invalid dispatch"), "{}", msg); diff --git a/diskann-benchmark/example/openai-disk-determinant-diversity-compare.json b/diskann-benchmark/example/openai-disk-determinant-diversity-compare.json new file mode 100644 index 000000000..e4c61c24f --- /dev/null +++ b/diskann-benchmark/example/openai-disk-determinant-diversity-compare.json @@ -0,0 +1,52 @@ +{ + "search_directories": [ + "C:/data/openai" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/data/openai/openai_index_normal" + }, + "search_phase": { + "queries": "openai_query.bin", + "groundtruth": "openai_gt_50.bin", + "search_list": [100, 200, 400], + "beam_width": 4, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null + } + } + }, + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Load", + "data_type": "float32", + "load_path": "C:/data/openai/openai_index_normal" + }, + "search_phase": { + "queries": "openai_query.bin", + "groundtruth": "openai_gt_50.bin", + "search_list": [100, 200, 400], + "beam_width": 4, + "recall_at": 10, + "num_threads": 8, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "is_determinant_diversity_search": true, + "determinant_diversity_eta": 0.01, + "determinant_diversity_power": 2.0 + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 65e5804a7..ea1f0f130 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -42,6 +42,9 @@ pub(super) struct DiskSearchStats { pub(super) beam_width: usize, pub(super) recall_at: u32, pub(crate) is_flat_search: bool, + pub(crate) is_determinant_diversity_search: bool, + pub(crate) determinant_diversity_eta: Option, + pub(crate) determinant_diversity_power: Option, pub(crate) distance: SimilarityMeasure, pub(crate) uses_vector_filters: bool, pub(super) num_nodes_to_cache: Option, @@ -276,6 +279,9 @@ where Some(search_params.beam_width), vector_filter, search_params.is_flat_search, + search_params.is_determinant_diversity_search, + search_params.determinant_diversity_eta, + search_params.determinant_diversity_power, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; @@ -341,6 +347,9 @@ where beam_width: search_params.beam_width, recall_at: search_params.recall_at, is_flat_search: search_params.is_flat_search, + is_determinant_diversity_search: search_params.is_determinant_diversity_search, + determinant_diversity_eta: search_params.determinant_diversity_eta, + determinant_diversity_power: search_params.determinant_diversity_power, distance: search_params.distance, uses_vector_filters: search_params.vector_filters_file.is_some(), num_nodes_to_cache: search_params.num_nodes_to_cache, @@ -425,6 +434,22 @@ impl fmt::Display for DiskSearchStats { writeln!(f, "Beam width, : {}", self.beam_width)?; writeln!(f, "Recall at, : {}", self.recall_at)?; writeln!(f, "Flat search, : {}", self.is_flat_search)?; + writeln!( + f, + "Det-div search, : {}", + self.is_determinant_diversity_search + )?; + writeln!( + f, + "Det-div params, : {}", + match ( + self.determinant_diversity_eta, + self.determinant_diversity_power, + ) { + (Some(eta), Some(power)) => format!("eta={eta}, power={power}"), + _ => "None".to_string(), + } + )?; writeln!(f, "Distance, : {}", self.distance)?; writeln!(f, "Vector filters, : {}", self.uses_vector_filters)?; writeln!( diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index bf843d72f..f5058f321 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -81,6 +81,12 @@ pub(crate) struct DiskSearchPhase { pub(crate) search_list: Vec, pub(crate) recall_at: u32, pub(crate) is_flat_search: bool, + #[serde(default)] + pub(crate) is_determinant_diversity_search: bool, + #[serde(default)] + pub(crate) determinant_diversity_eta: Option, + #[serde(default)] + pub(crate) determinant_diversity_power: Option, pub(crate) distance: SimilarityMeasure, pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, @@ -224,6 +230,35 @@ impl CheckDeserialization for DiskSearchPhase { if self.num_threads == 0 { anyhow::bail!("num_threads must be positive"); } + + if self.is_determinant_diversity_search { + if self.is_flat_search { + anyhow::bail!( + "is_determinant_diversity_search is not supported when is_flat_search is true" + ); + } + + let eta = self.determinant_diversity_eta.unwrap_or(0.01); + let power = self.determinant_diversity_power.unwrap_or(2.0); + + if eta < 0.0 || !eta.is_finite() { + anyhow::bail!("determinant_diversity_eta must be >= 0.0 and finite, got {eta}"); + } + + if power <= 0.0 || !power.is_finite() { + anyhow::bail!("determinant_diversity_power must be > 0.0 and finite, got {power}"); + } + + self.determinant_diversity_eta = Some(eta); + self.determinant_diversity_power = Some(power); + } else if self.determinant_diversity_eta.is_some() + || self.determinant_diversity_power.is_some() + { + anyhow::bail!( + "determinant_diversity_eta/determinant_diversity_power may only be set when is_determinant_diversity_search is true" + ); + } + if let Some(n) = self.num_nodes_to_cache { if n == 0 { anyhow::bail!("num_nodes_to_cache must be positive if specified"); @@ -268,6 +303,9 @@ impl Example for DiskIndexOperation { recall_at: 10, num_threads: 8, is_flat_search: false, + is_determinant_diversity_search: false, + determinant_diversity_eta: None, + determinant_diversity_power: None, distance: SimilarityMeasure::SquaredL2, vector_filters_file: None, num_nodes_to_cache: None, @@ -384,6 +422,22 @@ impl DiskSearchPhase { write_field!(f, "Recall@", self.recall_at)?; write_field!(f, "Threads", self.num_threads)?; write_field!(f, "Flat Search", self.is_flat_search)?; + write_field!( + f, + "Determinant Diversity Search", + self.is_determinant_diversity_search + )?; + match ( + self.determinant_diversity_eta, + self.determinant_diversity_power, + ) { + (Some(eta), Some(power)) => write_field!( + f, + "Determinant Diversity Params", + format!("eta={eta}, power={power}") + )?, + _ => write_field!(f, "Determinant Diversity Params", "none")?, + } write_field!(f, "Distance", self.distance)?; match &self.vector_filters_file { Some(vf) => write_field!(f, "Vector Filters File", vf.display())?, diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index c7f21b682..79075592a 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1103,6 +1103,9 @@ pub(crate) mod disk_index_builder_tests { &mut associated_data, &|_| true, false, + false, + None, + None, ); diskann_providers::test_utils::assert_top_k_exactly_match( diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index c0b16beba..b7226e939 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -39,8 +39,10 @@ use diskann::{ use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ model::{ - compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType, - pq::quantizer_preprocess, PQData, PQScratch, + compute_pq_distance, compute_pq_distance_for_pq_coordinates, + graph::{provider::async_::determinant_diversity_post_process, traits::GraphDataType}, + pq::quantizer_preprocess, + PQData, PQScratch, }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, }; @@ -279,6 +281,30 @@ impl<'a> RerankAndFilter<'a> { } } +#[derive(Clone, Copy)] +pub struct DeterminantDiversityRerankAndFilter<'a> { + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + top_k: usize, + eta: f64, + power: f64, +} + +impl<'a> DeterminantDiversityRerankAndFilter<'a> { + fn new( + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + top_k: usize, + eta: f64, + power: f64, + ) -> Self { + Self { + filter, + top_k, + eta, + power, + } + } +} + impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -340,6 +366,84 @@ where } } +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + [Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DeterminantDiversityRerankAndFilter<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + _computer: &DiskQueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + let provider = accessor.provider; + let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; + + let candidate_ids: Vec = candidates + .map(|candidate| candidate.id) + .filter(|id| (self.filter)(id)) + .collect(); + + if candidate_ids.is_empty() { + return Ok(0); + } + + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &candidate_ids)?; + + let mut candidate_vectors = Vec::with_capacity(candidate_ids.len()); + let mut associated_data = HashMap::with_capacity(candidate_ids.len()); + + for id in candidate_ids { + let vector = accessor.scratch.vertex_provider.get_vector(&id)?; + let distance = provider + .distance_comparer + .evaluate_similarity(query, vector); + let vector_f32 = Data::VectorDataType::as_f32(vector).map_err(Into::into)?; + let data = accessor.scratch.vertex_provider.get_associated_data(&id)?; + + candidate_vectors.push((id, distance, vector_f32.to_vec())); + associated_data.insert(id, *data); + } + + let borrowed: Vec<(u32, f32, &[f32])> = candidate_vectors + .iter() + .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) + .collect(); + + let reranked = determinant_diversity_post_process( + borrowed, &query_f32, self.top_k, self.eta, self.power, + ); + + Ok( + output.extend(reranked.into_iter().filter_map(|(id, distance)| { + associated_data + .get(&id) + .copied() + .map(|data| ((id, data), distance)) + })), + ) + } +} + impl<'this, Data, ProviderFactory> SearchStrategy, [Data::VectorDataType]> for DiskSearchStrategy<'this, Data, ProviderFactory> where @@ -933,6 +1037,9 @@ where beam_width: Option, vector_filter: Option>, is_flat_search: bool, + is_determinant_diversity_search: bool, + determinant_diversity_eta: Option, + determinant_diversity_power: Option, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); let mut indices = vec![0u32; return_list_size as usize]; @@ -951,6 +1058,9 @@ where &mut associated_data, &vector_filter.unwrap_or(default_vector_filter::()), is_flat_search, + is_determinant_diversity_search, + determinant_diversity_eta, + determinant_diversity_power, )?; let mut search_result = SearchResult { @@ -988,6 +1098,9 @@ where associated_data: &mut [Data::AssociatedDataType], vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), is_flat_search: bool, + is_determinant_diversity_search: bool, + determinant_diversity_eta: Option, + determinant_diversity_power: Option, ) -> ANNResult { let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices[..k_value], @@ -1010,13 +1123,31 @@ where ))? } else { let knn_search = Knn::new(k, l, beam_width)?; - self.runtime.block_on(self.index.search( - knn_search, - &strategy, - &DefaultContext, - strategy.query, - &mut result_output_buffer, - ))? + if is_determinant_diversity_search { + let processor = DeterminantDiversityRerankAndFilter::new( + vector_filter, + k, + determinant_diversity_eta.unwrap_or(0.01), + determinant_diversity_power.unwrap_or(2.0), + ); + + self.runtime.block_on(self.index.search_with( + knn_search, + &strategy, + processor, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? + } else { + self.runtime.block_on(self.index.search( + knn_search, + &strategy, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? + } }; query_stats.total_comparisons = stats.cmps; query_stats.search_hops = stats.hops; @@ -1493,7 +1624,17 @@ mod disk_provider_tests { let query = &aligned_box.as_slice()[1..]; let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + .search( + query, + params.k as u32, + params.l as u32, + beam_width, + None, + false, + false, + None, + None, + ) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1605,6 +1746,9 @@ mod disk_provider_tests { &mut associated_data, &|_| true, false, + false, + None, + None, ); assert!(result.is_err()); @@ -1674,6 +1818,9 @@ mod disk_provider_tests { Some(4), None, false, + false, + None, + None, ); assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); @@ -2013,6 +2160,9 @@ mod disk_provider_tests { &mut associated_data, &vector_filter, is_flat_search, + false, + None, + None, ); assert!(result.is_ok(), "Expected search to succeed"); @@ -2034,6 +2184,9 @@ mod disk_provider_tests { None, // beam_width Some(Box::new(vector_filter)), is_flat_search, + false, + None, + None, ); assert!(result_with_filter.is_ok(), "Expected search to succeed"); diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 888988ca0..73192c582 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -256,6 +256,9 @@ where Some(parameters.beam_width as usize), Some(vector_filter_function), parameters.is_flat_search, + false, + None, + None, ); match result { From 43ccd91053f85422444c1c59438f8942b4da40f6 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 26 Mar 2026 15:00:42 +0530 Subject: [PATCH 44/47] Generalize benchmark KNN post-processing --- Cargo.lock | 1 - diskann-benchmark-core/Cargo.toml | 1 - .../src/search/graph/determinant_diversity.rs | 205 ------------------ .../src/search/graph/knn.rs | 98 +++++++++ .../src/search/graph/mod.rs | 1 - .../src/backend/index/benchmarks.rs | 31 ++- diskann-benchmark/src/backend/index/result.rs | 36 --- .../src/backend/index/search/knn.rs | 92 +------- .../src/backend/index/spherical.rs | 28 ++- .../src/search/provider/disk_provider.rs | 4 + 10 files changed, 149 insertions(+), 348 deletions(-) delete mode 100644 diskann-benchmark-core/src/search/graph/determinant_diversity.rs diff --git a/Cargo.lock b/Cargo.lock index 43a01b6f4..9bce846e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -687,7 +687,6 @@ dependencies = [ "anyhow", "diskann", "diskann-benchmark-runner", - "diskann-providers", "diskann-utils", "diskann-vector", "futures-util", diff --git a/diskann-benchmark-core/Cargo.toml b/diskann-benchmark-core/Cargo.toml index 689978c03..90e64b9e3 100644 --- a/diskann-benchmark-core/Cargo.toml +++ b/diskann-benchmark-core/Cargo.toml @@ -11,7 +11,6 @@ edition = "2024" anyhow.workspace = true diskann.workspace = true diskann-benchmark-runner = { workspace = true } -diskann-providers.workspace = true diskann-utils.default-features = false diskann-utils.workspace = true futures-util = { workspace = true, default-features = false } diff --git a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs b/diskann-benchmark-core/src/search/graph/determinant_diversity.rs deleted file mode 100644 index 58008afda..000000000 --- a/diskann-benchmark-core/src/search/graph/determinant_diversity.rs +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::sync::Arc; - -use diskann::{ - ANNResult, - graph::{self, glue}, - provider, -}; -use diskann_benchmark_runner::utils::{MicroSeconds, percentiles}; -use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; -use diskann_utils::{future::AsyncFriendly, views::Matrix}; - -use crate::{ - recall, - search::{self, Search, graph::Strategy}, - utils, -}; - -#[derive(Debug, Clone, Copy)] -pub struct Parameters { - pub inner: graph::search::Knn, - pub processor: DeterminantDiversitySearchParams, -} - -#[derive(Debug)] -pub struct DeterminantDiversity -where - DP: provider::DataProvider, -{ - index: Arc>, - queries: Arc>, - strategy: Strategy, -} - -impl DeterminantDiversity -where - DP: provider::DataProvider, -{ - pub fn new( - index: Arc>, - queries: Arc>, - strategy: Strategy, - ) -> anyhow::Result> { - strategy.length_compatible(queries.nrows())?; - - Ok(Arc::new(Self { - index, - queries, - strategy, - })) - } -} - -impl Search for DeterminantDiversity -where - DP: provider::DataProvider, - S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, - DeterminantDiversitySearchParams: - for<'a> glue::SearchPostProcess, [T], DP::ExternalId> + Send + Sync, - T: AsyncFriendly + Clone, -{ - type Id = DP::ExternalId; - type Parameters = Parameters; - type Output = super::knn::Metrics; - - fn num_queries(&self) -> usize { - self.queries.nrows() - } - - fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { - search::IdCount::Fixed(parameters.inner.k_value()) - } - - async fn search( - &self, - parameters: &Self::Parameters, - buffer: &mut O, - index: usize, - ) -> ANNResult - where - O: graph::SearchOutputBuffer + Send, - { - let context = DP::Context::default(); - let stats = self - .index - .search_with( - parameters.inner, - self.strategy.get(index)?, - parameters.processor, - &context, - self.queries.row(index), - buffer, - ) - .await?; - - Ok(super::knn::Metrics { - comparisons: stats.cmps, - hops: stats.hops, - }) - } -} - -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct Summary { - pub setup: search::Setup, - pub parameters: Parameters, - pub end_to_end_latencies: Vec, - pub mean_latencies: Vec, - pub p90_latencies: Vec, - pub p99_latencies: Vec, - pub recall: recall::RecallMetrics, - pub mean_cmps: f64, - pub mean_hops: f64, -} - -pub struct Aggregator<'a, I> { - groundtruth: &'a dyn crate::recall::Rows, - recall_k: usize, - recall_n: usize, -} - -impl<'a, I> Aggregator<'a, I> { - pub fn new( - groundtruth: &'a dyn crate::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> Self { - Self { - groundtruth, - recall_k, - recall_n, - } - } -} - -impl search::Aggregate for Aggregator<'_, I> -where - I: crate::recall::RecallCompatible, -{ - type Output = Summary; - - fn aggregate( - &mut self, - run: search::Run, - mut results: Vec>, - ) -> anyhow::Result { - let recall = match results.first() { - Some(first) => crate::recall::knn( - self.groundtruth, - None, - first.ids().as_rows(), - self.recall_k, - self.recall_n, - true, - )?, - None => anyhow::bail!("Results must be non-empty"), - }; - - let mut mean_latencies = Vec::with_capacity(results.len()); - let mut p90_latencies = Vec::with_capacity(results.len()); - let mut p99_latencies = Vec::with_capacity(results.len()); - - results.iter_mut().for_each(|r| { - match percentiles::compute_percentiles(r.latencies_mut()) { - Ok(values) => { - let percentiles::Percentiles { mean, p90, p99, .. } = values; - mean_latencies.push(mean); - p90_latencies.push(p90); - p99_latencies.push(p99); - } - Err(_) => { - let zero = MicroSeconds::new(0); - mean_latencies.push(0.0); - p90_latencies.push(zero); - p99_latencies.push(zero); - } - } - }); - - Ok(Summary { - setup: run.setup().clone(), - parameters: *run.parameters(), - end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(), - recall, - mean_latencies, - p90_latencies, - p99_latencies, - mean_cmps: utils::average_all( - results - .iter() - .flat_map(|r| r.output().iter().map(|o| o.comparisons)), - ), - mean_hops: utils::average_all( - results - .iter() - .flat_map(|r| r.output().iter().map(|o| o.hops)), - ), - }) - } -} diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 6cc2c9673..45dff153c 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -41,6 +41,18 @@ where strategy: Strategy, } +/// A [`KNN`] variant that uses explicit post-processing during search. +#[derive(Debug)] +pub struct KNNWithPostProcessor +where + DP: provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: Strategy

, +} + impl KNN where DP: provider::DataProvider, @@ -71,6 +83,39 @@ where } } +impl KNNWithPostProcessor +where + DP: provider::DataProvider, +{ + /// Construct a new [`KNNWithPostProcessor`] searcher. + /// + /// If `strategy` or `post_processor` is one of the container variants of [`Strategy`], + /// its length must match the number of rows in `queries`. If this is the case, then the + /// strategies/processors will have a querywise correspondence (see [`search::SearchResults`]) + /// with the query matrix. + /// + /// # Errors + /// + /// Returns an error if the number of elements in `strategy` or `post_processor` is not + /// compatible with the number of rows in `queries`. + pub fn new( + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: Strategy

, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + post_processor.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor, + })) + } +} + /// Additional metrics collected during [`KNN`] search. /// /// # Note @@ -132,6 +177,59 @@ where } } +impl Search for KNNWithPostProcessor +where + DP: provider::DataProvider, + S: glue::DefaultSearchStrategy + Clone + AsyncFriendly, + T: AsyncFriendly + Clone, + P: for<'a> glue::SearchPostProcess, [T], DP::ExternalId> + + Clone + + Send + + Sync + + AsyncFriendly, +{ + type Id = DP::ExternalId; + type Parameters = graph::search::Knn; + type Output = Metrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount { + search::IdCount::Fixed(parameters.k_value()) + } + + async fn search( + &self, + parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> ANNResult + where + O: graph::SearchOutputBuffer + Send, + { + let context = DP::Context::default(); + let knn_search = *parameters; + let stats = self + .index + .search_with( + knn_search, + self.strategy.get(index)?, + self.post_processor.get(index)?.clone(), + &context, + self.queries.row(index), + buffer, + ) + .await?; + + Ok(Metrics { + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + /// An [`search::Aggregate`]d summary of multiple [`KNN`] search runs /// returned by the provided [`Aggregator`]. /// diff --git a/diskann-benchmark-core/src/search/graph/mod.rs b/diskann-benchmark-core/src/search/graph/mod.rs index cfcecb0db..eddb4fbcf 100644 --- a/diskann-benchmark-core/src/search/graph/mod.rs +++ b/diskann-benchmark-core/src/search/graph/mod.rs @@ -3,7 +3,6 @@ * Licensed under the MIT license. */ -pub mod determinant_diversity; pub mod knn; pub mod multihop; pub mod range; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index a733ada8e..68e0c96c9 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -381,21 +381,32 @@ where search_phase.determinant_diversity_eta, search_phase.determinant_diversity_power, ) { - let knn = - benchmark_core::search::graph::determinant_diversity::DeterminantDiversity::new( + let processor = DeterminantDiversitySearchParams::new( + search_phase + .determinant_diversity_results_k + .unwrap_or_else(|| { + search_phase + .runs + .iter() + .map(|run| run.search_n) + .max() + .unwrap_or(1) + }), + eta, + power, + ) + .map_err(|err| { + anyhow::anyhow!("Invalid determinant-diversity parameters: {err}") + })?; + + let knn = benchmark_core::search::graph::knn::KNNWithPostProcessor::new( index, queries, benchmark_core::search::graph::Strategy::broadcast(search_strategy), + benchmark_core::search::graph::Strategy::broadcast(processor), )?; - search::knn::run_determinant_diversity( - &knn, - &groundtruth, - steps, - eta, - power, - search_phase.determinant_diversity_results_k, - )? + search::knn::run(&knn, &groundtruth, steps)? } else { let knn = benchmark_core::search::graph::KNN::new( index, diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index 1f9c2e50a..1d6102f9b 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -155,42 +155,6 @@ impl SearchResults { mean_hops: mean_hops as f32, } } - - pub fn new_determinant_diversity( - summary: benchmark_core::search::graph::determinant_diversity::Summary, - ) -> Self { - let benchmark_core::search::graph::determinant_diversity::Summary { - setup, - parameters, - end_to_end_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall, - mean_cmps, - mean_hops, - .. - } = summary; - - let qps = end_to_end_latencies - .iter() - .map(|latency| recall.num_queries as f64 / latency.as_seconds()) - .collect(); - - Self { - num_tasks: setup.tasks.into(), - search_n: parameters.inner.k_value().get(), - search_l: parameters.inner.l_value().get(), - qps, - search_latencies: end_to_end_latencies, - mean_latencies, - p90_latencies, - p99_latencies, - recall: (&recall).into(), - mean_cmps: mean_cmps as f32, - mean_hops: mean_hops as f32, - } - } } fn format_search_results_table( diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 30560a6cd..357e982c2 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -6,7 +6,6 @@ use std::{num::NonZeroUsize, sync::Arc}; use diskann_benchmark_core::{self as benchmark_core, search as core_search}; -use diskann_providers::model::graph::provider::async_::DeterminantDiversitySearchParams; use crate::{backend::index::result::SearchResults, inputs::async_::GraphSearch}; @@ -53,34 +52,6 @@ pub(crate) trait Knn { ) -> anyhow::Result>; } -type DeterminantRun = - core_search::Run; - -pub(crate) fn run_determinant_diversity( - runner: &dyn DeterminantDiversityKnn, - groundtruth: &dyn benchmark_core::recall::Rows, - steps: SearchSteps<'_>, - eta: f64, - power: f64, - results_k: Option, -) -> anyhow::Result> { - run_search_determinant_diversity(runner, groundtruth, steps, |setup, search_l, search_n| { - let base = diskann::graph::search::Knn::new(search_n, search_l, None).unwrap(); - let processor = - DeterminantDiversitySearchParams::new(results_k.unwrap_or(search_n), eta, power) - .map_err(|err| { - anyhow::anyhow!("Invalid determinant-diversity parameters: {err}") - })?; - - let search_params = - diskann_benchmark_core::search::graph::determinant_diversity::Parameters { - inner: base, - processor, - }; - Ok(core_search::Run::new(search_params, setup)) - }) -} - fn run_search( runner: &dyn Knn, groundtruth: &dyn benchmark_core::recall::Rows, @@ -113,48 +84,6 @@ where Ok(all) } -fn run_search_determinant_diversity( - runner: &dyn DeterminantDiversityKnn, - groundtruth: &dyn benchmark_core::recall::Rows, - steps: SearchSteps<'_>, - builder: F, -) -> anyhow::Result> -where - F: Fn(core_search::Setup, usize, usize) -> anyhow::Result, -{ - let mut all = Vec::new(); - - for threads in steps.num_tasks.iter() { - for run in steps.runs.iter() { - let setup = core_search::Setup { - threads: *threads, - tasks: *threads, - reps: steps.reps, - }; - - let parameters: Vec<_> = run - .search_l - .iter() - .map(|&search_l| builder(setup.clone(), search_l, run.search_n)) - .collect::>>()?; - - all.extend(runner.search_all(parameters, groundtruth, run.recall_k, run.search_n)?); - } - } - - Ok(all) -} - -pub(crate) trait DeterminantDiversityKnn { - fn search_all( - &self, - parameters: Vec, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result>; -} - /////////// // Impls // /////////// @@ -211,19 +140,19 @@ where } } -impl DeterminantDiversityKnn - for Arc> +impl Knn + for Arc> where DP: diskann::provider::DataProvider, - core_search::graph::determinant_diversity::DeterminantDiversity: core_search::Search< + core_search::graph::knn::KNNWithPostProcessor: core_search::Search< Id = DP::InternalId, - Parameters = diskann_benchmark_core::search::graph::determinant_diversity::Parameters, + Parameters = diskann::graph::search::Knn, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -231,16 +160,9 @@ where let results = core_search::search_all( self.clone(), parameters.into_iter(), - core_search::graph::determinant_diversity::Aggregator::new( - groundtruth, - recall_k, - recall_n, - ), + core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n), )?; - Ok(results - .into_iter() - .map(SearchResults::new_determinant_diversity) - .collect()) + Ok(results.into_iter().map(SearchResults::new).collect()) } } diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 1918d6676..4b3e68ba1 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -338,20 +338,30 @@ mod imp { search_phase.determinant_diversity_eta, search_phase.determinant_diversity_power, ) { - let knn = benchmark_core::search::graph::determinant_diversity::DeterminantDiversity::new( + let processor = DeterminantDiversitySearchParams::new( + search_phase + .determinant_diversity_results_k + .unwrap_or_else(|| { + search_phase + .runs + .iter() + .map(|run| run.search_n) + .max() + .unwrap_or(1) + }), + eta, + power, + ) + .map_err(|err| anyhow::anyhow!("Invalid determinant-diversity parameters: {err}"))?; + + let knn = benchmark_core::search::graph::knn::KNNWithPostProcessor::new( index.clone(), queries.clone(), benchmark_core::search::graph::Strategy::broadcast(strategy), + benchmark_core::search::graph::Strategy::broadcast(processor), )?; - search::knn::run_determinant_diversity( - &knn, - &groundtruth, - steps, - eta, - power, - search_phase.determinant_diversity_results_k, - )? + search::knn::run(&knn, &groundtruth, steps)? } else { let knn = benchmark_core::search::graph::KNN::new( index.clone(), diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index b7226e939..683df44a8 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -1029,6 +1029,7 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. + #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], @@ -1569,6 +1570,9 @@ mod disk_provider_tests { &mut associated_data, &(|_| true), false, + false, + None, + None, ); // Calculate the range of the truth_result for this query From 3ff406d031bc14e591b49309ce0f0f03787654f1 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 26 Mar 2026 15:04:29 +0530 Subject: [PATCH 45/47] bug fix in determinant diversity algorithm. --- .../determinant_diversity_post_process.rs | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 75e84bbf1..ff0a7a756 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -242,21 +242,21 @@ fn post_process_with_eta_f32( break; } - let norm_factor = 1.0 / (1.0 + norms_sq[selected_index]).sqrt(); - let q_scaled: Vec = residuals[selected_index] - .iter() - .map(|&x| x * norm_factor) - .collect(); + let best_norm_sq = norms_sq[selected_index]; + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq = 1.0 / best_norm_sq; + let r_star_copy = residuals[selected_index].clone(); let mut projections = Vec::with_capacity(n); for i in 0..n { if !available[i] { projections.push(0.0); } else { - let alpha = dot_product(&residuals[selected_index], &residuals[i]) - * norm_factor - * norm_factor; - projections.push(alpha); + let projection = dot_product(&residuals[i], &r_star_copy) * inv_norm_sq; + projections.push(projection); } } @@ -265,12 +265,12 @@ fn post_process_with_eta_f32( continue; } - let alpha = projections[i]; - for (residual, &q_value) in residuals[i].iter_mut().zip(q_scaled.iter()) { - *residual -= alpha * q_value; + let projection = projections[i]; + for (residual, &star) in residuals[i].iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; } - norms_sq[i] = (norms_sq[i] - alpha * alpha).max(0.0); + norms_sq[i] = (norms_sq[i] - projection * projection * best_norm_sq).max(0.0); } } From 75f9aaaf5d41fa670f7285cc95f8c922e25a5642 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 26 Mar 2026 16:21:13 +0530 Subject: [PATCH 46/47] Reduce determinant rerank allocations --- .../src/search/provider/disk_provider.rs | 11 +- .../determinant_diversity_post_process.rs | 151 ++++++++---------- 2 files changed, 74 insertions(+), 88 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 683df44a8..c3c9e7400 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -424,13 +424,12 @@ where associated_data.insert(id, *data); } - let borrowed: Vec<(u32, f32, &[f32])> = candidate_vectors - .iter() - .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) - .collect(); - let reranked = determinant_diversity_post_process( - borrowed, &query_f32, self.top_k, self.eta, self.power, + candidate_vectors, + &query_f32, + self.top_k, + self.eta, + self.power, ); Ok( diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index ff0a7a756..52856b694 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -116,13 +116,8 @@ where )); } - let borrowed: Vec<(u32, f32, &[f32])> = candidates_with_vectors - .iter() - .map(|(id, distance, vector)| (*id, *distance, vector.as_slice())) - .collect(); - let reranked = determinant_diversity_post_process( - borrowed, + candidates_with_vectors, &query_f32[..], self.top_k, self.determinant_diversity_eta, @@ -137,7 +132,7 @@ where } pub fn determinant_diversity_post_process( - candidates: Vec<(Id, f32, &[f32])>, + candidates: Vec<(Id, f32, Vec)>, query: &[f32], k: usize, determinant_diversity_eta: f64, @@ -152,30 +147,20 @@ pub fn determinant_diversity_post_process( return Vec::new(); } - let candidates_f32: Vec<(Id, f32, Vec)> = candidates - .into_iter() - .map(|(id, dist, v)| (id, dist, v.to_vec())) - .collect(); - - if candidates_f32[0].2.is_empty() { + if candidates[0].2.is_empty() { return Vec::new(); } if determinant_diversity_eta > 0.0 { post_process_with_eta_f32( - candidates_f32, + candidates, query, k, determinant_diversity_eta, determinant_diversity_power, ) } else { - post_process_greedy_orthogonalization_f32( - candidates_f32, - query, - k, - determinant_diversity_power, - ) + post_process_greedy_orthogonalization_f32(candidates, query, k, determinant_diversity_power) } } @@ -396,57 +381,67 @@ mod tests { } #[test] - fn test_validation_zero_top_k() { - let result = DeterminantDiversitySearchParams::new(0, 0.01, 2.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidTopK { top_k: 0 }) - )); - } - - #[test] - fn test_validation_negative_eta() { - let result = DeterminantDiversitySearchParams::new(10, -0.01, 2.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidEta { .. }) - )); - } - - #[test] - fn test_validation_zero_power() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, 0.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidPower { .. }) - )); - } - - #[test] - fn test_validation_negative_power() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, -1.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidPower { .. }) - )); - } - - #[test] - fn test_validation_nan_eta() { - let result = DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidEta { .. }) - )); - } + fn test_validation_invalid_params() { + let test_cases = [ + ( + DeterminantDiversitySearchParams::new(0, 0.01, 2.0), + DeterminantDiversityError::InvalidTopK { top_k: 0 }, + ), + ( + DeterminantDiversitySearchParams::new(10, -0.01, 2.0), + DeterminantDiversityError::InvalidEta { eta: -0.01 }, + ), + ( + DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0), + DeterminantDiversityError::InvalidEta { eta: f64::NAN }, + ), + ( + DeterminantDiversitySearchParams::new(10, 0.01, 0.0), + DeterminantDiversityError::InvalidPower { power: 0.0 }, + ), + ( + DeterminantDiversitySearchParams::new(10, 0.01, -1.0), + DeterminantDiversityError::InvalidPower { power: -1.0 }, + ), + ( + DeterminantDiversitySearchParams::new(10, 0.01, f64::INFINITY), + DeterminantDiversityError::InvalidPower { + power: f64::INFINITY, + }, + ), + ]; - #[test] - fn test_validation_infinity_power() { - let result = DeterminantDiversitySearchParams::new(10, 0.01, f64::INFINITY); - assert!(matches!( - result, - Err(DeterminantDiversityError::InvalidPower { .. }) - )); + for (result, expected) in test_cases { + match (result, expected) { + ( + Err(DeterminantDiversityError::InvalidTopK { top_k: actual }), + DeterminantDiversityError::InvalidTopK { top_k: expected }, + ) => assert_eq!(actual, expected), + ( + Err(DeterminantDiversityError::InvalidEta { eta: actual }), + DeterminantDiversityError::InvalidEta { eta: expected }, + ) => { + if expected.is_nan() { + assert!(actual.is_nan()); + } else { + assert_eq!(actual, expected); + } + } + ( + Err(DeterminantDiversityError::InvalidPower { power: actual }), + DeterminantDiversityError::InvalidPower { power: expected }, + ) => { + if expected.is_infinite() { + assert!(actual.is_infinite()); + } else { + assert_eq!(actual, expected); + } + } + (other, expected) => { + panic!("Unexpected result {:?} for expected {:?}", other, expected) + } + } + } } #[test] @@ -454,11 +449,7 @@ mod tests { let v1 = vec![1.0f32, 0.0, 0.0]; let v2 = vec![0.0f32, 1.0, 0.0]; let v3 = vec![0.0f32, 0.0, 1.0]; - let candidates = vec![ - (1u32, 0.5f32, v1.as_slice()), - (2u32, 0.3f32, v2.as_slice()), - (3u32, 0.7f32, v3.as_slice()), - ]; + let candidates = vec![(1u32, 0.5f32, v1), (2u32, 0.3f32, v2), (3u32, 0.7f32, v3)]; let query = vec![1.0, 1.0, 1.0]; let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); @@ -470,11 +461,7 @@ mod tests { let v1 = vec![1.0f32, 0.0, 0.0]; let v2 = vec![0.99f32, 0.1, 0.0]; let v3 = vec![0.0f32, 1.0, 0.0]; - let candidates = vec![ - (1u32, 0.5f32, v1.as_slice()), - (2u32, 0.3f32, v2.as_slice()), - (3u32, 0.4f32, v3.as_slice()), - ]; + let candidates = vec![(1u32, 0.5f32, v1), (2u32, 0.3f32, v2), (3u32, 0.4f32, v3)]; let query = vec![1.0, 1.0, 0.0]; let result = determinant_diversity_post_process(candidates, &query, 2, 0.0, 1.0); @@ -483,7 +470,7 @@ mod tests { #[test] fn test_determinant_diversity_post_process_empty() { - let candidates: Vec<(u32, f32, &[f32])> = vec![]; + let candidates: Vec<(u32, f32, Vec)> = vec![]; let query = vec![1.0, 1.0, 1.0]; let result = determinant_diversity_post_process(candidates, &query, 3, 0.01, 2.0); @@ -494,7 +481,7 @@ mod tests { fn test_determinant_diversity_post_process_k_larger_than_candidates() { let v1 = vec![1.0f32, 0.0, 0.0]; let v2 = vec![0.0f32, 1.0, 0.0]; - let candidates = vec![(1u32, 0.5f32, v1.as_slice()), (2u32, 0.3f32, v2.as_slice())]; + let candidates = vec![(1u32, 0.5f32, v1), (2u32, 0.3f32, v2)]; let query = vec![1.0, 1.0, 1.0]; let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); From 58e8917bbec7ab947a7606da7326a90b87261328 Mon Sep 17 00:00:00 2001 From: Naren Datha Date: Thu, 26 Mar 2026 20:29:37 +0530 Subject: [PATCH 47/47] determinant diversity: disk index support and refactoring --- diskann-benchmark/src/inputs/async_.rs | 16 +++-- diskann-benchmark/src/inputs/disk.rs | 4 +- .../src/search/provider/disk_provider.rs | 23 +++++-- .../determinant_diversity_post_process.rs | 63 +++++++++++++------ 4 files changed, 75 insertions(+), 31 deletions(-) diff --git a/diskann-benchmark/src/inputs/async_.rs b/diskann-benchmark/src/inputs/async_.rs index 54b795635..b160cf4b8 100644 --- a/diskann-benchmark/src/inputs/async_.rs +++ b/diskann-benchmark/src/inputs/async_.rs @@ -149,18 +149,18 @@ impl CheckDeserialization for TopkSearchPhase { } if let Some(eta) = self.determinant_diversity_eta { - if eta < 0.0 { + if !eta.is_finite() || eta < 0.0 { return Err(anyhow!( - "determinant_diversity_eta must be >= 0.0, got {}", + "determinant_diversity_eta must be finite and >= 0.0, got {}", eta )); } } if let Some(power) = self.determinant_diversity_power { - if power <= 0.0 { + if !power.is_finite() || power < 0.0 { return Err(anyhow!( - "determinant_diversity_power must be > 0.0, got {}", + "determinant_diversity_power must be finite and >= 0.0, got {}", power )); } @@ -170,6 +170,14 @@ impl CheckDeserialization for TopkSearchPhase { if k == 0 { return Err(anyhow!("determinant_diversity_results_k must be > 0")); } + + if self.determinant_diversity_eta.is_none() + || self.determinant_diversity_power.is_none() + { + return Err(anyhow!( + "determinant_diversity_results_k requires determinant_diversity_eta and determinant_diversity_power to both be set" + )); + } } Ok(()) diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index f5058f321..7121e098b 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -245,8 +245,8 @@ impl CheckDeserialization for DiskSearchPhase { anyhow::bail!("determinant_diversity_eta must be >= 0.0 and finite, got {eta}"); } - if power <= 0.0 || !power.is_finite() { - anyhow::bail!("determinant_diversity_power must be > 0.0 and finite, got {power}"); + if power < 0.0 || !power.is_finite() { + anyhow::bail!("determinant_diversity_power must be >= 0.0 and finite, got {power}"); } self.determinant_diversity_eta = Some(eta); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index c3c9e7400..78cc1fec1 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -1124,12 +1124,23 @@ where } else { let knn_search = Knn::new(k, l, beam_width)?; if is_determinant_diversity_search { - let processor = DeterminantDiversityRerankAndFilter::new( - vector_filter, - k, - determinant_diversity_eta.unwrap_or(0.01), - determinant_diversity_power.unwrap_or(2.0), - ); + let eta = determinant_diversity_eta.unwrap_or(0.01); + let power = determinant_diversity_power.unwrap_or(2.0); + + if !eta.is_finite() || eta < 0.0 { + return Err(ANNError::log_index_error(format!( + "determinant_diversity_eta must be finite and >= 0.0, got {eta}" + ))); + } + + if !power.is_finite() || power < 0.0 { + return Err(ANNError::log_index_error(format!( + "determinant_diversity_power must be finite and >= 0.0, got {power}" + ))); + } + + let processor = + DeterminantDiversityRerankAndFilter::new(vector_filter, k, eta, power); self.runtime.block_on(self.index.search_with( knn_search, diff --git a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs index 52856b694..066f6880d 100644 --- a/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs +++ b/diskann-providers/src/model/graph/provider/async_/determinant_diversity_post_process.rs @@ -14,7 +14,9 @@ use diskann::{ provider::BuildQueryComputer, utils::{IntoUsize, VectorRepr}, }; -use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; +use diskann_vector::{ + DistanceFunction, MathematicalValue, PureDistanceFunction, distance::InnerProduct, +}; use super::{ inmem::GetFullPrecision, @@ -33,7 +35,7 @@ impl std::fmt::Display for DeterminantDiversityError { match self { Self::InvalidTopK { top_k } => write!(f, "top_k must be > 0, got {top_k}"), Self::InvalidEta { eta } => write!(f, "eta must be >= 0.0, got {eta}"), - Self::InvalidPower { power } => write!(f, "power must be > 0.0, got {power}"), + Self::InvalidPower { power } => write!(f, "power must be >= 0.0, got {power}"), } } } @@ -63,7 +65,7 @@ impl DeterminantDiversitySearchParams { }); } - if determinant_diversity_power <= 0.0 || !determinant_diversity_power.is_finite() { + if determinant_diversity_power < 0.0 || !determinant_diversity_power.is_finite() { return Err(DeterminantDiversityError::InvalidPower { power: determinant_diversity_power, }); @@ -97,9 +99,9 @@ where B: SearchOutputBuffer + Send + ?Sized, { let result = (|| { - let query_f32 = T::as_f32(query).map_err(Into::into)?; let full = accessor.as_full_precision(); let checker = accessor.as_deletion_check(); + let distance = full.distance(); let mut candidates_with_vectors = Vec::new(); for candidate in candidates { @@ -109,13 +111,16 @@ where let vector = unsafe { full.get_vector_sync(candidate.id.into_usize()) }; let vector_f32 = T::as_f32(vector).map_err(Into::into)?; + let full_precision_distance = distance.evaluate_similarity(query, vector); candidates_with_vectors.push(( candidate.id, - candidate.distance, + full_precision_distance, vector_f32.to_vec(), )); } + let query_f32 = T::as_f32(query).map_err(Into::into)?; + let reranked = determinant_diversity_post_process( candidates_with_vectors, &query_f32[..], @@ -142,6 +147,15 @@ pub fn determinant_diversity_post_process( return Vec::new(); } + let candidates: Vec<_> = candidates + .into_iter() + .filter(|(_, _, vector)| vector.len() == query.len()) + .collect(); + + if candidates.is_empty() { + return Vec::new(); + } + let k = k.min(candidates.len()); if k == 0 { return Vec::new(); @@ -192,8 +206,8 @@ fn post_process_with_eta_f32( let mut residuals = Vec::with_capacity(n); let mut norms_sq = Vec::with_capacity(n); - for (_, _, v) in &candidates { - let similarity = dot_product(v, query); + for (_, similarity_to_query, v) in &candidates { + let similarity = *similarity_to_query; let scale = similarity.max(0.0).powf(power as f32) * inv_sqrt_eta; let residual: Vec = v.iter().map(|&x| x * scale).collect(); let norm_sq = dot_product(&residual, &residual); @@ -203,6 +217,7 @@ fn post_process_with_eta_f32( let mut available = vec![true; n]; let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; for _ in 0..k { let best_idx = available @@ -235,13 +250,12 @@ fn post_process_with_eta_f32( let inv_norm_sq = 1.0 / best_norm_sq; let r_star_copy = residuals[selected_index].clone(); - let mut projections = Vec::with_capacity(n); for i in 0..n { if !available[i] { - projections.push(0.0); + projections[i] = 0.0; } else { let projection = dot_product(&residuals[i], &r_star_copy) * inv_norm_sq; - projections.push(projection); + projections[i] = projection; } } @@ -289,8 +303,8 @@ fn post_process_greedy_orthogonalization_f32( let mut residuals = Vec::with_capacity(n); let mut norms_sq = Vec::with_capacity(n); - for (_, _, v) in &candidates { - let similarity = dot_product(v, query); + for (_, similarity_to_query, v) in &candidates { + let similarity = *similarity_to_query; let scale = similarity.max(0.0).powf(power as f32); let residual: Vec = v.iter().map(|&x| x * scale).collect(); let norm_sq = dot_product(&residual, &residual); @@ -300,6 +314,7 @@ fn post_process_greedy_orthogonalization_f32( let mut available = vec![true; n]; let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; for _ in 0..k { let best = available @@ -331,13 +346,12 @@ fn post_process_greedy_orthogonalization_f32( let inv_norm_sq_star = 1.0 / best_norm_sq; let r_star_copy = residuals[best_index].clone(); - let mut projections = Vec::with_capacity(n); for j in 0..n { if !available[j] { - projections.push(0.0); + projections[j] = 0.0; } else { let projection = dot_product(&residuals[j], &r_star_copy) * inv_norm_sq_star; - projections.push(projection); + projections[j] = projection; } } @@ -378,6 +392,9 @@ mod tests { fn test_validation_valid_params() { let result = DeterminantDiversitySearchParams::new(10, 0.01, 2.0); assert!(result.is_ok()); + + let result = DeterminantDiversitySearchParams::new(10, 0.01, 0.0); + assert!(result.is_ok()); } #[test] @@ -395,10 +412,6 @@ mod tests { DeterminantDiversitySearchParams::new(10, f64::NAN, 2.0), DeterminantDiversityError::InvalidEta { eta: f64::NAN }, ), - ( - DeterminantDiversitySearchParams::new(10, 0.01, 0.0), - DeterminantDiversityError::InvalidPower { power: 0.0 }, - ), ( DeterminantDiversitySearchParams::new(10, 0.01, -1.0), DeterminantDiversityError::InvalidPower { power: -1.0 }, @@ -487,4 +500,16 @@ mod tests { let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); assert_eq!(result.len(), 2); } + + #[test] + fn test_determinant_diversity_post_process_dimension_mismatch_is_skipped() { + let bad = vec![1.0f32, 0.0]; + let good = vec![0.0f32, 1.0, 0.0]; + let candidates = vec![(1u32, 0.5f32, bad), (2u32, 0.3f32, good)]; + let query = vec![1.0, 1.0, 1.0]; + + let result = determinant_diversity_post_process(candidates, &query, 10, 0.01, 2.0); + assert_eq!(result.len(), 1); + assert_eq!(result[0].0, 2); + } }