From a4966281d17b5ee485bb31ec0c31ddb5021a0363 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 13 May 2026 17:49:54 -0700 Subject: [PATCH 01/17] Checkpoint. --- diskann-benchmark/Cargo.toml | 2 +- .../src/search/provider/disk_provider.rs | 88 ++- .../encoded_document_accessor.rs | 128 ++-- diskann-label-filter/src/lib.rs | 10 +- diskann-providers/src/index/diskann_async.rs | 144 ++-- .../provider/async_/inmem/full_precision.rs | 225 ++++-- .../graph/provider/async_/inmem/product.rs | 195 +++-- .../graph/provider/async_/inmem/scalar.rs | 225 ++++-- .../graph/provider/async_/inmem/spherical.rs | 153 ++-- .../model/graph/provider/async_/inmem/test.rs | 654 ++++++++--------- .../model/graph/provider/layers/betafilter.rs | 681 +++++++++--------- .../src/storage/index_storage.rs | 108 +-- diskann/src/graph/glue.rs | 519 +++++++------ diskann/src/graph/index.rs | 231 +++--- diskann/src/graph/search/knn_search.rs | 2 - diskann/src/graph/search/multihop_search.rs | 24 +- diskann/src/graph/search/range_search.rs | 3 +- diskann/src/graph/test/cases/index.rs | 160 ++-- diskann/src/graph/test/provider.rs | 147 +++- diskann/src/graph/workingset/map.rs | 526 ++++++-------- diskann/src/provider.rs | 484 ++++++------- 21 files changed, 2420 insertions(+), 2289 deletions(-) diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..59c8591fe 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -28,7 +28,6 @@ thiserror.workspace = true tokio = { workspace = true, features = ["rt-multi-thread"] } diskann-vector.workspace = true diskann-wide.workspace = true -diskann-label-filter.workspace = true diskann-tools = { workspace = true } diskann-disk = { workspace = true, optional = true } cfg-if.workspace = true @@ -38,6 +37,7 @@ opentelemetry_sdk = { workspace = true, optional = true } scopeguard = { version = "1.2", optional = true } diskann-benchmark-core = { workspace = true, features = ["bigann"] } itertools.workspace = true +diskann-label-filter.workspace = true [lints] clippy.undocumented_unsafe_blocks = "warn" diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 33938caea..64144014f 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -429,18 +429,18 @@ where }) } - async fn distances_unordered( - &mut self, - vec_id_itr: Itr, - _computer: &Self::QueryComputer, - f: F, - ) -> Result<(), Self::GetError> - where - F: Send + FnMut(f32, Self::Id), - Itr: Iterator, - { - self.pq_distances(&vec_id_itr.collect::>(), f) - } + // async fn distances_unordered( + // &mut self, + // vec_id_itr: Itr, + // _computer: &Self::QueryComputer, + // f: F, + // ) -> Result<(), Self::GetError> + // where + // F: Send + FnMut(f32, Self::Id), + // Itr: Iterator, + // { + // self.pq_distances(&vec_id_itr.collect::>(), f) + // } } impl ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> @@ -454,7 +454,7 @@ where _computer: &Self::QueryComputer, mut pred: P, mut f: F, - ) -> impl std::future::Future> + Send + ) -> impl std::future::Future> + Send where Itr: Iterator + Send, P: glue::HybridPredicate + Send + Sync, @@ -588,7 +588,7 @@ where } } -impl SearchExt for DiskAccessor<'_, Data, VP> +impl SearchExt<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, @@ -598,6 +598,21 @@ where Ok(vec![start_vertex_id]) } + async fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> ANNResult<()> + where + F: FnMut(Self::Id, f32) + Send + { + let start_vertex_id = self.provider.graph_header.metadata().medoid as u32; + let vector = self.provider.pq_data.get_compressed_vector(start_vertex_id.into_usize())?; + let distance = computer.evaluate_similarity(vector); + f(start_vertex_id, distance); + Ok(()) + } + fn terminate_early(&mut self) -> bool { self.io_tracker.io_count() > self.provider.search_io_limit } @@ -694,25 +709,25 @@ where Data: GraphDataType, VP: VertexProvider, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [u8] - where - Self: 'a; + // /// This accessor returns raw slices. There *is* a chance of racing when the fast + // /// providers are used. We just have to live with it. + // type Element<'a> + // = &'a [u8] + // where + // Self: 'a; /// `ElementRef` can have arbitrary lifetimes. type ElementRef<'a> = &'a [u8]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = ANNError; + // /// Choose to panic on an out-of-bounds access rather than propagate an error. + // type GetError = ANNError; - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize)) - } + // fn get_element( + // &mut self, + // id: Self::Id, + // ) -> impl Future, Self::GetError>> + Send { + // std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize)) + // } } impl IdIterator> for DiskAccessor<'_, Data, VP> @@ -993,14 +1008,15 @@ where let k = k_value; let l = search_list_size as usize; let stats = if is_flat_search { - self.runtime.block_on(self.index.flat_search( - &strategy, - &DefaultContext, - strategy.query, - vector_filter, - &Knn::new(k, l, beam_width)?, - &mut result_output_buffer, - ))? + todo!(); + // self.runtime.block_on(self.index.flat_search( + // &strategy, + // &DefaultContext, + // strategy.query, + // vector_filter, + // &Knn::new(k, l, beam_width)?, + // &mut result_output_buffer, + // ))? } else { let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0d16248dd..6f269ab42 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -72,71 +72,71 @@ impl Accessor for EncodedDocumentAccessor where IA: Accessor, { - type Element<'a> - = EncodedDocument, RoaringTreemap> - where - Self: 'a; + // type Element<'a> + // = EncodedDocument, RoaringTreemap> + // where + // Self: 'a; type ElementRef<'a> = EncodedDocument, &'a RoaringTreemap>; - type GetError = ANNError; - - async fn get_element(&mut self, id: Self::Id) -> Result, Self::GetError> { - let future = self.inner_accessor.get_element(id); - let elem = future.await.escalate("Did not find the vector element")?; - - let attrs = self - .attribute_accessor - .visit_labels_of_point(id, |_, opt_set| { - match opt_set { - //TODO: Currently, there is no way but to copy. So we copy the set from the Cow into a - //hydrated object. - //IMP NOTE: Removing the copy will also change the signature of "Element" and may cause other - //downstream issues, so should be done with care! - Some(set) => Ok(set.into_owned()), - None => Err(ANNError::message( - ANNErrorKind::IndexError, - "No labels were found for vector", - )), - } - })?; - - Ok(EncodedDocument::new(elem, attrs?)) - } - - async fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> Result<(), Self::GetError> - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - { - for i in itr { - let vec = self - .inner_accessor - .get_element(i) - .await - .escalate("Failed to get vector from inner accessor")?; - let _ = self - .attribute_accessor - .visit_labels_of_point(i, |_, opt_set| { - let set = match opt_set { - Some(set) => set, - None => { - return Err(ANNError::message( - ANNErrorKind::IndexError, - format!("No attributes found for point.{}", i), - )); - } - }; - let elem = EncodedDocument::new(vec.reborrow(), &*set); - f(elem, i); - Ok(()) - }); - } - Ok(()) - } + // type GetError = ANNError; + + // async fn get_element(&mut self, id: Self::Id) -> Result, Self::GetError> { + // let future = self.inner_accessor.get_element(id); + // let elem = future.await.escalate("Did not find the vector element")?; + + // let attrs = self + // .attribute_accessor + // .visit_labels_of_point(id, |_, opt_set| { + // match opt_set { + // //TODO: Currently, there is no way but to copy. So we copy the set from the Cow into a + // //hydrated object. + // //IMP NOTE: Removing the copy will also change the signature of "Element" and may cause other + // //downstream issues, so should be done with care! + // Some(set) => Ok(set.into_owned()), + // None => Err(ANNError::message( + // ANNErrorKind::IndexError, + // "No labels were found for vector", + // )), + // } + // })?; + + // Ok(EncodedDocument::new(elem, attrs?)) + // } + + // async fn on_elements_unordered( + // &mut self, + // itr: Itr, + // mut f: F, + // ) -> Result<(), Self::GetError> + // where + // Self: Sync, + // Itr: Iterator + Send, + // F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), + // { + // for i in itr { + // let vec = self + // .inner_accessor + // .get_element(i) + // .await + // .escalate("Failed to get vector from inner accessor")?; + // let _ = self + // .attribute_accessor + // .visit_labels_of_point(i, |_, opt_set| { + // let set = match opt_set { + // Some(set) => set, + // None => { + // return Err(ANNError::message( + // ANNErrorKind::IndexError, + // format!("No attributes found for point.{}", i), + // )); + // } + // }; + // let elem = EncodedDocument::new(vec.reborrow(), &*set); + // f(elem, i); + // Ok(()) + // }); + // } + // Ok(()) + // } } impl SearchExt for EncodedDocumentAccessor diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 106845f98..0a6abcb9e 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -17,11 +17,11 @@ pub mod utils { pub mod jsonl_reader; } -pub mod inline_beta_search { - pub mod encoded_document_accessor; - pub mod inline_beta_filter; - pub mod predicate_evaluator; -} +// pub mod inline_beta_search { +// pub mod encoded_document_accessor; +// pub mod inline_beta_filter; +// pub mod predicate_evaluator; +// } // Persisent Index Traits pub mod traits { diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 69c254c86..da386be72 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -3185,78 +3185,78 @@ pub(crate) mod tests { // Flaky Provider Handling // ///////////////////////////// - // This test uses a "Flaky" accessor that spuriously fails with non-critical errors to - // check that such errors are not propagated by DiskANN. - #[tokio::test] - async fn test_flaky_build() { - let parameters = InitParams { - l_build: 64, - max_degree: 16, - metric: Metric::L2, - batchsize: NonZeroUsize::new(1).unwrap(), - }; - - let start_point = StartPointStrategy::RandomSamples { - nsamples: ONE, - seed: 0xb4de0a1298a86eea, - }; - - // This is the two level index. - let (index, data) = init_from_file( - inmem::test::Flaky::new(9), - parameters, - SIFTSMALL, - 8, - start_point, - ) - .await; - - // There should be one more reachable node than points in the dataset to account for - // the start point. - let neighbor_accessor = &mut index.provider().neighbors(); - assert_eq!( - index - .count_reachable_nodes( - &index.provider().starting_points().unwrap(), - neighbor_accessor - ) - .await - .unwrap(), - data.nrows() + 1, - ); - - let top_k = 10; - let search_l = 32; - let mut ids = vec![0; top_k]; - let mut distances = vec![0.0; top_k]; - - // Here, we use elements of the dataset to search the dataset itself. - // - // We do this for each query, computing the expected ground truth and verifying - // that our simple graph search matches. - // - // Because this dataset is small, we can expect exact equality. - let ctx = &DefaultContext; - for (q, query) in data.row_iter().enumerate() { - let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); - // Full Precision Search. - index - .search( - graph_search, - &FullPrecision, - ctx, - query, - &mut result_output_buffer, - ) - .await - .unwrap(); - - assert_top_k_exactly_match(q, >, &ids, &distances, top_k); - } - } + // // This test uses a "Flaky" accessor that spuriously fails with non-critical errors to + // // check that such errors are not propagated by DiskANN. + // #[tokio::test] + // async fn test_flaky_build() { + // let parameters = InitParams { + // l_build: 64, + // max_degree: 16, + // metric: Metric::L2, + // batchsize: NonZeroUsize::new(1).unwrap(), + // }; + + // let start_point = StartPointStrategy::RandomSamples { + // nsamples: ONE, + // seed: 0xb4de0a1298a86eea, + // }; + + // // This is the two level index. + // let (index, data) = init_from_file( + // inmem::test::Flaky::new(9), + // parameters, + // SIFTSMALL, + // 8, + // start_point, + // ) + // .await; + + // // There should be one more reachable node than points in the dataset to account for + // // the start point. + // let neighbor_accessor = &mut index.provider().neighbors(); + // assert_eq!( + // index + // .count_reachable_nodes( + // &index.provider().starting_points().unwrap(), + // neighbor_accessor + // ) + // .await + // .unwrap(), + // data.nrows() + 1, + // ); + + // let top_k = 10; + // let search_l = 32; + // let mut ids = vec![0; top_k]; + // let mut distances = vec![0.0; top_k]; + + // // Here, we use elements of the dataset to search the dataset itself. + // // + // // We do this for each query, computing the expected ground truth and verifying + // // that our simple graph search matches. + // // + // // Because this dataset is small, we can expect exact equality. + // let ctx = &DefaultContext; + // for (q, query) in data.row_iter().enumerate() { + // let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); + // let mut result_output_buffer = + // search_output_buffer::IdDistance::new(&mut ids, &mut distances); + // let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + // // Full Precision Search. + // index + // .search( + // graph_search, + // &FullPrecision, + // ctx, + // query, + // &mut result_output_buffer, + // ) + // .await + // .unwrap(); + + // assert_top_k_exactly_match(q, >, &ids, &distances, top_k); + // } + // } async fn create_retry_saturated_index( retry: NonZeroU32, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e72f323ef..2b408c046 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -10,7 +10,7 @@ use diskann::{ ANNError, ANNResult, error::Infallible, graph::{ - SearchOutputBuffer, + AdjacencyList, SearchOutputBuffer, glue::{ self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, @@ -26,7 +26,7 @@ use diskann::{ }; use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; +use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; use crate::model::graph::provider::async_::{ FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, @@ -144,7 +144,7 @@ where type Id = u32; } -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +impl SearchExt<&[T]> for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, @@ -154,6 +154,28 @@ where fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } + + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = computer.evaluate_similarity(unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + + f(i, distance); + } + Ok(()) + } + } } impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> @@ -192,81 +214,81 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [T] - where - Self: 'a; + // /// This accessor returns raw slices. There *is* a chance of racing when the fast + // /// providers are used. We just have to live with it. + // type Element<'a> + // = &'a [T] + // where + // Self: 'a; /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = &'a [T]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } - - std::future::ready(Ok(())) - } + // /// Choose to panic on an out-of-bounds access rather than propagate an error. + // type GetError = Panics; + + // /// Return the full-precision vector stored at index `i`. + // /// + // /// This function always completes synchronously. + // #[inline(always)] + // fn get_element( + // &mut self, + // id: Self::Id, + // ) -> impl Future, Self::GetError>> + Send { + // // SAFETY: We've decided to live with UB (undefined behavior) that can result from + // // potentially mixing unsynchronized reads and writes on the underlying memory. + // std::future::ready(Ok(unsafe { + // self.provider.base_vectors.get_vector_sync(id.into_usize()) + // })) + // } + + // /// Perform a bulk operation. + // /// + // /// This implementation uses prefetching. + // fn on_elements_unordered( + // &mut self, + // itr: Itr, + // mut f: F, + // ) -> impl Future> + Send + // where + // Self: Sync, + // Itr: Iterator + Send, + // F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + // { + // // Reuse the internal buffer to collect the results and give us random access + // // capabilities. + // let id_buffer = &mut self.id_buffer; + // id_buffer.clear(); + // id_buffer.extend(itr); + + // let len = id_buffer.len(); + // let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // // Prefetch the first few vectors. + // for id in id_buffer.iter().take(lookahead) { + // self.provider.base_vectors.prefetch_hint(id.into_usize()); + // } + + // for (i, id) in id_buffer.iter().enumerate() { + // // Prefetch `lookahead` iterations ahead as long as it is safe. + // if lookahead > 0 && i + lookahead < len { + // self.provider + // .base_vectors + // .prefetch_hint(id_buffer[i + lookahead].into_usize()); + // } + + // // Invoke the passed closure on the full-precision vector. + // // + // // SAFETY: We're accepting the consequences of potential unsynchronized, + // // concurrent mutation. + // f( + // unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, + // *id, + // ) + // } + + // std::future::ready(Ok(())) + // } } impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> @@ -314,6 +336,61 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(neighbors.iter().filter(|i| pred.eval_mut(i))); + + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let v = unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }; + let distance = computer.evaluate_similarity(v); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } //-------------------// diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index b3867ad85..cd895aa91 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -9,6 +9,7 @@ use diskann::default_post_processor; use diskann::{ ANNError, ANNResult, graph::{ + AdjacencyList, glue::{ self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy, @@ -22,7 +23,7 @@ use diskann::{ utils::{IntoUsize, VectorRepr}, }; use diskann_utils::future::AsyncFriendly; -use diskann_vector::distance::Metric; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use crate::model::{ graph::provider::async_::{ @@ -109,8 +110,9 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl SearchExt for QuantAccessor<'_, V, D, Ctx> +impl SearchExt<&[T]> for QuantAccessor<'_, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, @@ -118,6 +120,28 @@ where fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } + + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); + + f(i, distance); + } + Ok(()) + } + } } impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> @@ -149,54 +173,54 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [u8] - where - Self: 'a; + // /// This accessor returns raw slices. There *is* a chance of racing when the fast + // /// providers are used. We just have to live with it. + // type Element<'a> + // = &'a [u8] + // where + // Self: 'a; /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = &'a [u8]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the quantized vector stored at index `i`. - /// - /// This function always completes synchronously. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.aux_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - for i in itr { - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.aux_vectors.get_vector_sync(i.into_usize()) }, - i, - ) - } - std::future::ready(Ok(())) - } + // /// Choose to panic on an out-of-bounds access rather than propagate an error. + // type GetError = Panics; + + // /// Return the quantized vector stored at index `i`. + // /// + // /// This function always completes synchronously. + // fn get_element( + // &mut self, + // id: Self::Id, + // ) -> impl Future, Self::GetError>> + Send { + // // SAFETY: We've decided to live with UB that can result from potentially mixing + // // unsynchronized reads and writes on the underlying memory. + // std::future::ready(Ok(unsafe { + // self.provider.aux_vectors.get_vector_sync(id.into_usize()) + // })) + // } + + // /// Perform a bulk operation. + // fn on_elements_unordered( + // &mut self, + // itr: Itr, + // mut f: F, + // ) -> impl Future> + Send + // where + // Self: Sync, + // Itr: Iterator + Send, + // F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + // { + // for i in itr { + // // SAFETY: We're accepting the consequences of potential unsynchronized, + // // concurrent mutation. + // f( + // unsafe { self.provider.aux_vectors.get_vector_sync(i.into_usize()) }, + // i, + // ) + // } + // std::future::ready(Ok(())) + // } } impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> @@ -240,6 +264,39 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + for i in neighbors.iter().filter(|i| pred.eval_mut(i)) { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); + + on_neighbors(distance, *i); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } //-------------------// @@ -322,33 +379,33 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - /// The [`distances::pq::Hybrid`] is an enum consisting of either a full-precision - /// vector or a quantized vector. - /// - /// This accessor can return either. - type Element<'a> - = distances::pq::Hybrid<&'a [T], &'a [u8]> - where - Self: 'a; + // /// The [`distances::pq::Hybrid`] is an enum consisting of either a full-precision + // /// vector or a quantized vector. + // /// + // /// This accessor can return either. + // type Element<'a> + // = distances::pq::Hybrid<&'a [T], &'a [u8]> + // where + // Self: 'a; /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// The default behavior of `get_element` returns a full-precision vector. The - /// implementation of [`Fill`] is how the `max_fp_vecs_per_fill` is used. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - distances::pq::Hybrid::Full(self.provider.base_vectors.get_vector_sync(id.into_usize())) - })) - } + // /// Choose to panic on an out-of-bounds access rather than propagate an error. + // type GetError = Panics; + + // /// The default behavior of `get_element` returns a full-precision vector. The + // /// implementation of [`Fill`] is how the `max_fp_vecs_per_fill` is used. + // fn get_element( + // &mut self, + // id: Self::Id, + // ) -> impl Future, Self::GetError>> + Send { + // // SAFETY: We've decided to live with UB that can result from potentially mixing + // // unsynchronized reads and writes on the underlying memory. + // std::future::ready(Ok(unsafe { + // distances::pq::Hybrid::Full(self.provider.base_vectors.get_vector_sync(id.into_usize())) + // })) + // } } impl BuildDistanceComputer for HybridAccessor<'_, T, D, Ctx> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index a69344691..b9a878cec 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -9,6 +9,7 @@ use crate::storage::{StorageReadProvider, StorageWriteProvider}; use diskann::{ ANNError, ANNResult, default_post_processor, graph::{ + AdjacencyList, glue::{ self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, @@ -390,16 +391,35 @@ impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx type Id = u32; } -impl SearchExt for QuantAccessor<'_, NBITS, V, D, Ctx> +impl SearchExt<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, Unsigned: Representation, + QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } + + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + let vector = self.provider.aux_vectors.get_vector(i.into_usize())?; + f(i, computer.evaluate_similarity(vector)); + } + Ok(()) + } + } } impl<'a, const NBITS: usize, V, D, Ctx> QuantAccessor<'a, NBITS, V, D, Ctx> @@ -427,85 +447,85 @@ where Ctx: ExecutionContext, Unsigned: Representation, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = CVRef<'a, NBITS> - where - Self: 'a; + // /// This accessor returns raw slices. There *is* a chance of racing when the fast + // /// providers are used. We just have to live with it. + // type Element<'a> + // = CVRef<'a, NBITS> + // where + // Self: 'a; /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = CVRef<'a, NBITS>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = ANNError; - - /// Return the quantized vector stored at index `i`. - /// - /// This function always completes synchronously. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready( - match self.provider.aux_vectors.get_vector(id.into_usize()) { - Ok(v) => Ok(v), - Err(err) => Err(err.into()), - }, - ) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.aux_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .aux_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) { - Ok(v) => v, - Err(e) => return std::future::ready(Err(e.into())), - }; - - // Invoke the passed closure on the vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f(vector, *id) - } - - std::future::ready(Ok(())) - } + // /// Choose to panic on an out-of-bounds access rather than propagate an error. + // type GetError = ANNError; + + // /// Return the quantized vector stored at index `i`. + // /// + // /// This function always completes synchronously. + // fn get_element( + // &mut self, + // id: Self::Id, + // ) -> impl Future, Self::GetError>> + Send { + // // SAFETY: We've decided to live with UB that can result from potentially mixing + // // unsynchronized reads and writes on the underlying memory. + // std::future::ready( + // match self.provider.aux_vectors.get_vector(id.into_usize()) { + // Ok(v) => Ok(v), + // Err(err) => Err(err.into()), + // }, + // ) + // } + + // /// Perform a bulk operation. + // /// + // /// This implementation uses prefetching. + // fn on_elements_unordered( + // &mut self, + // itr: Itr, + // mut f: F, + // ) -> impl Future> + Send + // where + // Self: Sync, + // Itr: Iterator + Send, + // F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + // { + // // Reuse the internal buffer to collect the results and give us random access + // // capabilities. + // let id_buffer = &mut self.id_buffer; + // id_buffer.clear(); + // id_buffer.extend(itr); + + // let len = id_buffer.len(); + // let lookahead = self.provider.aux_vectors.prefetch_lookahead(); + + // // Prefetch the first few vectors. + // for id in id_buffer.iter().take(lookahead) { + // self.provider.aux_vectors.prefetch_hint(id.into_usize()); + // } + + // for (i, id) in id_buffer.iter().enumerate() { + // // Prefetch `lookahead` iterations ahead as long as it is safe. + // if lookahead > 0 && i + lookahead < len { + // self.provider + // .aux_vectors + // .prefetch_hint(id_buffer[i + lookahead].into_usize()); + // } + + // let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) { + // Ok(v) => v, + // Err(e) => return std::future::ready(Err(e.into())), + // }; + + // // Invoke the passed closure on the vector. + // // + // // SAFETY: We're accepting the consequences of potential unsynchronized, + // // concurrent mutation. + // f(vector, *id) + // } + + // std::future::ready(Ok(())) + // } } impl<'a, const NBITS: usize, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, NBITS, V, D, Ctx> @@ -554,6 +574,59 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(neighbors.iter().filter(|i| pred.eval_mut(i))); + + let len = id_buffer.len(); + let lookahead = self.provider.aux_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.aux_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .aux_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; + + // Invoke the passed closure on the vector. + let distance = computer.evaluate_similarity(vector); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 11f489aa4..c6771bcbf 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -9,8 +9,8 @@ use std::{future::Future, sync::Mutex}; use diskann::{ ANNError, ANNErrorKind, ANNResult, default_post_processor, - error::IntoANNResult, graph::{ + AdjacencyList, glue::{ self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, @@ -29,7 +29,7 @@ use diskann_quantization::{ spherical, }; use diskann_utils::future::AsyncFriendly; -use diskann_vector::distance::Metric; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use thiserror::Error; use super::{GetFullPrecision, PassThrough, Rerank}; @@ -329,8 +329,9 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl SearchExt for QuantAccessor<'_, V, D, Ctx> +impl SearchExt<&[T]> for QuantAccessor<'_, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, @@ -338,6 +339,28 @@ where fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } + + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = computer.evaluate_similarity( + self.provider.aux_vectors.get_vector(i.into_usize())? + ); + + f(i, distance); + } + Ok(()) + } + } } impl Accessor for QuantAccessor<'_, V, D, Ctx> @@ -346,81 +369,8 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = spherical::iface::Opaque<'a> - where - Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = spherical::iface::Opaque<'a>; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = ANNError; - - /// Return the quantized vector stored at index `i`. - /// - /// This function always completes synchronously. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready( - self.provider - .aux_vectors - .get_vector(id.into_usize()) - .into_ann_result(), - ) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.aux_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .aux_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) { - Ok(v) => v, - Err(e) => return std::future::ready(Err(e.into())), - }; - - f(vector, *id) - } - - std::future::ready(Ok(())) - } } impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> @@ -465,6 +415,57 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + + // Reuse the internal buffer to collect the results and give us random access + // capabilities. + let id_buffer = &mut self.id_buffer; + id_buffer.clear(); + id_buffer.extend(neighbors.iter().filter(|i| pred.eval_mut(i))); + + let len = id_buffer.len(); + let lookahead = self.provider.aux_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.aux_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .aux_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; + let distance = computer.evaluate_similarity(vector); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } #[derive(Debug, Error)] diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 2493a5ce2..84dabdd8e 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -1,327 +1,327 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{ - future::Future, - sync::{Arc, Mutex}, -}; - -use diskann::{ - ANNError, ANNResult, default_post_processor, - error::{RankedError, ToRanked, TransientError}, - graph::{ - glue::{ - CopyIds, DefaultPostProcessor, ExpandBeam, InsertStrategy, MultiInsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - workingset::map, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - HasId, - }, - utils::IntoUsize, -}; -use diskann_utils::views::Matrix; - -use super::{DefaultProvider, DefaultQuant}; -use crate::model::graph::provider::async_::{ - SimpleNeighborProviderAsync, TableDeleteProviderAsync, inmem::FullPrecisionStore, -}; - -/// A full-precision accessor that spuriously fails for non-start points with a controllable -/// frequency. -/// -/// This is meant to test the non-critical error handling of index operations. -#[derive(Debug, Clone, Copy)] -pub struct Flaky { - fail_every: usize, -} - -impl Flaky { - pub(crate) fn new(fail_every: usize) -> Self { - Self { fail_every } - } -} - -#[derive(Debug)] -pub struct TestError { - is_transient: bool, - handled: bool, -} - -impl TestError { - fn transient() -> Self { - Self { - is_transient: true, - handled: false, - } - } -} - -impl std::fmt::Display for TestError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl TransientError for TestError { - fn acknowledge(mut self, _why: D) - where - D: std::fmt::Display, - { - self.handled = true; - } - - fn escalate(mut self, _why: D) -> Self - where - D: std::fmt::Display, - { - assert!(self.is_transient); - self.handled = true; - self.is_transient = false; - self - } -} - -impl Drop for TestError { - fn drop(&mut self) { - if self.is_transient { - assert!(self.handled, "dropping an unhandled transient error!"); - } - } -} - -impl From for ANNError { - fn from(value: TestError) -> Self { - assert!( - !value.is_transient, - "transient errors should not be converted!" - ); - ANNError::log_async_error(value) - } -} - -impl ToRanked for TestError { - type Transient = Self; - type Error = Self; - - fn to_ranked(self) -> RankedError { - if self.is_transient { - RankedError::Transient(self) - } else { - RankedError::Error(self) - } - } - - fn from_transient(transient: Self) -> Self { - assert!(transient.is_transient); - transient - } - - fn from_error(error: Self) -> Self { - assert!(!error.is_transient); - error - } -} - -type Tda = TableDeleteProviderAsync; -type TestProvider = DefaultProvider, DefaultQuant, Tda>; - -pub struct FlakyAccessor<'a> { - provider: &'a TestProvider, - fail_every: usize, - get_count: usize, -} - -type FullAccessor<'a> = super::FullAccessor<'a, f32, DefaultQuant, Tda, DefaultContext>; - -impl SearchExt for FlakyAccessor<'_> { - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a> FlakyAccessor<'a> { - fn new(provider: &'a TestProvider, fail_every: usize, get_count: usize) -> Self { - assert_ne!(get_count, 0); - Self { - provider, - get_count, - fail_every, - } - } - - fn as_full(&self) -> FullAccessor<'a> { - FullAccessor::new(self.provider) - } -} - -impl HasId for FlakyAccessor<'_> { - type Id = u32; -} - -impl Accessor for FlakyAccessor<'_> { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [f32] - where - Self: 'a; - - type ElementRef<'a> = &'a [f32]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = TestError; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // Do not fail when retrieving start points. - // - // NOTE: `is_not_start_point` takes a neighbor, but only looks at the `ID` portion, - // so we can use a dummy neighbor. - if self.provider.is_not_start_point()(&Neighbor::new(id, 0.0)) { - self.get_count -= 1; - if self.get_count == 0 { - self.get_count = self.fail_every; - return std::future::ready(Err(TestError::transient())); - } - } - - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } -} - -impl<'a> BuildDistanceComputer for FlakyAccessor<'a> { - type DistanceComputerError = as BuildDistanceComputer>::DistanceComputerError; - type DistanceComputer = as BuildDistanceComputer>::DistanceComputer; - - fn build_distance_computer( - &self, - ) -> Result { - self.as_full().build_distance_computer() - } -} - -impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { - type QueryComputerError = - as BuildQueryComputer<&'a [f32]>>::QueryComputerError; - type QueryComputer = as BuildQueryComputer<&'a [f32]>>::QueryComputer; - - fn build_query_computer( - &self, - from: &'a [f32], - ) -> Result { - self.as_full().build_query_computer(from) - } -} - -impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} - -impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { - type Delegate = &'a SimpleNeighborProviderAsync; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'x> SearchStrategy for Flaky { - type QueryComputer = as BuildQueryComputer<&'x [f32]>>::QueryComputer; - type SearchAccessor<'a> = FlakyAccessor<'a>; - type SearchAccessorError = ANNError; - - fn search_accessor<'a>( - &'a self, - provider: &'a TestProvider, - _context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - Ok(FlakyAccessor::new( - provider, - self.fail_every, - self.fail_every, - )) - } -} - -impl DefaultPostProcessor for Flaky { - default_post_processor!(CopyIds); -} - -const STATIC_PRUNE_THRESHOLD: usize = 5; -/// We need to tune the flakiness of the `Prune` accessor so that occasionally, the first -/// item retrieved is a failure. -static START_COUNT: Mutex = Mutex::new(STATIC_PRUNE_THRESHOLD); - -type WorkingSet = map::Map, map::Ref<[f32]>>; - -impl PruneStrategy for Flaky { - type DistanceComputer<'a> = as BuildDistanceComputer>::DistanceComputer; - type PruneAccessor<'a> = FlakyAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - type WorkingSet = WorkingSet; - - fn prune_accessor<'a>( - &'a self, - provider: &'a TestProvider, - _context: &'a DefaultContext, - ) -> Result, Self::PruneAccessorError> { - let mut guard = START_COUNT.lock().unwrap(); - let start = *guard; - *guard -= 1; - if *guard == 0 { - *guard = STATIC_PRUNE_THRESHOLD; - } - - Ok(FlakyAccessor::new(provider, STATIC_PRUNE_THRESHOLD, start)) - } - - fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { - map::Builder::new(map::Capacity::Default).build(capacity) - } -} - -impl InsertStrategy for Flaky { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self { - *self - } -} - -impl MultiInsertStrategy> for Flaky { - type Seed = map::Builder>; - type WorkingSet = WorkingSet; - type FinishError = diskann::error::Infallible; - type InsertStrategy = Self; - - fn insert_strategy(&self) -> Self::InsertStrategy { - *self - } - - fn finish( - &self, - _provider: &TestProvider, - _ctx: &DefaultContext, - batch: &Arc>, - ids: Itr, - ) -> impl std::future::Future> + Send - where - Itr: ExactSizeIterator + Send, - { - std::future::ready(Ok(map::Builder::new(map::Capacity::Default) - .with_overlay(map::Overlay::from_batch(batch.clone(), ids)))) - } -} +// /* +// * Copyright (c) Microsoft Corporation. +// * Licensed under the MIT license. +// */ +// +// use std::{ +// future::Future, +// sync::{Arc, Mutex}, +// }; +// +// use diskann::{ +// ANNError, ANNResult, default_post_processor, +// error::{RankedError, ToRanked, TransientError}, +// graph::{ +// glue::{ +// CopyIds, DefaultPostProcessor, ExpandBeam, InsertStrategy, MultiInsertStrategy, +// PruneStrategy, SearchExt, SearchStrategy, +// }, +// workingset::map, +// }, +// neighbor::Neighbor, +// provider::{ +// Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, +// HasId, +// }, +// utils::IntoUsize, +// }; +// use diskann_utils::views::Matrix; +// +// use super::{DefaultProvider, DefaultQuant}; +// use crate::model::graph::provider::async_::{ +// SimpleNeighborProviderAsync, TableDeleteProviderAsync, inmem::FullPrecisionStore, +// }; +// +// /// A full-precision accessor that spuriously fails for non-start points with a controllable +// /// frequency. +// /// +// /// This is meant to test the non-critical error handling of index operations. +// #[derive(Debug, Clone, Copy)] +// pub struct Flaky { +// fail_every: usize, +// } +// +// impl Flaky { +// pub(crate) fn new(fail_every: usize) -> Self { +// Self { fail_every } +// } +// } +// +// #[derive(Debug)] +// pub struct TestError { +// is_transient: bool, +// handled: bool, +// } +// +// impl TestError { +// fn transient() -> Self { +// Self { +// is_transient: true, +// handled: false, +// } +// } +// } +// +// impl std::fmt::Display for TestError { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// write!(f, "{:?}", self) +// } +// } +// +// impl TransientError for TestError { +// fn acknowledge(mut self, _why: D) +// where +// D: std::fmt::Display, +// { +// self.handled = true; +// } +// +// fn escalate(mut self, _why: D) -> Self +// where +// D: std::fmt::Display, +// { +// assert!(self.is_transient); +// self.handled = true; +// self.is_transient = false; +// self +// } +// } +// +// impl Drop for TestError { +// fn drop(&mut self) { +// if self.is_transient { +// assert!(self.handled, "dropping an unhandled transient error!"); +// } +// } +// } +// +// impl From for ANNError { +// fn from(value: TestError) -> Self { +// assert!( +// !value.is_transient, +// "transient errors should not be converted!" +// ); +// ANNError::log_async_error(value) +// } +// } +// +// impl ToRanked for TestError { +// type Transient = Self; +// type Error = Self; +// +// fn to_ranked(self) -> RankedError { +// if self.is_transient { +// RankedError::Transient(self) +// } else { +// RankedError::Error(self) +// } +// } +// +// fn from_transient(transient: Self) -> Self { +// assert!(transient.is_transient); +// transient +// } +// +// fn from_error(error: Self) -> Self { +// assert!(!error.is_transient); +// error +// } +// } +// +// type Tda = TableDeleteProviderAsync; +// type TestProvider = DefaultProvider, DefaultQuant, Tda>; +// +// pub struct FlakyAccessor<'a> { +// provider: &'a TestProvider, +// fail_every: usize, +// get_count: usize, +// } +// +// type FullAccessor<'a> = super::FullAccessor<'a, f32, DefaultQuant, Tda, DefaultContext>; +// +// impl SearchExt for FlakyAccessor<'_> { +// fn starting_points(&self) -> impl Future>> { +// std::future::ready(self.provider.starting_points()) +// } +// } +// +// impl<'a> FlakyAccessor<'a> { +// fn new(provider: &'a TestProvider, fail_every: usize, get_count: usize) -> Self { +// assert_ne!(get_count, 0); +// Self { +// provider, +// get_count, +// fail_every, +// } +// } +// +// fn as_full(&self) -> FullAccessor<'a> { +// FullAccessor::new(self.provider) +// } +// } +// +// impl HasId for FlakyAccessor<'_> { +// type Id = u32; +// } +// +// impl Accessor for FlakyAccessor<'_> { +// // /// This accessor returns raw slices. There *is* a chance of racing when the fast +// // /// providers are used. We just have to live with it. +// // type Element<'a> +// // = &'a [f32] +// // where +// // Self: 'a; +// +// type ElementRef<'a> = &'a [f32]; +// +// // /// Choose to panic on an out-of-bounds access rather than propagate an error. +// // type GetError = TestError; +// +// // /// Return the full-precision vector stored at index `i`. +// // /// +// // /// This function always completes synchronously. +// // #[inline(always)] +// // fn get_element( +// // &mut self, +// // id: Self::Id, +// // ) -> impl Future, Self::GetError>> + Send { +// // // Do not fail when retrieving start points. +// // // +// // // NOTE: `is_not_start_point` takes a neighbor, but only looks at the `ID` portion, +// // // so we can use a dummy neighbor. +// // if self.provider.is_not_start_point()(&Neighbor::new(id, 0.0)) { +// // self.get_count -= 1; +// // if self.get_count == 0 { +// // self.get_count = self.fail_every; +// // return std::future::ready(Err(TestError::transient())); +// // } +// // } +// +// // // SAFETY: We've decided to live with UB (undefined behavior) that can result from +// // // potentially mixing unsynchronized reads and writes on the underlying memory. +// // std::future::ready(Ok(unsafe { +// // self.provider.base_vectors.get_vector_sync(id.into_usize()) +// // })) +// // } +// } +// +// impl<'a> BuildDistanceComputer for FlakyAccessor<'a> { +// type DistanceComputerError = as BuildDistanceComputer>::DistanceComputerError; +// type DistanceComputer = as BuildDistanceComputer>::DistanceComputer; +// +// fn build_distance_computer( +// &self, +// ) -> Result { +// self.as_full().build_distance_computer() +// } +// } +// +// impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { +// type QueryComputerError = +// as BuildQueryComputer<&'a [f32]>>::QueryComputerError; +// type QueryComputer = as BuildQueryComputer<&'a [f32]>>::QueryComputer; +// +// fn build_query_computer( +// &self, +// from: &'a [f32], +// ) -> Result { +// self.as_full().build_query_computer(from) +// } +// } +// +// impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} +// +// impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { +// type Delegate = &'a SimpleNeighborProviderAsync; +// fn delegate_neighbor(&'a mut self) -> Self::Delegate { +// self.provider.neighbors() +// } +// } +// +// impl<'x> SearchStrategy for Flaky { +// type QueryComputer = as BuildQueryComputer<&'x [f32]>>::QueryComputer; +// type SearchAccessor<'a> = FlakyAccessor<'a>; +// type SearchAccessorError = ANNError; +// +// fn search_accessor<'a>( +// &'a self, +// provider: &'a TestProvider, +// _context: &'a DefaultContext, +// ) -> Result, Self::SearchAccessorError> { +// Ok(FlakyAccessor::new( +// provider, +// self.fail_every, +// self.fail_every, +// )) +// } +// } +// +// impl DefaultPostProcessor for Flaky { +// default_post_processor!(CopyIds); +// } +// +// const STATIC_PRUNE_THRESHOLD: usize = 5; +// /// We need to tune the flakiness of the `Prune` accessor so that occasionally, the first +// /// item retrieved is a failure. +// static START_COUNT: Mutex = Mutex::new(STATIC_PRUNE_THRESHOLD); +// +// type WorkingSet = map::Map, map::Ref<[f32]>>; +// +// impl PruneStrategy for Flaky { +// type DistanceComputer<'a> = as BuildDistanceComputer>::DistanceComputer; +// type PruneAccessor<'a> = FlakyAccessor<'a>; +// type PruneAccessorError = diskann::error::Infallible; +// type WorkingSet = WorkingSet; +// +// fn prune_accessor<'a>( +// &'a self, +// provider: &'a TestProvider, +// _context: &'a DefaultContext, +// ) -> Result, Self::PruneAccessorError> { +// let mut guard = START_COUNT.lock().unwrap(); +// let start = *guard; +// *guard -= 1; +// if *guard == 0 { +// *guard = STATIC_PRUNE_THRESHOLD; +// } +// +// Ok(FlakyAccessor::new(provider, STATIC_PRUNE_THRESHOLD, start)) +// } +// +// fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { +// map::Builder::new(map::Capacity::Default).build(capacity) +// } +// } +// +// impl InsertStrategy for Flaky { +// type PruneStrategy = Self; +// +// fn prune_strategy(&self) -> Self { +// *self +// } +// } +// +// impl MultiInsertStrategy> for Flaky { +// type Seed = map::Builder>; +// type WorkingSet = WorkingSet; +// type FinishError = diskann::error::Infallible; +// type InsertStrategy = Self; +// +// fn insert_strategy(&self) -> Self::InsertStrategy { +// *self +// } +// +// fn finish( +// &self, +// _provider: &TestProvider, +// _ctx: &DefaultContext, +// batch: &Arc>, +// ids: Itr, +// ) -> impl std::future::Future> + Send +// where +// Itr: ExactSizeIterator + Send, +// { +// std::future::ready(Ok(map::Builder::new(map::Capacity::Default) +// .with_overlay(map::Overlay::from_batch(batch.clone(), ids)))) +// } +// } diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index f0ffc0451..7ebf6b9a1 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -24,12 +24,10 @@ use diskann::{ index::QueryLabelProvider, }, neighbor::Neighbor, - provider::{Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, utils::VectorId, }; -use diskann_utils::Reborrow; use diskann_vector::PreprocessedDistanceFunction; -use futures_util::FutureExt; /// A [`SearchStrategy`] type that composes the inner distance computer with beta filtering. /// @@ -158,39 +156,6 @@ where } } -///////////// -// Helpers // -///////////// - -/// The `Element` and `ElementRef` types used by the [`BetaAccessor`]. -#[derive(Debug, Clone, PartialEq)] -pub struct Pair { - id: I, - element: E, -} - -impl Pair { - fn new(id: I, element: E) -> Self { - Self { id, element } - } -} - -/// `Reborrow` is implemented in terms of a full `Reborrow` of `E` while leaving the id -/// untouched. -impl<'a, I, E> Reborrow<'a> for Pair -where - E: Reborrow<'a>, - I: Copy, -{ - type Target = Pair; - fn reborrow(&'a self) -> Self::Target { - Pair { - id: self.id, - element: self.element.reborrow(), - } - } -} - /// An [`Accessor`] that composes with an `Inner` accessor to provide beta-filtering. pub struct BetaAccessor where @@ -201,13 +166,35 @@ where beta: f32, } -impl SearchExt for BetaAccessor +/// The `Element` and `ElementRef` types used by the [`BetaAccessor`]. +#[derive(Debug, Clone, PartialEq)] +pub struct Pair { + id: I, + element: E, +} + +impl SearchExt for BetaAccessor where - Inner: SearchExt, + Inner: SearchExt, { fn starting_points(&self) -> impl Future>> + Send { self.inner.starting_points() } + + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send { + self.inner.start_point_distances( + computer.inner(), + move |id, distance| { + f(id, computer.apply(id, distance)); + } + ) + } } impl<'a, Inner> DelegateNeighbor<'a> for BetaAccessor @@ -231,50 +218,7 @@ impl Accessor for BetaAccessor where Inner: Accessor, { - /// Modify `Element` to retain the vector ID. - type Element<'a> - = Pair> - where - Self: 'a; type ElementRef<'a> = Pair>; - - /// Use the same error type as `Inner`. - type GetError = Inner::GetError; - - /// Invoke `get_element` on the inner accessor and return a tuple consisting of the - /// retrieved element and `id`. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // The first `map` applies to `Future`. - // The second `map` applies to the `Result`. - self.inner - .get_element(id) - .map(move |result| result.map(move |v| Pair::new(id, v))) - } - - /// Method `on_elements_unordered` is implemented by invoking - /// `inner.on_elements_unordered` with a decorated version of `f`. - async fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> Result<(), Self::GetError> - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - { - self.inner - .on_elements_unordered( - itr, - #[inline] - move |element, id| f(Pair::new(id, element), id), - ) - .await - } } impl BuildQueryComputer for BetaAccessor @@ -296,7 +240,28 @@ where } } -impl ExpandBeam for BetaAccessor where Inner: BuildQueryComputer + AsNeighbor {} +impl ExpandBeam for BetaAccessor +where + Inner: ExpandBeam, +{ + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + self.inner + .expand_beam(ids, computer.inner(), pred, move |distance, id| { + on_neighbors(computer.apply(id, distance), id) + }) + } +} /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { @@ -322,6 +287,15 @@ where pub fn inner(&self) -> &Inner { &self.inner } + + /// Apply the beta-filtering heuristic. + pub fn apply(&self, id: I, distance: f32) -> f32 { + if self.labels.is_match(id) { + distance * self.beta + } else { + distance + } + } } impl PreprocessedDistanceFunction, f32> for BetaComputer @@ -334,281 +308,274 @@ where /// If so, multiply the distance computed by `Inner` by `beta`. #[inline(always)] fn evaluate_similarity(&self, x: Pair) -> f32 { - // Inner distance computation. - let distance = self.inner.evaluate_similarity(x.element); - // Check beta. - if self.labels.is_match(x.id) { - distance * self.beta - } else { - distance - } + self.apply(x.id, self.inner.evaluate_similarity(x.element)) } } -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use diskann::{ - ANNError, ANNResult, always_escalate, - graph::AdjacencyList, - graph::glue::CopyIds, - provider::{DefaultContext, NeighborAccessor, NoopGuard}, - }; - use futures_util::future; - use thiserror::Error; - - use super::*; - - /// A very simple data provider. - struct SimpleProvider; - impl DataProvider for SimpleProvider { - type Context = DefaultContext; - type InternalId = u32; - type ExternalId = u64; - type Guard = NoopGuard; - - type Error = ANNError; - - fn to_internal_id(&self, _context: &DefaultContext, gid: &u64) -> ANNResult { - Ok((*gid).try_into()?) - } - - fn to_external_id(&self, _context: &DefaultContext, id: u32) -> ANNResult { - Ok(id.into()) - } - } - - /// An `Accessor` that doubles its input ID as its output element. - /// - /// This also tracks the number of calls made to `get_element` and - /// `on_elements_unordered` to ensure that `BetaFilter` correctly forwards these methods. - #[derive(Debug, Default, Clone, Copy)] - struct Doubler { - get_element: usize, - on_elements_unordered: usize, - } - - impl SearchExt for Doubler { - async fn starting_points(&self) -> ANNResult> { - Ok(vec![0]) - } - } - - impl Doubler { - fn reset(&mut self) { - *self = Self::default(); - } - } - - /// A simple error type to test error forwarding. - #[derive(Debug, Error)] - #[error("the value {0} is not allowed")] - pub struct NotAllowed(u32); - - impl From for ANNError { - #[inline(never)] - fn from(value: NotAllowed) -> Self { - ANNError::log_async_error(value) - } - } - - impl HasId for Doubler { - type Id = u32; - } - - impl NeighborAccessor for Doubler { - fn get_neighbors( - self, - _id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - neighbors.clear(); - future::ok(self) - } - } - - always_escalate!(NotAllowed); - - impl Accessor for Doubler { - type Element<'a> - = u64 - where - Self: 'a; - type ElementRef<'a> = u64; - - type GetError = NotAllowed; - - fn get_element( - &mut self, - id: u32, - ) -> impl std::future::Future, Self::GetError>> + Send - { - self.get_element += 1; - let is_err = (100..200).contains(&id); - - async move { - if is_err { - Err(NotAllowed(id)) - } else { - let id: u64 = id.into(); - Ok(2 * id) - } - } - } - - async fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> Result<(), Self::GetError> - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - { - self.on_elements_unordered += 1; - for i in itr { - f(self.get_element(i).await?, i); - } - Ok(()) - } - } - - struct AddingComputer(u64); - impl PreprocessedDistanceFunction for AddingComputer { - fn evaluate_similarity(&self, x: u64) -> f32 { - (self.0 + x) as f32 - } - } - - impl BuildQueryComputer for Doubler { - type QueryComputer = AddingComputer; - type QueryComputerError = ANNError; - - fn build_query_computer( - &self, - from: u64, - ) -> Result { - Ok(AddingComputer(from)) - } - } - - impl ExpandBeam for Doubler {} - - #[derive(Debug)] - struct SimpleStrategy; - - impl SearchStrategy for SimpleStrategy { - type SearchAccessor<'a> = Doubler; - type QueryComputer = AddingComputer; - type SearchAccessorError = ANNError; - - fn search_accessor<'a>( - &'a self, - _provider: &'a SimpleProvider, - _context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - Ok(Doubler::default()) - } - } - - impl glue::DefaultPostProcessor for SimpleStrategy { - diskann::default_post_processor!(CopyIds); - } - - /// A simple `QueryLabelProvider` that matches multiples of 3. - #[derive(Debug)] - struct ThreeFilter; - - impl QueryLabelProvider for ThreeFilter { - fn is_match(&self, id: u32) -> bool { - id.is_multiple_of(3) - } - } - - #[tokio::test] - async fn test_beta_filter() { - let provider = SimpleProvider; - let context = &DefaultContext; - let beta: f32 = 0.25; - - let strategy = BetaFilter::new(SimpleStrategy, Arc::new(ThreeFilter), beta); - - let mut accessor: BetaAccessor<_> = strategy.search_accessor(&provider, context).unwrap(); - assert_eq!(accessor.inner.get_element, 0); - assert_eq!(accessor.inner.on_elements_unordered, 0); - - // Test non-erroring path. - let v = accessor.get_element(1).await.unwrap(); - assert_eq!(v, Pair::new(1, 2)); - - let v = accessor.get_element(2).await.unwrap(); - assert_eq!(v, Pair::new(2, 4)); - - // Test erroring path. - assert!(accessor.get_element(100).await.is_err()); - assert!(accessor.get_element(101).await.is_err()); - - assert_eq!(accessor.inner.get_element, 4); - assert_eq!(accessor.inner.on_elements_unordered, 0); - accessor.inner.reset(); - - // On elements unordered. - { - let mut v = Vec::new(); - accessor - .on_elements_unordered([1, 2, 3, 4, 5].into_iter(), |element, id| { - v.push((element, id)); - }) - .await - .unwrap(); - - assert_eq!(accessor.inner.get_element, 5); - assert_eq!(accessor.inner.on_elements_unordered, 1); - assert_eq!( - v, - &[ - (Pair::new(1, 2), 1), - (Pair::new(2, 4), 2), - (Pair::new(3, 6), 3), - (Pair::new(4, 8), 4), - (Pair::new(5, 10), 5) - ] - ); - accessor.inner.reset(); - } - - // On-elements-unordered propagates errors. - assert!( - accessor - .on_elements_unordered([1, 2, 3, 100, 4].into_iter(), |_, _| {}) - .await - .is_err() - ); - - // Computation. - let query = 10; - let computer = accessor.build_query_computer(query).unwrap(); - - assert_eq!( - computer.evaluate_similarity(accessor.get_element(10).await.unwrap()), - (10 * 2 + query) as f32 - ); - assert_eq!( - computer.evaluate_similarity(accessor.get_element(11).await.unwrap()), - (11 * 2 + query) as f32 - ); - assert_eq!( - computer.evaluate_similarity(accessor.get_element(12).await.unwrap()), - beta * ((12 * 2 + query) as f32) - ); - - // Test dummy implementation of `get_neighbors` for code coverage. - let mut neighbors = AdjacencyList::new(); - accessor.get_neighbors(0, &mut neighbors).await.unwrap(); - assert_eq!(neighbors.len(), 0); - } -} +// /////////// +// // Tests // +// /////////// +// +// #[cfg(test)] +// mod tests { +// use diskann::{ +// ANNError, ANNResult, always_escalate, +// graph::AdjacencyList, +// graph::glue::CopyIds, +// provider::{DefaultContext, NeighborAccessor, NoopGuard}, +// }; +// use futures_util::future; +// use thiserror::Error; +// +// use super::*; +// +// /// A very simple data provider. +// struct SimpleProvider; +// impl DataProvider for SimpleProvider { +// type Context = DefaultContext; +// type InternalId = u32; +// type ExternalId = u64; +// type Guard = NoopGuard; +// +// type Error = ANNError; +// +// fn to_internal_id(&self, _context: &DefaultContext, gid: &u64) -> ANNResult { +// Ok((*gid).try_into()?) +// } +// +// fn to_external_id(&self, _context: &DefaultContext, id: u32) -> ANNResult { +// Ok(id.into()) +// } +// } +// +// /// An `Accessor` that doubles its input ID as its output element. +// /// +// /// This also tracks the number of calls made to `get_element` and +// /// `on_elements_unordered` to ensure that `BetaFilter` correctly forwards these methods. +// #[derive(Debug, Default, Clone, Copy)] +// struct Doubler { +// get_element: usize, +// on_elements_unordered: usize, +// } +// +// impl SearchExt for Doubler { +// async fn starting_points(&self) -> ANNResult> { +// Ok(vec![0]) +// } +// } +// +// impl Doubler { +// fn reset(&mut self) { +// *self = Self::default(); +// } +// } +// +// /// A simple error type to test error forwarding. +// #[derive(Debug, Error)] +// #[error("the value {0} is not allowed")] +// pub struct NotAllowed(u32); +// +// impl From for ANNError { +// #[inline(never)] +// fn from(value: NotAllowed) -> Self { +// ANNError::log_async_error(value) +// } +// } +// +// impl HasId for Doubler { +// type Id = u32; +// } +// +// impl NeighborAccessor for Doubler { +// fn get_neighbors( +// self, +// _id: Self::Id, +// neighbors: &mut AdjacencyList, +// ) -> impl Future> + Send { +// neighbors.clear(); +// future::ok(self) +// } +// } +// +// always_escalate!(NotAllowed); +// +// impl Accessor for Doubler { +// // type Element<'a> +// // = u64 +// // where +// // Self: 'a; +// type ElementRef<'a> = u64; +// +// // type GetError = NotAllowed; +// +// // fn get_element( +// // &mut self, +// // id: u32, +// // ) -> impl std::future::Future, Self::GetError>> + Send +// // { +// // self.get_element += 1; +// // let is_err = (100..200).contains(&id); +// +// // async move { +// // if is_err { +// // Err(NotAllowed(id)) +// // } else { +// // let id: u64 = id.into(); +// // Ok(2 * id) +// // } +// // } +// // } +// +// async fn on_elements_unordered( +// &mut self, +// itr: Itr, +// mut f: F, +// ) -> Result<(), Self::GetError> +// where +// Self: Sync, +// Itr: Iterator + Send, +// F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), +// { +// self.on_elements_unordered += 1; +// for i in itr { +// f(self.get_element(i).await?, i); +// } +// Ok(()) +// } +// } +// +// struct AddingComputer(u64); +// impl PreprocessedDistanceFunction for AddingComputer { +// fn evaluate_similarity(&self, x: u64) -> f32 { +// (self.0 + x) as f32 +// } +// } +// +// impl BuildQueryComputer for Doubler { +// type QueryComputer = AddingComputer; +// type QueryComputerError = ANNError; +// +// fn build_query_computer( +// &self, +// from: u64, +// ) -> Result { +// Ok(AddingComputer(from)) +// } +// } +// +// impl ExpandBeam for Doubler {} +// +// #[derive(Debug)] +// struct SimpleStrategy; +// +// impl SearchStrategy for SimpleStrategy { +// type SearchAccessor<'a> = Doubler; +// type QueryComputer = AddingComputer; +// type SearchAccessorError = ANNError; +// +// fn search_accessor<'a>( +// &'a self, +// _provider: &'a SimpleProvider, +// _context: &'a DefaultContext, +// ) -> Result, Self::SearchAccessorError> { +// Ok(Doubler::default()) +// } +// } +// +// impl glue::DefaultPostProcessor for SimpleStrategy { +// diskann::default_post_processor!(CopyIds); +// } +// +// /// A simple `QueryLabelProvider` that matches multiples of 3. +// #[derive(Debug)] +// struct ThreeFilter; +// +// impl QueryLabelProvider for ThreeFilter { +// fn is_match(&self, id: u32) -> bool { +// id.is_multiple_of(3) +// } +// } +// +// #[tokio::test] +// async fn test_beta_filter() { +// let provider = SimpleProvider; +// let context = &DefaultContext; +// let beta: f32 = 0.25; +// +// let strategy = BetaFilter::new(SimpleStrategy, Arc::new(ThreeFilter), beta); +// +// let mut accessor: BetaAccessor<_> = strategy.search_accessor(&provider, context).unwrap(); +// assert_eq!(accessor.inner.get_element, 0); +// assert_eq!(accessor.inner.on_elements_unordered, 0); +// +// // Test non-erroring path. +// let v = accessor.get_element(1).await.unwrap(); +// assert_eq!(v, Pair::new(1, 2)); +// +// let v = accessor.get_element(2).await.unwrap(); +// assert_eq!(v, Pair::new(2, 4)); +// +// // Test erroring path. +// assert!(accessor.get_element(100).await.is_err()); +// assert!(accessor.get_element(101).await.is_err()); +// +// assert_eq!(accessor.inner.get_element, 4); +// assert_eq!(accessor.inner.on_elements_unordered, 0); +// accessor.inner.reset(); +// +// // On elements unordered. +// { +// let mut v = Vec::new(); +// accessor +// .on_elements_unordered([1, 2, 3, 4, 5].into_iter(), |element, id| { +// v.push((element, id)); +// }) +// .await +// .unwrap(); +// +// assert_eq!(accessor.inner.get_element, 5); +// assert_eq!(accessor.inner.on_elements_unordered, 1); +// assert_eq!( +// v, +// &[ +// (Pair::new(1, 2), 1), +// (Pair::new(2, 4), 2), +// (Pair::new(3, 6), 3), +// (Pair::new(4, 8), 4), +// (Pair::new(5, 10), 5) +// ] +// ); +// accessor.inner.reset(); +// } +// +// // On-elements-unordered propagates errors. +// assert!( +// accessor +// .on_elements_unordered([1, 2, 3, 100, 4].into_iter(), |_, _| {}) +// .await +// .is_err() +// ); +// +// // Computation. +// let query = 10; +// let computer = accessor.build_query_computer(query).unwrap(); +// +// assert_eq!( +// computer.evaluate_similarity(accessor.get_element(10).await.unwrap()), +// (10 * 2 + query) as f32 +// ); +// assert_eq!( +// computer.evaluate_similarity(accessor.get_element(11).await.unwrap()), +// (11 * 2 + query) as f32 +// ); +// assert_eq!( +// computer.evaluate_similarity(accessor.get_element(12).await.unwrap()), +// beta * ((12 * 2 + query) as f32) +// ); +// +// // Test dummy implementation of `get_neighbors` for code coverage. +// let mut neighbors = AdjacencyList::new(); +// accessor.get_neighbors(0, &mut neighbors).await.unwrap(); +// assert_eq!(neighbors.len(), 0); +// } +// } diff --git a/diskann-providers/src/storage/index_storage.rs b/diskann-providers/src/storage/index_storage.rs index 21a1d9333..2e76beff3 100644 --- a/diskann-providers/src/storage/index_storage.rs +++ b/diskann-providers/src/storage/index_storage.rs @@ -346,12 +346,12 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - id_iter.clone(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // id_iter.clone(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -367,12 +367,12 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - id_iter.clone(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // id_iter.clone(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -389,19 +389,19 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; - - check_accessor_equal( - inmem::product::QuantAccessor::new(index.provider()), - inmem::product::QuantAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; + + // check_accessor_equal( + // inmem::product::QuantAccessor::new(index.provider()), + // inmem::product::QuantAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -417,19 +417,19 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; - - check_accessor_equal( - inmem::product::QuantAccessor::new(index.provider()), - inmem::product::QuantAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; + + // check_accessor_equal( + // inmem::product::QuantAccessor::new(index.provider()), + // inmem::product::QuantAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -439,22 +439,22 @@ mod tests { } } - async fn check_accessor_equal(mut left: A, mut right: B, itr: Itr) - where - A: for<'a> Accessor = &'a T>, - B: for<'a> Accessor = &'a T>, - T: PartialEq + std::fmt::Debug + ?Sized, - Itr: Iterator, - { - for i in itr { - assert_eq!( - left.get_element(i).await.unwrap().reborrow(), - right.get_element(i).await.unwrap().reborrow(), - "failed for index {}", - i - ); - } - } + // async fn check_accessor_equal(mut left: A, mut right: B, itr: Itr) + // where + // A: for<'a> Accessor = &'a T>, + // B: for<'a> Accessor = &'a T>, + // T: PartialEq + std::fmt::Debug + ?Sized, + // Itr: Iterator, + // { + // for i in itr { + // assert_eq!( + // left.get_element(i).await.unwrap().reborrow(), + // right.get_element(i).await.unwrap().reborrow(), + // "failed for index {}", + // i + // ); + // } + // } fn check_graphs_equal( left: &SimpleNeighborProviderAsync, diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index cb098b85c..0bc4e7c1c 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -79,28 +79,35 @@ use std::{future::Future, sync::Arc}; use diskann_utils::Reborrow; -use diskann_utils::future::AssertSend; use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; use crate::{ ANNError, ANNResult, - error::{ErrorExt, StandardError}, - graph::{AdjacencyList, SearchOutputBuffer, workingset}, + error::{StandardError}, + graph::{SearchOutputBuffer, workingset}, neighbor::Neighbor, provider::{ Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, HasId, NeighborAccessor, + DataProvider, HasId, }, utils::VectorId, }; /// A trait to override search constraints such as early termination based on constraints /// by implementer. -pub trait SearchExt: Accessor { +pub trait SearchExt: BuildQueryComputer { /// Return a `Vec` containing the starting points. fn starting_points(&self) -> impl std::future::Future>> + Send; + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send; + /// Default is to never terminate early. fn terminate_early(&mut self) -> bool { false @@ -252,38 +259,18 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// ## Error Handling /// /// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: BuildQueryComputer + AsNeighbor + Sized { +pub trait ExpandBeam: BuildQueryComputer + AsNeighbor { fn expand_beam( &mut self, ids: Itr, computer: &Self::QueryComputer, - mut pred: P, - mut on_neighbors: F, + pred: P, + on_neighbors: F, ) -> impl std::future::Future> + Send where Itr: Iterator + Send, P: HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - async move { - let mut neighbors = AdjacencyList::new(); - for id in ids { - self.get_neighbors(id, &mut neighbors).send().await?; - neighbors.retain(|i| pred.eval(i)); - - self.distances_unordered(neighbors.iter().copied(), computer, |distance, id| { - if pred.eval_mut(&id) { - on_neighbors(distance, id); - } - }) - .send() - .await - .allow_transient("allowing transient error in beam expansion")?; - } - - Ok(()) - } - } + F: FnMut(f32, Self::Id) + Send; } /// A search strategy for query objects of type `T`. @@ -312,7 +299,7 @@ where /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. type SearchAccessor<'a>: ExpandBeam - + SearchExt; + + SearchExt; /// Construct and return the search accessor. fn search_accessor<'a>( @@ -471,7 +458,7 @@ pub struct FilterStartPoints; impl SearchPostProcessStep for FilterStartPoints where - A: BuildQueryComputer + SearchExt, + A: BuildQueryComputer + SearchExt, T: Copy + Send + Sync, { /// A this level, sub-errors are converted into [`ANNError`] to provide additional @@ -787,7 +774,7 @@ where /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. type DeleteSearchAccessor<'a>: ExpandBeam, Id = Provider::InternalId> - + SearchExt; + + SearchExt>; /// The processor used during the delete-search phase. type SearchPostProcessor: for<'a> SearchPostProcess, Self::DeleteElement<'a>> @@ -841,237 +828,237 @@ where // Tests // /////////// -#[cfg(test)] -mod tests { - use std::sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }; - - use diskann_vector::PreprocessedDistanceFunction; - use futures_util::future; - - use super::*; - use crate::{ - ANNResult, neighbor, - provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor}, - }; - - // A really simple provider that just holds floats and uses the absolute value for its - // distances. - struct SimpleProvider { - items: Vec, - } - - #[derive(Default, Clone)] - struct CountGetVector { - count: Arc, - } - impl ExecutionContext for CountGetVector {} - - impl CountGetVector { - fn count(&self) -> usize { - self.count.load(Ordering::Relaxed) - } - - fn clear(&self) { - self.count.store(0, Ordering::Relaxed) - } - } - - impl DataProvider for SimpleProvider { - type Context = CountGetVector; - type InternalId = u32; - type ExternalId = u32; - type Error = ANNError; - type Guard = crate::provider::NoopGuard; - - /// Translate an external id to its corresponding internal id. - fn to_internal_id( - &self, - _context: &CountGetVector, - gid: &Self::ExternalId, - ) -> Result { - Ok(*gid) - } - - /// Translate an internal id to its corresponding external id. - fn to_external_id( - &self, - _context: &CountGetVector, - id: Self::InternalId, - ) -> Result { - Ok(id) - } - } - - #[derive(Clone, Copy)] - struct Retriever<'a> { - provider: &'a SimpleProvider, - count: &'a CountGetVector, - } - - impl SearchExt for Retriever<'_> { - async fn starting_points(&self) -> ANNResult> { - Ok(vec![0]) - } - } - - impl<'a> Retriever<'a> { - fn new(provider: &'a SimpleProvider, count: &'a CountGetVector) -> Self { - Self { provider, count } - } - } - - impl HasId for Retriever<'_> { - type Id = u32; - } - - impl Accessor for Retriever<'_> { - type Element<'a> - = f32 - where - Self: 'a; - type ElementRef<'a> = f32; - - type GetError = ANNError; - fn get_element( - &mut self, - id: Self::Id, - ) -> impl std::future::Future, Self::GetError>> + Send - { - let result = match self.provider.items.get(id as usize) { - Some(v) => { - self.count.count.fetch_add(1, Ordering::Relaxed); - Ok(*v) - } - None => panic!("invalid id: {}", id), - }; - async move { result } - } - } - - impl NeighborAccessor for Retriever<'_> { - fn get_neighbors( - self, - _id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - neighbors.clear(); - future::ok(self) - } - } - - struct QueryComputer; - impl PreprocessedDistanceFunction for QueryComputer { - fn evaluate_similarity(&self, _changing: f32) -> f32 { - panic!("this method should not be called") - } - } - - impl BuildQueryComputer for Retriever<'_> { - type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; - fn build_query_computer(&self, _from: f32) -> Result { - Ok(QueryComputer) - } - } - - impl ExpandBeam for Retriever<'_> {} - - // This strategy explicitly does not define `post_process` so we can test the provided - // implementation. - struct Strategy; - - impl SearchStrategy for Strategy { - type QueryComputer = QueryComputer; - type SearchAccessorError = ANNError; - type SearchAccessor<'a> = Retriever<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a SimpleProvider, - context: &'a CountGetVector, - ) -> Result, Self::SearchAccessorError> { - Ok(Retriever::new(provider, context)) - } - } - - impl DefaultPostProcessor for Strategy { - default_post_processor!(CopyIds); - } - - #[tokio::test(flavor = "current_thread")] - async fn test_default_post_process() { - let ctx = CountGetVector::default(); - let strategy = Strategy; - - let num_points: usize = 100; - let provider = SimpleProvider { - items: (0..num_points).map(|i| i as f32).collect(), - }; - - assert_eq!(provider.to_internal_id(&ctx, &10).unwrap(), 10); - assert_eq!(provider.to_external_id(&ctx, 10).unwrap(), 10); - - let mut accessor = strategy.search_accessor(&provider, &ctx).unwrap(); - assert_eq!(accessor.starting_points().await.unwrap().as_slice(), &[0]); - for i in 0..num_points { - assert_eq!(accessor.get_element(i as u32).await.unwrap(), i as f32); - } - - // Check dummy get_neighbors implmeentation for code coverage - let mut neighbors = AdjacencyList::new(); - accessor - .delegate_neighbor() - .get_neighbors(0, &mut neighbors) - .await - .unwrap(); - assert_eq!(neighbors.len(), 0); - - // Check that the correct number of reads were emitted. - assert_eq!(ctx.count(), num_points); - ctx.clear(); - - let query = 11.5; - let computer = accessor.build_query_computer(query).unwrap(); - - for input_len in 0..10 { - let input: Vec<_> = (0..input_len) - .map(|i| Neighbor::::new(i as u32, i as f32)) - .collect(); - for output_len in 0..10 { - let mut output = vec![Neighbor::::default(); output_len]; - - let count = strategy - .default_post_processor() - .post_process( - &mut accessor, - query, - &computer, - input.iter().copied(), - &mut neighbor::BackInserter::new(output.as_mut_slice()), - ) - .await - .unwrap(); - - assert_eq!(count, input_len.min(output_len)); - - // Check that the in-range values were properly copied. - for (i, n) in output.iter().take(count).enumerate() { - assert_eq!(i, n.id as usize); - assert_eq!(i as f32, n.distance); - } - - // Check that out-of-range values were untouched. - for n in output.iter().skip(count) { - assert_eq!(n.id, 0); - assert_eq!(n.distance, 0.0); - } - } - } - - // Ensure that no reads were emitted. - assert_eq!(ctx.count(), 0); - } -} +// #[cfg(test)] +// mod tests { +// use std::sync::{ +// Arc, +// atomic::{AtomicUsize, Ordering}, +// }; +// +// use diskann_vector::PreprocessedDistanceFunction; +// use futures_util::future; +// +// use super::*; +// use crate::{ +// ANNResult, neighbor, +// provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor}, +// }; +// +// // A really simple provider that just holds floats and uses the absolute value for its +// // distances. +// struct SimpleProvider { +// items: Vec, +// } +// +// #[derive(Default, Clone)] +// struct CountGetVector { +// count: Arc, +// } +// impl ExecutionContext for CountGetVector {} +// +// impl CountGetVector { +// fn count(&self) -> usize { +// self.count.load(Ordering::Relaxed) +// } +// +// fn clear(&self) { +// self.count.store(0, Ordering::Relaxed) +// } +// } +// +// impl DataProvider for SimpleProvider { +// type Context = CountGetVector; +// type InternalId = u32; +// type ExternalId = u32; +// type Error = ANNError; +// type Guard = crate::provider::NoopGuard; +// +// /// Translate an external id to its corresponding internal id. +// fn to_internal_id( +// &self, +// _context: &CountGetVector, +// gid: &Self::ExternalId, +// ) -> Result { +// Ok(*gid) +// } +// +// /// Translate an internal id to its corresponding external id. +// fn to_external_id( +// &self, +// _context: &CountGetVector, +// id: Self::InternalId, +// ) -> Result { +// Ok(id) +// } +// } +// +// #[derive(Clone, Copy)] +// struct Retriever<'a> { +// provider: &'a SimpleProvider, +// count: &'a CountGetVector, +// } +// +// impl SearchExt for Retriever<'_> { +// async fn starting_points(&self) -> ANNResult> { +// Ok(vec![0]) +// } +// } +// +// impl<'a> Retriever<'a> { +// fn new(provider: &'a SimpleProvider, count: &'a CountGetVector) -> Self { +// Self { provider, count } +// } +// } +// +// impl HasId for Retriever<'_> { +// type Id = u32; +// } +// +// impl Accessor for Retriever<'_> { +// // type Element<'a> +// // = f32 +// // where +// // Self: 'a; +// type ElementRef<'a> = f32; +// +// // type GetError = ANNError; +// // fn get_element( +// // &mut self, +// // id: Self::Id, +// // ) -> impl std::future::Future, Self::GetError>> + Send +// // { +// // let result = match self.provider.items.get(id as usize) { +// // Some(v) => { +// // self.count.count.fetch_add(1, Ordering::Relaxed); +// // Ok(*v) +// // } +// // None => panic!("invalid id: {}", id), +// // }; +// // async move { result } +// // } +// } +// +// impl NeighborAccessor for Retriever<'_> { +// fn get_neighbors( +// self, +// _id: Self::Id, +// neighbors: &mut AdjacencyList, +// ) -> impl Future> + Send { +// neighbors.clear(); +// future::ok(self) +// } +// } +// +// struct QueryComputer; +// impl PreprocessedDistanceFunction for QueryComputer { +// fn evaluate_similarity(&self, _changing: f32) -> f32 { +// panic!("this method should not be called") +// } +// } +// +// impl BuildQueryComputer for Retriever<'_> { +// type QueryComputerError = ANNError; +// type QueryComputer = QueryComputer; +// fn build_query_computer(&self, _from: f32) -> Result { +// Ok(QueryComputer) +// } +// } +// +// impl ExpandBeam for Retriever<'_> {} +// +// // This strategy explicitly does not define `post_process` so we can test the provided +// // implementation. +// struct Strategy; +// +// impl SearchStrategy for Strategy { +// type QueryComputer = QueryComputer; +// type SearchAccessorError = ANNError; +// type SearchAccessor<'a> = Retriever<'a>; +// +// fn search_accessor<'a>( +// &'a self, +// provider: &'a SimpleProvider, +// context: &'a CountGetVector, +// ) -> Result, Self::SearchAccessorError> { +// Ok(Retriever::new(provider, context)) +// } +// } +// +// impl DefaultPostProcessor for Strategy { +// default_post_processor!(CopyIds); +// } +// +// #[tokio::test(flavor = "current_thread")] +// async fn test_default_post_process() { +// let ctx = CountGetVector::default(); +// let strategy = Strategy; +// +// let num_points: usize = 100; +// let provider = SimpleProvider { +// items: (0..num_points).map(|i| i as f32).collect(), +// }; +// +// assert_eq!(provider.to_internal_id(&ctx, &10).unwrap(), 10); +// assert_eq!(provider.to_external_id(&ctx, 10).unwrap(), 10); +// +// let mut accessor = strategy.search_accessor(&provider, &ctx).unwrap(); +// assert_eq!(accessor.starting_points().await.unwrap().as_slice(), &[0]); +// for i in 0..num_points { +// assert_eq!(accessor.get_element(i as u32).await.unwrap(), i as f32); +// } +// +// // Check dummy get_neighbors implmeentation for code coverage +// let mut neighbors = AdjacencyList::new(); +// accessor +// .delegate_neighbor() +// .get_neighbors(0, &mut neighbors) +// .await +// .unwrap(); +// assert_eq!(neighbors.len(), 0); +// +// // Check that the correct number of reads were emitted. +// assert_eq!(ctx.count(), num_points); +// ctx.clear(); +// +// let query = 11.5; +// let computer = accessor.build_query_computer(query).unwrap(); +// +// for input_len in 0..10 { +// let input: Vec<_> = (0..input_len) +// .map(|i| Neighbor::::new(i as u32, i as f32)) +// .collect(); +// for output_len in 0..10 { +// let mut output = vec![Neighbor::::default(); output_len]; +// +// let count = strategy +// .default_post_processor() +// .post_process( +// &mut accessor, +// query, +// &computer, +// input.iter().copied(), +// &mut neighbor::BackInserter::new(output.as_mut_slice()), +// ) +// .await +// .unwrap(); +// +// assert_eq!(count, input_len.min(output_len)); +// +// // Check that the in-range values were properly copied. +// for (i, n) in output.iter().take(count).enumerate() { +// assert_eq!(i, n.id as usize); +// assert_eq!(i as f32, n.distance); +// } +// +// // Check that out-of-range values were untouched. +// for n in output.iter().skip(count) { +// assert_eq!(n.id, 0); +// assert_eq!(n.distance, 0.0); +// } +// } +// } +// +// // Ensure that no reads were emitted. +// assert_eq!(ctx.count(), 0); +// } +// } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index d696135a0..596139394 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -16,7 +16,7 @@ use diskann_utils::{ Reborrow, future::{AssertSend, SendFuture, boxit}, }; -use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; +use diskann_vector::{DistanceFunction}; use futures_util::FutureExt; use hashbrown::HashSet; use thiserror::Error; @@ -25,12 +25,11 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ - self, Batch, ExpandBeam, IdIterator, InplaceDeleteStrategy, InsertStrategy, + self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ - Knn, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -346,7 +345,6 @@ where self.search_internal( None, // beam_width - &start_ids, &mut accessor, &computer, &mut scratch, @@ -475,7 +473,6 @@ where self.search_internal( None, // beam_width - &start_ids, &mut accessor, &computer, &mut scratch, @@ -1291,7 +1288,6 @@ where self.search_internal( None, // beam_width - &start_ids, &mut search_accessor, &computer, &mut scratch, @@ -2018,35 +2014,44 @@ where pub(crate) fn search_internal( &self, beam_width: Option, - start_ids: &[DP::InternalId], accessor: &mut A, computer: &A::QueryComputer, scratch: &mut SearchScratch, search_record: &mut SR, ) -> impl SendFuture> where - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, SR: SearchRecord + ?Sized, Q: NeighborQueue, { async move { let beam_width = beam_width.unwrap_or(1); - // paged search can call search_internal multiple times, we only need to initialize - // state if not already initialized. if scratch.visited.is_empty() { - for id in start_ids { - scratch.visited.insert(*id); - let element = accessor - .get_element(*id) - .await - .escalate("start point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); - scratch.best.insert(Neighbor::new(*id, dist)); - scratch.cmps += 1; - } + accessor + .start_point_distances(computer, |id, dist| { + scratch.visited.insert(id); + scratch.best.insert(Neighbor::new(id, dist)); + scratch.cmps += 1; + }) + .await?; } + // // paged search can call search_internal multiple times, we only need to initialize + // // state if not already initialized. + // if scratch.visited.is_empty() { + // for id in start_ids { + // scratch.visited.insert(*id); + // let element = accessor + // .get_element(*id) + // .await + // .escalate("start point retrieval must succeed")?; + // let dist = computer.evaluate_similarity(element.reborrow()); + // scratch.best.insert(Neighbor::new(*id, dist)); + // scratch.cmps += 1; + // } + // } + let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); while scratch.best.has_notvisited_node() && !accessor.terminate_early() { scratch.beam_nodes.clear(); @@ -2183,98 +2188,99 @@ where search_params.search(self, strategy, processor, context, query, output) } - /// Performs a brute-force flat search over the points matching a provided filter function. - /// - /// This method executes a linear scan through all points in the index, applying the provided - /// `vector_filter` to select candidate points. It computes the similarity between the query - /// vector and each candidate, returning the top results according to the provided search parameters. - /// - /// # Arguments - /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `vector_filter` - A predicate function used to filter candidate vectors based on their external IDs. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`). - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// - /// # Returns - /// - /// Returns search statistics including the number of distance computations performed. - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - /// - /// # Notes - /// - /// This method is computationally expensive for large datasets, as it does not leverage the graph structure - /// and instead performs a linear scan of all filtered points. - pub async fn flat_search<'a, S, T, O, OB, I>( - &'a self, - strategy: &'a S, - context: &'a DP::Context, - query: T, - vector_filter: &(dyn Fn(&DP::ExternalId) -> bool + Send + Sync), - search_params: &Knn, - output: &mut OB, - ) -> ANNResult - where - T: Copy + Send, - S: glue::DefaultSearchStrategy: IdIterator>, - I: Iterator::InternalId>, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send, - { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let mut scratch = { - let num_start_points = accessor.starting_points().await?.len(); - self.search_scratch(search_params.l_value().get(), num_start_points) - }; - - let id_iterator = accessor.id_iterator().await?; - for id in id_iterator { - let external_id = self - .data_provider - .to_external_id(context, id) - .escalate("external id should be found")?; - - if vector_filter(&external_id) { - scratch.visited.insert(id); - let element = accessor - .get_element(id) - .await - .escalate("matched point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); - scratch.best.insert(Neighbor::new(id, dist)); - scratch.cmps += 1; - } - } - - let result_count = strategy - .default_post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(search_params.l_value().get()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(SearchStats { - cmps: scratch.cmps, - hops: scratch.hops, - result_count: result_count as u32, - range_search_second_round: false, - }) - } + // /// Performs a brute-force flat search over the points matching a provided filter function. + // /// + // /// This method executes a linear scan through all points in the index, applying the provided + // /// `vector_filter` to select candidate points. It computes the similarity between the query + // /// vector and each candidate, returning the top results according to the provided search parameters. + // /// + // /// # Arguments + // /// + // /// * `strategy` - The search strategy to use for accessing and processing elements. + // /// * `context` - The context to pass through to providers. + // /// * `query` - The query vector for which nearest neighbors are sought. + // /// * `vector_filter` - A predicate function used to filter candidate vectors based on their external IDs. + // /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`). + // /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. + // /// + // /// # Returns + // /// + // /// Returns search statistics including the number of distance computations performed. + // /// + // /// # Errors + // /// + // /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. + // /// + // /// # Notes + // /// + // /// This method is computationally expensive for large datasets, as it does not leverage the graph structure + // /// and instead performs a linear scan of all filtered points. + // pub async fn flat_search<'a, S, T, O, OB, I>( + // &'a self, + // strategy: &'a S, + // context: &'a DP::Context, + // query: T, + // vector_filter: &(dyn Fn(&DP::ExternalId) -> bool + Send + Sync), + // search_params: &Knn, + // output: &mut OB, + // ) -> ANNResult + // where + // T: Copy + Send, + // S: glue::DefaultSearchStrategy: IdIterator>, + // I: Iterator::InternalId>, + // O: Send, + // OB: search_output_buffer::SearchOutputBuffer + Send, + // { + // let mut accessor = strategy + // .search_accessor(&self.data_provider, context) + // .into_ann_result()?; + // let computer = accessor.build_query_computer(query).into_ann_result()?; + + // let mut scratch = { + // let num_start_points = accessor.starting_points().await?.len(); + // self.search_scratch(search_params.l_value().get(), num_start_points) + // }; + + // let id_iterator = accessor.id_iterator().await?; + // for id in id_iterator { + // let external_id = self + // .data_provider + // .to_external_id(context, id) + // .escalate("external id should be found")?; + + // if vector_filter(&external_id) + // { + // scratch.visited.insert(id); + // let element = accessor + // .get_element(id) + // .await + // .escalate("matched point retrieval must succeed")?; + // let dist = computer.evaluate_similarity(element.reborrow()); + // scratch.best.insert(Neighbor::new(id, dist)); + // scratch.cmps += 1; + // } + // } + + // let result_count = strategy + // .default_post_processor() + // .post_process( + // &mut accessor, + // query, + // &computer, + // scratch.best.iter().take(search_params.l_value().get()), + // output, + // ) + // .send() + // .await + // .into_ann_result()?; + + // Ok(SearchStats { + // cmps: scratch.cmps, + // hops: scratch.hops, + // result_count: result_count as u32, + // range_search_second_round: false, + // }) + // } ////////////////// // Paged Search // @@ -2427,7 +2433,6 @@ where let start_ids = accessor.starting_points().await?; self.search_internal( None, // beam_width - &start_ids, &mut accessor, &search_state.extra.1, &mut search_state.scratch, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index c05594ce0..9468e4204 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -202,7 +202,6 @@ where let stats = index .search_internal( Some(self.beam_width.get()), - &start_ids, &mut accessor, &computer, &mut scratch, @@ -278,7 +277,6 @@ where let stats = index .search_internal( Some(self.inner.beam_width.get()), - &start_ids, &mut accessor, &computer, &mut scratch, diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index aba0f44c5..3eb6ef923 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -5,15 +5,13 @@ //! Label-filtered search using multi-hop expansion. -use diskann_utils::Reborrow; use diskann_utils::future::SendFuture; -use diskann_vector::PreprocessedDistanceFunction; use hashbrown::HashSet; use super::{Knn, Search, record::SearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, - error::{ErrorExt, IntoANNResult}, + error::{IntoANNResult}, graph::{ glue::{ self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, @@ -182,7 +180,7 @@ pub(crate) async fn multihop_search_internal( ) -> ANNResult where I: VectorId, - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); @@ -194,21 +192,13 @@ where range_search_second_round: false, }; - // Initialize search state if not already initialized. - // This allows paged search to call multihop_search_internal multiple times - if scratch.visited.is_empty() { - let start_ids = accessor.starting_points().await?; - - for id in start_ids { + accessor + .start_point_distances(computer, |id, dist| { scratch.visited.insert(id); - let element = accessor - .get_element(id) - .await - .escalate("start point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); scratch.best.insert(Neighbor::new(id, dist)); - } - } + scratch.cmps += 1; + }) + .await?; // Pre-allocate with good capacity to avoid repeated allocations let mut one_hop_neighbors = Vec::with_capacity(max_degree_with_slack); diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index b92fa1384..400fd3f0c 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -191,7 +191,6 @@ where let initial_stats = index .search_internal( self.beam_width(), - &start_ids, &mut accessor, &computer, &mut scratch, @@ -332,7 +331,7 @@ pub(crate) async fn range_search_internal( ) -> ANNResult where I: crate::utils::VectorId, - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, { let beam_width = search_params.beam_width().unwrap_or(1); diff --git a/diskann/src/graph/test/cases/index.rs b/diskann/src/graph/test/cases/index.rs index 5d91de7bb..1b195fe4a 100644 --- a/diskann/src/graph/test/cases/index.rs +++ b/diskann/src/graph/test/cases/index.rs @@ -282,83 +282,83 @@ async fn test_drop_deleted_neighbors_noop() { assert_eq!(result, graph::ConsolidateKind::Complete); } -#[tokio::test(flavor = "current_thread")] -async fn test_flat_search_basic() { - use crate::graph::search::Knn; - use crate::graph::search_output_buffer::IdDistance; - - let adjacency_list = generate_2d_square_adjacency_list(); - let index = setup_2d_square(adjacency_list, 4); - let strategy = test_provider::Strategy::new(); - let ctx = test_provider::Context::new(); - - // Query near origin — node 0 at (0,0) is closest. - // l_value must cover all 5 points (4 data + 1 start) so the working set - // doesn't drop any before the post-processor runs. - let query = [0.1_f32, 0.1]; - let params = Knn::new(4, 5, None).unwrap(); - - let mut ids = [0u32; 4]; - let mut distances = [0.0f32; 4]; - let mut output = IdDistance::new(&mut ids, &mut distances); - - let stats = index - .flat_search( - &strategy, - &ctx, - query.as_slice(), - &|_| true, - ¶ms, - &mut output, - ) - .await - .unwrap(); - - // FilterStartPoints removes the start node, leaving 4 data nodes. - assert_eq!(stats.result_count, 4); - let results: std::collections::HashSet = - ids[..stats.result_count as usize].iter().copied().collect(); - for id in 0..4u32 { - assert!(results.contains(&id), "data node {id} should be in results"); - } -} - -#[tokio::test(flavor = "current_thread")] -async fn test_flat_search_with_filter() { - use crate::graph::search::Knn; - use crate::graph::search_output_buffer::IdDistance; - - let adjacency_list = generate_2d_square_adjacency_list(); - let index = setup_2d_square(adjacency_list, 4); - let strategy = test_provider::Strategy::new(); - let ctx = test_provider::Context::new(); - - // Query near origin, but filter out node 0. - let query = [0.1_f32, 0.1]; - let params = Knn::new(2, 4, None).unwrap(); - - let mut ids = [0u32; 2]; - let mut distances = [0.0f32; 2]; - let mut output = IdDistance::new(&mut ids, &mut distances); - - let stats = index - .flat_search( - &strategy, - &ctx, - query.as_slice(), - &|ext_id: &u32| *ext_id != 0, - ¶ms, - &mut output, - ) - .await - .unwrap(); - - assert_eq!(stats.result_count, 2); - assert!( - !ids[..stats.result_count as usize].contains(&0), - "node 0 should be filtered out" - ); - // Nodes 1, 2, 3 remain — closest two to (0.1, 0.1) are 1 (1,0) and 2 (0,1). - assert!(ids.contains(&1), "node 1 should be present"); - assert!(ids.contains(&2), "node 2 should be present"); -} +// #[tokio::test(flavor = "current_thread")] +// async fn test_flat_search_basic() { +// use crate::graph::search::Knn; +// use crate::graph::search_output_buffer::IdDistance; +// +// let adjacency_list = generate_2d_square_adjacency_list(); +// let index = setup_2d_square(adjacency_list, 4); +// let strategy = test_provider::Strategy::new(); +// let ctx = test_provider::Context::new(); +// +// // Query near origin — node 0 at (0,0) is closest. +// // l_value must cover all 5 points (4 data + 1 start) so the working set +// // doesn't drop any before the post-processor runs. +// let query = [0.1_f32, 0.1]; +// let params = Knn::new(4, 5, None).unwrap(); +// +// let mut ids = [0u32; 4]; +// let mut distances = [0.0f32; 4]; +// let mut output = IdDistance::new(&mut ids, &mut distances); +// +// let stats = index +// .flat_search( +// &strategy, +// &ctx, +// query.as_slice(), +// &|_| true, +// ¶ms, +// &mut output, +// ) +// .await +// .unwrap(); +// +// // FilterStartPoints removes the start node, leaving 4 data nodes. +// assert_eq!(stats.result_count, 4); +// let results: std::collections::HashSet = +// ids[..stats.result_count as usize].iter().copied().collect(); +// for id in 0..4u32 { +// assert!(results.contains(&id), "data node {id} should be in results"); +// } +// } +// +// #[tokio::test(flavor = "current_thread")] +// async fn test_flat_search_with_filter() { +// use crate::graph::search::Knn; +// use crate::graph::search_output_buffer::IdDistance; +// +// let adjacency_list = generate_2d_square_adjacency_list(); +// let index = setup_2d_square(adjacency_list, 4); +// let strategy = test_provider::Strategy::new(); +// let ctx = test_provider::Context::new(); +// +// // Query near origin, but filter out node 0. +// let query = [0.1_f32, 0.1]; +// let params = Knn::new(2, 4, None).unwrap(); +// +// let mut ids = [0u32; 2]; +// let mut distances = [0.0f32; 2]; +// let mut output = IdDistance::new(&mut ids, &mut distances); +// +// let stats = index +// .flat_search( +// &strategy, +// &ctx, +// query.as_slice(), +// &|ext_id: &u32| *ext_id != 0, +// ¶ms, +// &mut output, +// ) +// .await +// .unwrap(); +// +// assert_eq!(stats.result_count, 2); +// assert!( +// !ids[..stats.result_count as usize].contains(&0), +// "node 0 should be filtered out" +// ); +// // Nodes 1, 2, 3 remain — closest two to (0.1, 0.1) are 1 (1,0) and 2 (0,1). +// assert!(ids.contains(&1), "node 1 should be present"); +// assert!(ids.contains(&2), "node 2 should be present"); +// } diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 051ec6fc7..b1c548cda 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -13,13 +13,13 @@ use std::{ }; use dashmap::{DashMap, mapref::entry::Entry}; -use diskann_utils::views::Matrix; -use diskann_vector::distance::Metric; +use diskann_utils::{future::SendFuture, views::Matrix}; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use thiserror::Error; use crate::{ ANNError, ANNResult, default_post_processor, - error::{Infallible, RankedError, StandardError, ToRanked, TransientError, message}, + error::{ErrorExt, Infallible, RankedError, StandardError, ToRanked, TransientError, message}, graph::{AdjacencyList, SearchOutputBuffer, glue, test::synthetic, workingset}, internal::counter::{Counter, LocalCounter}, neighbor::Neighbor, @@ -1034,21 +1034,8 @@ impl<'a> Accessor<'a> { transient_ids, } } -} - -impl provider::HasId for Accessor<'_> { - type Id = u32; -} - -impl provider::Accessor for Accessor<'_> { - type Element<'a> - = &'a [f32] - where - Self: 'a; - type ElementRef<'a> = &'a [f32]; - type GetError = AccessError; - async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> { + pub fn get(&mut self, id: u32) -> Result<&[f32], AccessError> { match self.provider.terms.get(&id) { Some(term) => { if let Some(transient) = &self.transient_ids @@ -1066,6 +1053,36 @@ impl provider::Accessor for Accessor<'_> { } } +impl provider::HasId for Accessor<'_> { + type Id = u32; +} + +impl provider::Accessor for Accessor<'_> { + // type Element<'a> + // = &'a [f32] + // where + // Self: 'a; + type ElementRef<'a> = &'a [f32]; + // type GetError = AccessError; + + // async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> { + // match self.provider.terms.get(&id) { + // Some(term) => { + // if let Some(transient) = &self.transient_ids + // && transient.contains(&id) + // { + // return Err(AccessError::Transient(TransientAccessError::new(id))); + // } + + // self.get_vector.increment(); + // self.buffer.copy_from_slice(&term.data); + // Ok(&*self.buffer) + // } + // None => Err(AccessError::InvalidId(AccessedInvalidId(id))), + // } + // } +} + impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { type Delegate = NeighborAccessor<'a>; fn delegate_neighbor(&'a mut self) -> Self::Delegate { @@ -1103,13 +1120,99 @@ impl provider::BuildDistanceComputer for Accessor<'_> { // Glue // //------// -impl glue::SearchExt for Accessor<'_> { +impl glue::SearchExt<&[f32]> for Accessor<'_> { fn starting_points(&self) -> impl Future>> + Send { futures_util::future::ok(self.provider.config.start_points.keys().copied().collect()) } + + fn start_point_distances( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for &i in self.provider.config.start_points.keys() { + f( + i, + computer.evaluate_similarity(self.get(i).escalate("start points must exist")?), + ) + } + Ok(()) + } + } +} + +impl glue::ExpandBeam<&[f32]> for Accessor<'_> { + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + async move { + let mut neighbors = AdjacencyList::new(); + for id in ids { + self.provider.get_neighbors(id, &mut neighbors)?; + for &n in neighbors.iter().filter(|i| pred.eval_mut(i)) { + if let Some(buf) = self.get(n).allow_transient("transient failures allowed")? { + on_neighbors(computer.evaluate_similarity(buf), n) + } + } + } + Ok(()) + } + } } -impl glue::ExpandBeam<&[f32]> for Accessor<'_> {} +type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; +type View<'a> = workingset::map::View<'a, u32, Box<[f32]>, workingset::map::Ref<[f32]>>; + +impl workingset::Fill for Accessor<'_> { + type Error = ANNError; + type View<'a> + = View<'a> + where + Self: 'a, + WorkingSet: 'a; + + fn fill<'a, Itr>( + &'a mut self, + set: &'a mut WorkingSet, + itr: Itr, + ) -> impl SendFuture, Self::Error>> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + async { + use workingset::map::Entry; + + set.prepare(itr.clone()); + for i in itr { + match set.entry(i) { + Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } + Entry::Vacant(vacant) => { + if let Some(buf) = + self.get(i).allow_transient("transient failures allowed")? + { + vacant.insert(buf.into()); + } + } + } + } + Ok(set.view()) + } + } +} impl glue::IdIterator> for Accessor<'_> { async fn id_iterator(&mut self) -> Result, ANNError> { @@ -1648,7 +1751,7 @@ mod tests { let mut accessor = super::Accessor::new(&provider); let id = 5; - assert!(rt.block_on(accessor.get_element(5)).is_err()); + assert!(accessor.get(5).is_err()); // Setting with the wrong dimension is an error. { @@ -1658,7 +1761,7 @@ mod tests { .unwrap_err(); let msg = err.to_string(); assert_message_contains!(msg, "wrong dim"); - assert!(rt.block_on(accessor.get_element(id)).is_err()); + assert!(accessor.get(id).is_err()); } // Setting with the correct dimension is successful. @@ -1669,7 +1772,7 @@ mod tests { .unwrap(); rt.block_on(guard.complete()); - let element = rt.block_on(accessor.get_element(id)).unwrap(); + let element = accessor.get(id).unwrap(); assert_eq!(v, element); } diff --git a/diskann/src/graph/workingset/map.rs b/diskann/src/graph/workingset/map.rs index eaa11c089..7cbf8f848 100644 --- a/diskann/src/graph/workingset/map.rs +++ b/diskann/src/graph/workingset/map.rs @@ -129,16 +129,14 @@ use std::{fmt::Debug, hash::Hash, sync::Arc}; -use diskann_utils::{Reborrow, future::SendFuture}; +use diskann_utils::{Reborrow}; use hashbrown::hash_map; use crate::{ - error::{RankedError, ToRanked, TransientError}, graph::glue, - provider::Accessor, }; -use super::{AsWorkingSet, Fill}; +use super::{AsWorkingSet}; ///////// // Map // @@ -284,72 +282,6 @@ where } } -impl Map -where - K: Copy + Hash + Eq + Send + Sync, - V: Send + Sync, - P: Projection, -{ - /// Fill using `accessor.get_element` for each missing key. - /// - /// Calls [`prepare`](Self::prepare) to pin and evict entries, then fetches any - /// keys still missing. For incremental fills that skip preparation, use - /// [`fill_with`](Self::fill_with) directly. - /// - /// Transient errors are acknowledged and skipped. Only critical errors are propagated. - pub fn fill<'a, A, Itr>( - &'a mut self, - accessor: &'a mut A, - itr: Itr, - ) -> impl SendFuture, ::Error>> - where - A: for<'b> Accessor: Into>, - Itr: ExactSizeIterator + Clone + Send + Sync, - { - self.prepare(itr.clone()); - self.fill_with(accessor, itr, |element| element.into()) - } - - /// Fill using a projection from `Accessor::Element` to `V`. - /// - /// Transient errors are acknowledged and skipped. Only critical errors are propagated. - pub fn fill_with<'a, A, Itr, F>( - &'a mut self, - accessor: &'a mut A, - itr: Itr, - f: F, - ) -> impl SendFuture, ::Error>> - where - A: Accessor, - Itr: ExactSizeIterator + Send + Sync, - F: Fn(A::ElementRef<'_>) -> V + Send, - { - async move { - for i in itr { - match self.entry(i) { - Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } - Entry::Vacant(vacant) => match accessor.get_element(i).await { - Ok(element) => { - vacant.insert(f(element.reborrow())); - } - Err(local_error) => match local_error.to_ranked() { - RankedError::Transient(transient) => { - transient.acknowledge( - "transient error during fill; element will be absent from working set", - ); - } - RankedError::Error(critical) => { - return Err(critical); - } - }, - }, - } - } - Ok(self.view()) - } - } -} - impl Map where K: Hash + Eq, @@ -526,36 +458,36 @@ impl VacantEntry<'_, K, V> { } } -/// Blanket implementation of [`Fill`] for [`Map`]-backed working sets. -/// -/// This covers the common case where the accessor's `Extended` type is stored directly -/// in the map. Accessors that need custom fill logic (e.g. hybrid full-precision/quantized) -/// should use a different `State` type and provide their own `Fill` impl. -impl Fill> for A -where - P: Projection, - A: for<'a> Accessor = P::ElementRef<'a>>, - V: Project

+ Send + Sync + 'static, - for<'a> A::ElementRef<'a>: Into, -{ - type Error = ::Error; - type View<'a> - = View<'a, A::Id, V, P> - where - Self: 'a; - - fn fill<'a, Itr>( - &'a mut self, - map: &'a mut Map, - itr: Itr, - ) -> impl SendFuture, Self::Error>> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - map.fill(self, itr) - } -} +// /// Blanket implementation of [`Fill`] for [`Map`]-backed working sets. +// /// +// /// This covers the common case where the accessor's `Extended` type is stored directly +// /// in the map. Accessors that need custom fill logic (e.g. hybrid full-precision/quantized) +// /// should use a different `State` type and provide their own `Fill` impl. +// impl Fill> for A +// where +// P: Projection, +// A: for<'a> Accessor = P::ElementRef<'a>>, +// V: Project

+ Send + Sync + 'static, +// for<'a> A::ElementRef<'a>: Into, +// { +// type Error = ::Error; +// type View<'a> +// = View<'a, A::Id, V, P> +// where +// Self: 'a; +// +// fn fill<'a, Itr>( +// &'a mut self, +// map: &'a mut Map, +// itr: Itr, +// ) -> impl SendFuture, Self::Error>> +// where +// Itr: ExactSizeIterator + Clone + Send + Sync, +// Self: 'a, +// { +// map.fill(self, itr) +// } +// } ///////////////// // Projections // @@ -1933,201 +1865,201 @@ mod tests { let _ = ar.clone(); } - //----------------------------------// - // Fill / fill_with (async, Tier 1) // - //----------------------------------// - - /// Create a grid-backed provider (1-D, size 5). - /// - /// IDs 0–4 have vectors `[0.0]` .. `[4.0]`, start point is `u32::MAX`. - fn fill_provider() -> crate::graph::test::provider::Provider { - use crate::graph::test::synthetic::Grid; - crate::graph::test::provider::Provider::grid(Grid::One, 5).unwrap() - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_happy_path() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - let current = map.generation; - let view = map - .fill(&mut accessor, [0u32, 1, 2].into_iter()) - .await - .unwrap(); - - assert_eq!(view.get(0).unwrap(), &[0.0]); - assert_eq!(view.get(1).unwrap(), &[1.0]); - assert_eq!(view.get(2).unwrap(), &[2.0]); - assert!(view.get(99).is_none()); - - assert_eq!( - map.generation, - current + 1, - "`fill` should bump the generation" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_clears_previous_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: Map, Ref<[f32]>> = Builder::new(Capacity::None).build(0); - - // First fill with IDs 0 and 1. - let view = map - .fill(&mut accessor, [0u32, 1].into_iter()) - .await - .unwrap(); - - assert!(view.get(0).is_some()); - assert!(view.get(1).is_some()); - assert!(view.get(2).is_none()); - - // Second fill with only ID 2 — previous entries should be cleared. - let view = map.fill(&mut accessor, [2u32].into_iter()).await.unwrap(); - assert!(view.get(0).is_none(), "fill should have cleared id 0"); - assert!(view.get(1).is_none(), "fill should have cleared id 1"); - assert_eq!(view.get(2).unwrap(), &[2.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_with_preserves_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // Populate with ID 0. - let view = map - .fill_with(&mut accessor, [0u32].into_iter(), |e| e.into()) - .await - .unwrap(); - assert!(view.get(0).is_some()); - - // fill_with with ID 1 — should NOT clear ID 0. - let view = map - .fill_with(&mut accessor, [1u32].into_iter(), |e| e.into()) - .await - .unwrap(); - assert_eq!( - view.get(0).unwrap(), - &[0.0], - "fill_with should preserve id 0" - ); - assert_eq!(view.get(1).unwrap(), &[1.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_skips_transient_errors() { - let provider = fill_provider(); - - let mut accessor = - TestAccessor::flaky(&provider, Cow::Owned(std::collections::HashSet::from([1]))); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // ID 1 is transient — should be skipped, not propagated. - let view = map - .fill(&mut accessor, [0u32, 1, 2].into_iter()) - .await - .unwrap(); - assert_eq!(view.get(0).unwrap(), &[0.0]); - assert!(view.get(1).is_none(), "transient ID should be absent"); - assert_eq!(view.get(2).unwrap(), &[2.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_propagates_critical_errors() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // ID 99 doesn't exist — critical InvalidId error. - let err = map - .fill(&mut accessor, [0u32, 99].into_iter()) - .await - .unwrap_err(); - let msg = err.to_string(); - assert!( - msg.contains("99"), - "error should mention the invalid id: {msg}" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_with_skips_occupied_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // Pre-insert a sentinel value for ID 0. - map.insert(0, Box::new([99.0])); - - // fill_with should skip the occupied entry. - let view = map - .fill_with(&mut accessor, [0u32, 1].into_iter(), |e| e.into()) - .await - .unwrap(); - - // ID 0 retains its pre-inserted sentinel. - assert_eq!( - view.get(0).unwrap(), - &[99.0], - "occupied entry should be preserved" - ); - assert_eq!(view.get(1).unwrap(), &[1.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_with_skips_seeded_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - - // Seed with a batch containing a different value for ID 0. - let batch = Arc::new(Matrix::try_from(Box::new([99.0, 88.0]), 2, 1).unwrap()); - let overlay = Overlay::>::from_batch(batch, [0u32, 1].into_iter()); - let mut map = seeded_map(overlay, Capacity::Unbounded); - - // fill_with requests IDs 0 and 2. ID 0 is seeded → skip, ID 2 is filled. - let view = map - .fill_with(&mut accessor, [0u32, 2].into_iter(), |e| e.into()) - .await - .unwrap(); - - // ID 0 comes from the seed (batch row 0 = [99.0]), NOT the accessor. - assert_eq!(view.get(0).unwrap(), &[99.0]); - // ID 2 was filled from the accessor. - assert_eq!(view.get(2).unwrap(), &[2.0]); - // Verify ID 0 is NOT in the fill layer. - assert!( - map.get(&0).is_none(), - "ID 0 should only be in the seed, not the fill layer" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_empty_iterator() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - let view = map.fill(&mut accessor, std::iter::empty()).await.unwrap(); - assert!(view.get(0).is_none()); - } - - #[tokio::test(flavor = "current_thread")] - async fn blanket_fill_trait() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // Exercise the blanket Fill impl. - let view = <_ as Fill>::fill(&mut accessor, &mut map, [0u32, 1, 2].into_iter()) - .await - .unwrap(); - - assert_eq!(view.get(0).unwrap(), &[0.0]); - assert_eq!(view.get(1).unwrap(), &[1.0]); - assert_eq!(view.get(2).unwrap(), &[2.0]); - } + // //----------------------------------// + // // Fill / fill_with (async, Tier 1) // + // //----------------------------------// + + // /// Create a grid-backed provider (1-D, size 5). + // /// + // /// IDs 0–4 have vectors `[0.0]` .. `[4.0]`, start point is `u32::MAX`. + // fn fill_provider() -> crate::graph::test::provider::Provider { + // use crate::graph::test::synthetic::Grid; + // crate::graph::test::provider::Provider::grid(Grid::One, 5).unwrap() + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_happy_path() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // let current = map.generation; + // let view = map + // .fill(&mut accessor, [0u32, 1, 2].into_iter()) + // .await + // .unwrap(); + + // assert_eq!(view.get(0).unwrap(), &[0.0]); + // assert_eq!(view.get(1).unwrap(), &[1.0]); + // assert_eq!(view.get(2).unwrap(), &[2.0]); + // assert!(view.get(99).is_none()); + + // assert_eq!( + // map.generation, + // current + 1, + // "`fill` should bump the generation" + // ); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_clears_previous_entries() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: Map, Ref<[f32]>> = Builder::new(Capacity::None).build(0); + + // // First fill with IDs 0 and 1. + // let view = map + // .fill(&mut accessor, [0u32, 1].into_iter()) + // .await + // .unwrap(); + + // assert!(view.get(0).is_some()); + // assert!(view.get(1).is_some()); + // assert!(view.get(2).is_none()); + + // // Second fill with only ID 2 — previous entries should be cleared. + // let view = map.fill(&mut accessor, [2u32].into_iter()).await.unwrap(); + // assert!(view.get(0).is_none(), "fill should have cleared id 0"); + // assert!(view.get(1).is_none(), "fill should have cleared id 1"); + // assert_eq!(view.get(2).unwrap(), &[2.0]); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_with_preserves_entries() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // // Populate with ID 0. + // let view = map + // .fill_with(&mut accessor, [0u32].into_iter(), |e| e.into()) + // .await + // .unwrap(); + // assert!(view.get(0).is_some()); + + // // fill_with with ID 1 — should NOT clear ID 0. + // let view = map + // .fill_with(&mut accessor, [1u32].into_iter(), |e| e.into()) + // .await + // .unwrap(); + // assert_eq!( + // view.get(0).unwrap(), + // &[0.0], + // "fill_with should preserve id 0" + // ); + // assert_eq!(view.get(1).unwrap(), &[1.0]); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_skips_transient_errors() { + // let provider = fill_provider(); + + // let mut accessor = + // TestAccessor::flaky(&provider, Cow::Owned(std::collections::HashSet::from([1]))); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // // ID 1 is transient — should be skipped, not propagated. + // let view = map + // .fill(&mut accessor, [0u32, 1, 2].into_iter()) + // .await + // .unwrap(); + // assert_eq!(view.get(0).unwrap(), &[0.0]); + // assert!(view.get(1).is_none(), "transient ID should be absent"); + // assert_eq!(view.get(2).unwrap(), &[2.0]); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_propagates_critical_errors() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // // ID 99 doesn't exist — critical InvalidId error. + // let err = map + // .fill(&mut accessor, [0u32, 99].into_iter()) + // .await + // .unwrap_err(); + // let msg = err.to_string(); + // assert!( + // msg.contains("99"), + // "error should mention the invalid id: {msg}" + // ); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_with_skips_occupied_entries() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // // Pre-insert a sentinel value for ID 0. + // map.insert(0, Box::new([99.0])); + + // // fill_with should skip the occupied entry. + // let view = map + // .fill_with(&mut accessor, [0u32, 1].into_iter(), |e| e.into()) + // .await + // .unwrap(); + + // // ID 0 retains its pre-inserted sentinel. + // assert_eq!( + // view.get(0).unwrap(), + // &[99.0], + // "occupied entry should be preserved" + // ); + // assert_eq!(view.get(1).unwrap(), &[1.0]); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_with_skips_seeded_entries() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + + // // Seed with a batch containing a different value for ID 0. + // let batch = Arc::new(Matrix::try_from(Box::new([99.0, 88.0]), 2, 1).unwrap()); + // let overlay = Overlay::>::from_batch(batch, [0u32, 1].into_iter()); + // let mut map = seeded_map(overlay, Capacity::Unbounded); + + // // fill_with requests IDs 0 and 2. ID 0 is seeded → skip, ID 2 is filled. + // let view = map + // .fill_with(&mut accessor, [0u32, 2].into_iter(), |e| e.into()) + // .await + // .unwrap(); + + // // ID 0 comes from the seed (batch row 0 = [99.0]), NOT the accessor. + // assert_eq!(view.get(0).unwrap(), &[99.0]); + // // ID 2 was filled from the accessor. + // assert_eq!(view.get(2).unwrap(), &[2.0]); + // // Verify ID 0 is NOT in the fill layer. + // assert!( + // map.get(&0).is_none(), + // "ID 0 should only be in the seed, not the fill layer" + // ); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn fill_empty_iterator() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // let view = map.fill(&mut accessor, std::iter::empty()).await.unwrap(); + // assert!(view.get(0).is_none()); + // } + + // #[tokio::test(flavor = "current_thread")] + // async fn blanket_fill_trait() { + // let provider = fill_provider(); + // let mut accessor = TestAccessor::new(&provider); + // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); + + // // Exercise the blanket Fill impl. + // let view = <_ as Fill>::fill(&mut accessor, &mut map, [0u32, 1, 2].into_iter()) + // .await + // .unwrap(); + + // assert_eq!(view.get(0).unwrap(), &[0.0]); + // assert_eq!(view.get(1).unwrap(), &[1.0]); + // assert_eq!(view.get(2).unwrap(), &[2.0]); + // } } diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f55af356c..968671373 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -97,8 +97,7 @@ use std::ops::Deref; -use diskann_utils::Reborrow; -use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; +use diskann_vector::{DistanceFunction}; use sealed::{BoundTo, Sealed}; use crate::{ANNError, ANNResult, error::ToRanked, graph::AdjacencyList, utils::VectorId}; @@ -427,51 +426,51 @@ pub trait Accessor: HasId + Send + Sync { /// requirement on `Self`. type ElementRef<'a>; - /// The concrete type of the data element associated with this accessor. - /// - /// For distance computations, this should be cheaply convertible via [`Reborrow`] to - /// `Self::ElementRef`. - type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync - where - Self: 'a; - - /// The error (if any) returned by [`Self::get_element`]. - type GetError: ToRanked + std::fmt::Debug + Send + Sync + 'static; - - /// Return the value associated with the key `id`. - /// - /// It is expected that index algorithms will only invoke `get_element` on valid IDs, - /// that can be derived from [`SetElement::set_element`] or by some other means. - /// - /// Implementations are suggested to return an error if this invariant is broken, but - /// may also panic if that is an acceptable error mode. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl std::future::Future, Self::GetError>> + Send; - - /// A bulk interface for invoking [`Self::get_element`] on each item in an iterator and - /// invoking the closure with the reborrowed element. - /// - /// Algorithms are encouraged to use this interface if appropriate as accessor - /// implementations may specialize the implementation for better performance. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl std::future::Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - { - async move { - for i in itr { - f(self.get_element(i).await?.reborrow(), i); - } - Ok(()) - } - } + // /// The concrete type of the data element associated with this accessor. + // /// + // /// For distance computations, this should be cheaply convertible via [`Reborrow`] to + // /// `Self::ElementRef`. + // type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync + // where + // Self: 'a; + + // /// The error (if any) returned by [`Self::get_element`]. + // type GetError: ToRanked + std::fmt::Debug + Send + Sync + 'static; + + // /// Return the value associated with the key `id`. + // /// + // /// It is expected that index algorithms will only invoke `get_element` on valid IDs, + // /// that can be derived from [`SetElement::set_element`] or by some other means. + // /// + // /// Implementations are suggested to return an error if this invariant is broken, but + // /// may also panic if that is an acceptable error mode. + // fn get_element( + // &mut self, + // id: Self::Id, + // ) -> impl std::future::Future, Self::GetError>> + Send; + + // /// A bulk interface for invoking [`Self::get_element`] on each item in an iterator and + // /// invoking the closure with the reborrowed element. + // /// + // /// Algorithms are encouraged to use this interface if appropriate as accessor + // /// implementations may specialize the implementation for better performance. + // fn on_elements_unordered( + // &mut self, + // itr: Itr, + // mut f: F, + // ) -> impl std::future::Future> + Send + // where + // Self: Sync, + // Itr: Iterator + Send, + // F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), + // { + // async move { + // for i in itr { + // f(self.get_element(i).await?.reborrow(), i); + // } + // Ok(()) + // } + // } } /// A specialized [`Accessor`] that provides random-access distance computations. @@ -506,10 +505,7 @@ pub trait BuildQueryComputer: Accessor { /// The concrete type of the distance computer, which must be applicable for all /// elements yielded by the [`Accessor`]. - type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> - + Send - + Sync - + 'static; + type QueryComputer: Send + Sync + 'static; /// Build the query computer for this accessor. /// @@ -520,26 +516,26 @@ pub trait BuildQueryComputer: Accessor { from: T, ) -> Result; - /// Compute the distances for the elements in the iterator `itr` using the - /// `computer` and apply the closure `f` to each distance and ID. The default - /// implementation uses on_elements_unordered to iterate over the elements - /// and compute the distances using `computer` parameter. - fn distances_unordered( - &mut self, - vec_id_itr: Itr, - computer: &Self::QueryComputer, - mut f: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - F: Send + FnMut(f32, Self::Id), - { - self.on_elements_unordered(vec_id_itr, move |element, i| { - // Default is to use the computer to evaluate the similarity. - let distance = computer.evaluate_similarity(element); - f(distance, i); - }) - } + // /// Compute the distances for the elements in the iterator `itr` using the + // /// `computer` and apply the closure `f` to each distance and ID. The default + // /// implementation uses on_elements_unordered to iterate over the elements + // /// and compute the distances using `computer` parameter. + // fn distances_unordered( + // &mut self, + // vec_id_itr: Itr, + // computer: &Self::QueryComputer, + // mut f: F, + // ) -> impl std::future::Future> + Send + // where + // Itr: Iterator + Send, + // F: Send + FnMut(f32, Self::Id), + // { + // self.on_elements_unordered(vec_id_itr, move |element, i| { + // // Default is to use the computer to evaluate the similarity. + // let distance = computer.evaluate_similarity(element); + // f(distance, i); + // }) + // } } ///////////////////////// @@ -1047,48 +1043,48 @@ mod tests { type Id = u32; } impl Accessor for FloatAccessor<'_> { - type Element<'a> - = f32 - where - Self: 'a; + // type Element<'a> + // = f32 + // where + // Self: 'a; type ElementRef<'a> = f32; - type GetError = Missing; - - fn get_element( - &mut self, - id: u32, - ) -> impl Future, Self::GetError>> + Send { - let guard = self.0.data.lock().unwrap(); - let v = match guard.get(&id) { - None => Err(Missing), - Some(v) => Ok(v.0), - }; - std::future::ready(v) - } - - // Implement `on_elements_unordered` by only acquiring the lock once. - // - // Real implementations will need to take care to avoid deadlocks. - async fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> Result<(), Self::GetError> - where - Self: Sync, - Itr: Iterator, - F: Send + FnMut(f32, u32), - { - let guard = self.0.data.lock().unwrap(); - for i in itr { - match guard.get(&i) { - None => return Err(Missing), - Some(v) => f(v.0, i), - } - } - Ok(()) - } + // type GetError = Missing; + + // fn get_element( + // &mut self, + // id: u32, + // ) -> impl Future, Self::GetError>> + Send { + // let guard = self.0.data.lock().unwrap(); + // let v = match guard.get(&id) { + // None => Err(Missing), + // Some(v) => Ok(v.0), + // }; + // std::future::ready(v) + // } + + // // Implement `on_elements_unordered` by only acquiring the lock once. + // // + // // Real implementations will need to take care to avoid deadlocks. + // async fn on_elements_unordered( + // &mut self, + // itr: Itr, + // mut f: F, + // ) -> Result<(), Self::GetError> + // where + // Self: Sync, + // Itr: Iterator, + // F: Send + FnMut(f32, u32), + // { + // let guard = self.0.data.lock().unwrap(); + // for i in itr { + // match guard.get(&i) { + // None => return Err(Missing), + // Some(v) => f(v.0, i), + // } + // } + // Ok(()) + // } } // An accessor for the `String` portion of the data stored in the SimpleProvider. @@ -1113,100 +1109,100 @@ mod tests { type Id = u32; } impl Accessor for StringAccessor<'_> { - type Element<'a> - = &'a str - where - Self: 'a; + // type Element<'a> + // = &'a str + // where + // Self: 'a; type ElementRef<'a> = &'a str; - type GetError = Missing; - - fn get_element( - &mut self, - id: u32, - ) -> impl Future, Self::GetError>> + Send { - let guard = self.provider.data.lock().unwrap(); - let v = match guard.get(&id) { - None => Err(Missing), - Some(v) => { - self.buf.clone_from(&v.1); - Ok(&*self.buf) - } - }; - std::future::ready(v) - } + // type GetError = Missing; + + // fn get_element( + // &mut self, + // id: u32, + // ) -> impl Future, Self::GetError>> + Send { + // let guard = self.provider.data.lock().unwrap(); + // let v = match guard.get(&id) { + // None => Err(Missing), + // Some(v) => { + // self.buf.clone_from(&v.1); + // Ok(&*self.buf) + // } + // }; + // std::future::ready(v) + // } } - #[tokio::test] - async fn test_default_implementations() { - let provider = SimpleProvider::new(-1.0, "hello".to_string()); - { - let mut data = provider.data.lock().unwrap(); - data.insert(0, (0.0, "world".to_string())); - data.insert(1, (1.0, "foo".to_string())); - data.insert(2, (2.0, "bar".to_string())); - } - - // Float accessor - { - let mut accessor = FloatAccessor(&provider); - assert_eq!(accessor.get_element(0).await.unwrap(), 0.0); - assert_eq!(accessor.get_element(1).await.unwrap(), 1.0); - assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), -1.0); - - let mut v = Vec::new(); - accessor - .on_elements_unordered([2, 1, 0].into_iter(), |element, id| v.push((element, id))) - .await - .unwrap(); - - assert_eq!(&v, &[(2.0, 2), (1.0, 1), (0.0, 0)]); - - // Test error propagation. - // Trying to access element 3 will result in an error, which should be propagated - // up. - let err = accessor - .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { - v.push((element, id)) - }) - .await - .unwrap_err(); - assert_eq!(err, Missing); - } - - // String accessor - { - let mut accessor = StringAccessor::new(&provider); - assert_eq!(accessor.get_element(0).await.unwrap(), "world"); - assert_eq!(accessor.get_element(1).await.unwrap(), "foo"); - assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), "hello"); - - // This method tests the provided implementation of `on_elements_unordered`. - let expected = [("bar", 2), ("foo", 1), ("world", 0)]; - - let mut expected_iter = expected.into_iter(); - accessor - .on_elements_unordered([2, 1, 0].into_iter(), |element, id| { - assert_eq!((element, id), expected_iter.next().unwrap()); - }) - .await - .unwrap(); - assert!(expected_iter.next().is_none()); - - // Test error propagation. - // Trying to access element 3 will result in an error, which should be propagated - // up. - let mut expected_iter = expected.into_iter(); - let err = accessor - .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { - assert_eq!((element, id), expected_iter.next().unwrap()); - }) - .await - .unwrap_err(); - assert_eq!(err, Missing); - assert!(expected_iter.next().is_none()); - } - } + // #[tokio::test] + // async fn test_default_implementations() { + // let provider = SimpleProvider::new(-1.0, "hello".to_string()); + // { + // let mut data = provider.data.lock().unwrap(); + // data.insert(0, (0.0, "world".to_string())); + // data.insert(1, (1.0, "foo".to_string())); + // data.insert(2, (2.0, "bar".to_string())); + // } + + // // Float accessor + // { + // let mut accessor = FloatAccessor(&provider); + // assert_eq!(accessor.get_element(0).await.unwrap(), 0.0); + // assert_eq!(accessor.get_element(1).await.unwrap(), 1.0); + // assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), -1.0); + + // let mut v = Vec::new(); + // accessor + // .on_elements_unordered([2, 1, 0].into_iter(), |element, id| v.push((element, id))) + // .await + // .unwrap(); + + // assert_eq!(&v, &[(2.0, 2), (1.0, 1), (0.0, 0)]); + + // // Test error propagation. + // // Trying to access element 3 will result in an error, which should be propagated + // // up. + // let err = accessor + // .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { + // v.push((element, id)) + // }) + // .await + // .unwrap_err(); + // assert_eq!(err, Missing); + // } + + // // String accessor + // { + // let mut accessor = StringAccessor::new(&provider); + // assert_eq!(accessor.get_element(0).await.unwrap(), "world"); + // assert_eq!(accessor.get_element(1).await.unwrap(), "foo"); + // assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), "hello"); + + // // This method tests the provided implementation of `on_elements_unordered`. + // let expected = [("bar", 2), ("foo", 1), ("world", 0)]; + + // let mut expected_iter = expected.into_iter(); + // accessor + // .on_elements_unordered([2, 1, 0].into_iter(), |element, id| { + // assert_eq!((element, id), expected_iter.next().unwrap()); + // }) + // .await + // .unwrap(); + // assert!(expected_iter.next().is_none()); + + // // Test error propagation. + // // Trying to access element 3 will result in an error, which should be propagated + // // up. + // let mut expected_iter = expected.into_iter(); + // let err = accessor + // .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { + // assert_eq!((element, id), expected_iter.next().unwrap()); + // }) + // .await + // .unwrap_err(); + // assert_eq!(err, Missing); + // assert!(expected_iter.next().is_none()); + // } + // } ///////////////////////////////// // Supported Accessor Patterns // @@ -1272,16 +1268,16 @@ mod tests { common_test_accessor!(Allocating<'_>); impl Accessor for Allocating<'_> { - type Element<'a> - = Box<[u8]> - where - Self: 'a; + // type Element<'a> + // = Box<[u8]> + // where + // Self: 'a; type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; + // type GetError = Infallible; - async fn get_element(&mut self, _: u32) -> Result, Infallible> { - Ok(self.store.data.clone()) - } + // async fn get_element(&mut self, _: u32) -> Result, Infallible> { + // Ok(self.store.data.clone()) + // } } // An accessor that forwards - returning references directly into the underlying @@ -1299,18 +1295,7 @@ mod tests { common_test_accessor!(Forwarding<'_>); impl<'provider> Accessor for Forwarding<'provider> { - // NOTE: The lifetime of `Element` is `'provider` - not `'a`. This is what makes - // it a forwarding accessor. - type Element<'a> - = &'provider [u8] - where - Self: 'a; type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result<&'provider [u8], Infallible> { - Ok(&*self.store.data) - } } // An accessor that returns a non-reference type with a lifetime. @@ -1343,16 +1328,7 @@ mod tests { common_test_accessor!(Wrapping<'_>); impl Accessor for Wrapping<'_> { - type Element<'a> - = Wrapped<'a> - where - Self: 'a; type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result, Infallible> { - Ok(Wrapped(&self.store.data)) - } } // An accessor that shares local state. @@ -1374,56 +1350,6 @@ mod tests { common_test_accessor!(Sharing<'_>); impl Accessor for Sharing<'_> { - type Element<'a> - = &'a [u8] - where - Self: 'a; type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result<&[u8], Infallible> { - self.local.copy_from_slice(&self.store.data); - Ok(&self.local) - } - } - - #[tokio::test] - async fn test_accessor_patterns() { - let store = Store::new(); - - // A slice against which we compute distances. - let base: &[u8] = &[2, 3, 4, 5]; - - { - let mut accessor = Allocating::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - - { - let mut accessor = Forwarding::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - - { - let mut accessor = Wrapping::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - - { - let mut accessor = Sharing::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } } } From b8e382b3f3752afcd21f1dc30196b22750ebae92 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 13 May 2026 18:04:11 -0700 Subject: [PATCH 02/17] Remove now unneeded code. --- .../src/search/provider/disk_provider.rs | 38 +- .../encoded_document_accessor.rs | 128 +++--- .../provider/async_/inmem/full_precision.rs | 74 --- .../graph/provider/async_/inmem/product.rs | 71 --- .../graph/provider/async_/inmem/spherical.rs | 5 +- .../model/graph/provider/layers/betafilter.rs | 11 +- diskann/src/graph/glue.rs | 2 +- diskann/src/graph/index.rs | 23 +- diskann/src/graph/search/multihop_search.rs | 2 +- diskann/src/graph/test/provider.rs | 24 +- diskann/src/graph/workingset/map.rs | 243 +--------- diskann/src/provider.rs | 424 +----------------- 12 files changed, 92 insertions(+), 953 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 64144014f..3e3cad44c 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -428,19 +428,6 @@ where .to_vec(), }) } - - // async fn distances_unordered( - // &mut self, - // vec_id_itr: Itr, - // _computer: &Self::QueryComputer, - // f: F, - // ) -> Result<(), Self::GetError> - // where - // F: Send + FnMut(f32, Self::Id), - // Itr: Iterator, - // { - // self.pq_distances(&vec_id_itr.collect::>(), f) - // } } impl ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> @@ -604,10 +591,13 @@ where mut f: F, ) -> ANNResult<()> where - F: FnMut(Self::Id, f32) + Send + F: FnMut(Self::Id, f32) + Send, { let start_vertex_id = self.provider.graph_header.metadata().medoid as u32; - let vector = self.provider.pq_data.get_compressed_vector(start_vertex_id.into_usize())?; + let vector = self + .provider + .pq_data + .get_compressed_vector(start_vertex_id.into_usize())?; let distance = computer.evaluate_similarity(vector); f(start_vertex_id, distance); Ok(()) @@ -709,25 +699,7 @@ where Data: GraphDataType, VP: VertexProvider, { - // /// This accessor returns raw slices. There *is* a chance of racing when the fast - // /// providers are used. We just have to live with it. - // type Element<'a> - // = &'a [u8] - // where - // Self: 'a; - - /// `ElementRef` can have arbitrary lifetimes. type ElementRef<'a> = &'a [u8]; - - // /// Choose to panic on an out-of-bounds access rather than propagate an error. - // type GetError = ANNError; - - // fn get_element( - // &mut self, - // id: Self::Id, - // ) -> impl Future, Self::GetError>> + Send { - // std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize)) - // } } impl IdIterator> for DiskAccessor<'_, Data, VP> diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 6f269ab42..0d16248dd 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -72,71 +72,71 @@ impl Accessor for EncodedDocumentAccessor where IA: Accessor, { - // type Element<'a> - // = EncodedDocument, RoaringTreemap> - // where - // Self: 'a; + type Element<'a> + = EncodedDocument, RoaringTreemap> + where + Self: 'a; type ElementRef<'a> = EncodedDocument, &'a RoaringTreemap>; - // type GetError = ANNError; - - // async fn get_element(&mut self, id: Self::Id) -> Result, Self::GetError> { - // let future = self.inner_accessor.get_element(id); - // let elem = future.await.escalate("Did not find the vector element")?; - - // let attrs = self - // .attribute_accessor - // .visit_labels_of_point(id, |_, opt_set| { - // match opt_set { - // //TODO: Currently, there is no way but to copy. So we copy the set from the Cow into a - // //hydrated object. - // //IMP NOTE: Removing the copy will also change the signature of "Element" and may cause other - // //downstream issues, so should be done with care! - // Some(set) => Ok(set.into_owned()), - // None => Err(ANNError::message( - // ANNErrorKind::IndexError, - // "No labels were found for vector", - // )), - // } - // })?; - - // Ok(EncodedDocument::new(elem, attrs?)) - // } - - // async fn on_elements_unordered( - // &mut self, - // itr: Itr, - // mut f: F, - // ) -> Result<(), Self::GetError> - // where - // Self: Sync, - // Itr: Iterator + Send, - // F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - // { - // for i in itr { - // let vec = self - // .inner_accessor - // .get_element(i) - // .await - // .escalate("Failed to get vector from inner accessor")?; - // let _ = self - // .attribute_accessor - // .visit_labels_of_point(i, |_, opt_set| { - // let set = match opt_set { - // Some(set) => set, - // None => { - // return Err(ANNError::message( - // ANNErrorKind::IndexError, - // format!("No attributes found for point.{}", i), - // )); - // } - // }; - // let elem = EncodedDocument::new(vec.reborrow(), &*set); - // f(elem, i); - // Ok(()) - // }); - // } - // Ok(()) - // } + type GetError = ANNError; + + async fn get_element(&mut self, id: Self::Id) -> Result, Self::GetError> { + let future = self.inner_accessor.get_element(id); + let elem = future.await.escalate("Did not find the vector element")?; + + let attrs = self + .attribute_accessor + .visit_labels_of_point(id, |_, opt_set| { + match opt_set { + //TODO: Currently, there is no way but to copy. So we copy the set from the Cow into a + //hydrated object. + //IMP NOTE: Removing the copy will also change the signature of "Element" and may cause other + //downstream issues, so should be done with care! + Some(set) => Ok(set.into_owned()), + None => Err(ANNError::message( + ANNErrorKind::IndexError, + "No labels were found for vector", + )), + } + })?; + + Ok(EncodedDocument::new(elem, attrs?)) + } + + async fn on_elements_unordered( + &mut self, + itr: Itr, + mut f: F, + ) -> Result<(), Self::GetError> + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), + { + for i in itr { + let vec = self + .inner_accessor + .get_element(i) + .await + .escalate("Failed to get vector from inner accessor")?; + let _ = self + .attribute_accessor + .visit_labels_of_point(i, |_, opt_set| { + let set = match opt_set { + Some(set) => set, + None => { + return Err(ANNError::message( + ANNErrorKind::IndexError, + format!("No attributes found for point.{}", i), + )); + } + }; + let elem = EncodedDocument::new(vec.reborrow(), &*set); + f(elem, i); + Ok(()) + }); + } + Ok(()) + } } impl SearchExt for EncodedDocumentAccessor diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 2b408c046..09fb81053 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -214,81 +214,7 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - // /// This accessor returns raw slices. There *is* a chance of racing when the fast - // /// providers are used. We just have to live with it. - // type Element<'a> - // = &'a [T] - // where - // Self: 'a; - - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = &'a [T]; - - // /// Choose to panic on an out-of-bounds access rather than propagate an error. - // type GetError = Panics; - - // /// Return the full-precision vector stored at index `i`. - // /// - // /// This function always completes synchronously. - // #[inline(always)] - // fn get_element( - // &mut self, - // id: Self::Id, - // ) -> impl Future, Self::GetError>> + Send { - // // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // // potentially mixing unsynchronized reads and writes on the underlying memory. - // std::future::ready(Ok(unsafe { - // self.provider.base_vectors.get_vector_sync(id.into_usize()) - // })) - // } - - // /// Perform a bulk operation. - // /// - // /// This implementation uses prefetching. - // fn on_elements_unordered( - // &mut self, - // itr: Itr, - // mut f: F, - // ) -> impl Future> + Send - // where - // Self: Sync, - // Itr: Iterator + Send, - // F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - // { - // // Reuse the internal buffer to collect the results and give us random access - // // capabilities. - // let id_buffer = &mut self.id_buffer; - // id_buffer.clear(); - // id_buffer.extend(itr); - - // let len = id_buffer.len(); - // let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // // Prefetch the first few vectors. - // for id in id_buffer.iter().take(lookahead) { - // self.provider.base_vectors.prefetch_hint(id.into_usize()); - // } - - // for (i, id) in id_buffer.iter().enumerate() { - // // Prefetch `lookahead` iterations ahead as long as it is safe. - // if lookahead > 0 && i + lookahead < len { - // self.provider - // .base_vectors - // .prefetch_hint(id_buffer[i + lookahead].into_usize()); - // } - - // // Invoke the passed closure on the full-precision vector. - // // - // // SAFETY: We're accepting the consequences of potential unsynchronized, - // // concurrent mutation. - // f( - // unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - // *id, - // ) - // } - - // std::future::ready(Ok(())) - // } } impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index cd895aa91..f0b3f18f4 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -173,54 +173,8 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - // /// This accessor returns raw slices. There *is* a chance of racing when the fast - // /// providers are used. We just have to live with it. - // type Element<'a> - // = &'a [u8] - // where - // Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = &'a [u8]; - - // /// Choose to panic on an out-of-bounds access rather than propagate an error. - // type GetError = Panics; - - // /// Return the quantized vector stored at index `i`. - // /// - // /// This function always completes synchronously. - // fn get_element( - // &mut self, - // id: Self::Id, - // ) -> impl Future, Self::GetError>> + Send { - // // SAFETY: We've decided to live with UB that can result from potentially mixing - // // unsynchronized reads and writes on the underlying memory. - // std::future::ready(Ok(unsafe { - // self.provider.aux_vectors.get_vector_sync(id.into_usize()) - // })) - // } - - // /// Perform a bulk operation. - // fn on_elements_unordered( - // &mut self, - // itr: Itr, - // mut f: F, - // ) -> impl Future> + Send - // where - // Self: Sync, - // Itr: Iterator + Send, - // F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - // { - // for i in itr { - // // SAFETY: We're accepting the consequences of potential unsynchronized, - // // concurrent mutation. - // f( - // unsafe { self.provider.aux_vectors.get_vector_sync(i.into_usize()) }, - // i, - // ) - // } - // std::future::ready(Ok(())) - // } } impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> @@ -379,33 +333,8 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - // /// The [`distances::pq::Hybrid`] is an enum consisting of either a full-precision - // /// vector or a quantized vector. - // /// - // /// This accessor can return either. - // type Element<'a> - // = distances::pq::Hybrid<&'a [T], &'a [u8]> - // where - // Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - - // /// Choose to panic on an out-of-bounds access rather than propagate an error. - // type GetError = Panics; - - // /// The default behavior of `get_element` returns a full-precision vector. The - // /// implementation of [`Fill`] is how the `max_fp_vecs_per_fill` is used. - // fn get_element( - // &mut self, - // id: Self::Id, - // ) -> impl Future, Self::GetError>> + Send { - // // SAFETY: We've decided to live with UB that can result from potentially mixing - // // unsynchronized reads and writes on the underlying memory. - // std::future::ready(Ok(unsafe { - // distances::pq::Hybrid::Full(self.provider.base_vectors.get_vector_sync(id.into_usize())) - // })) - // } } impl BuildDistanceComputer for HybridAccessor<'_, T, D, Ctx> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index c6771bcbf..4dd1189de 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -352,9 +352,8 @@ where for i in self.provider.starting_points()? { // SAFETY: We're accepting the consequences of potential unsynchronized, // concurrent mutation. - let distance = computer.evaluate_similarity( - self.provider.aux_vectors.get_vector(i.into_usize())? - ); + let distance = computer + .evaluate_similarity(self.provider.aux_vectors.get_vector(i.into_usize())?); f(i, distance); } diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 7ebf6b9a1..f65ad8087 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -187,13 +187,12 @@ where mut f: F, ) -> impl std::future::Future> + Send where - F: FnMut(Self::Id, f32) + Send { - self.inner.start_point_distances( - computer.inner(), - move |id, distance| { + F: FnMut(Self::Id, f32) + Send, + { + self.inner + .start_point_distances(computer.inner(), move |id, distance| { f(id, computer.apply(id, distance)); - } - ) + }) } } diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 0bc4e7c1c..e1956e7b5 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -83,7 +83,7 @@ use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; use crate::{ ANNError, ANNResult, - error::{StandardError}, + error::StandardError, graph::{SearchOutputBuffer, workingset}, neighbor::Neighbor, provider::{ diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 596139394..63c9933a0 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -16,7 +16,7 @@ use diskann_utils::{ Reborrow, future::{AssertSend, SendFuture, boxit}, }; -use diskann_vector::{DistanceFunction}; +use diskann_vector::DistanceFunction; use futures_util::FutureExt; use hashbrown::HashSet; use thiserror::Error; @@ -25,8 +25,8 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ - self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - MultiInsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -2027,6 +2027,8 @@ where async move { let beam_width = beam_width.unwrap_or(1); + // paged search can call search_internal multiple times, we only need to initialize + // state if not already initialized. if scratch.visited.is_empty() { accessor .start_point_distances(computer, |id, dist| { @@ -2037,21 +2039,6 @@ where .await?; } - // // paged search can call search_internal multiple times, we only need to initialize - // // state if not already initialized. - // if scratch.visited.is_empty() { - // for id in start_ids { - // scratch.visited.insert(*id); - // let element = accessor - // .get_element(*id) - // .await - // .escalate("start point retrieval must succeed")?; - // let dist = computer.evaluate_similarity(element.reborrow()); - // scratch.best.insert(Neighbor::new(*id, dist)); - // scratch.cmps += 1; - // } - // } - let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); while scratch.best.has_notvisited_node() && !accessor.terminate_early() { scratch.beam_nodes.clear(); diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 3eb6ef923..04a0fe4ed 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -11,7 +11,7 @@ use hashbrown::HashSet; use super::{Knn, Search, record::SearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, - error::{IntoANNResult}, + error::IntoANNResult, graph::{ glue::{ self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index b1c548cda..50badec86 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1058,29 +1058,7 @@ impl provider::HasId for Accessor<'_> { } impl provider::Accessor for Accessor<'_> { - // type Element<'a> - // = &'a [f32] - // where - // Self: 'a; type ElementRef<'a> = &'a [f32]; - // type GetError = AccessError; - - // async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> { - // match self.provider.terms.get(&id) { - // Some(term) => { - // if let Some(transient) = &self.transient_ids - // && transient.contains(&id) - // { - // return Err(AccessError::Transient(TransientAccessError::new(id))); - // } - - // self.get_vector.increment(); - // self.buffer.copy_from_slice(&term.data); - // Ok(&*self.buffer) - // } - // None => Err(AccessError::InvalidId(AccessedInvalidId(id))), - // } - // } } impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { @@ -1742,7 +1720,7 @@ mod tests { #[test] fn test_set_element() { - use provider::{Accessor, Guard, SetElement}; + use provider::{Guard, SetElement}; let provider = create_test_provider(); let rt = current_thread_runtime(); diff --git a/diskann/src/graph/workingset/map.rs b/diskann/src/graph/workingset/map.rs index 7cbf8f848..8fffeb910 100644 --- a/diskann/src/graph/workingset/map.rs +++ b/diskann/src/graph/workingset/map.rs @@ -129,14 +129,12 @@ use std::{fmt::Debug, hash::Hash, sync::Arc}; -use diskann_utils::{Reborrow}; +use diskann_utils::Reborrow; use hashbrown::hash_map; -use crate::{ - graph::glue, -}; +use crate::graph::glue; -use super::{AsWorkingSet}; +use super::AsWorkingSet; ///////// // Map // @@ -458,37 +456,6 @@ impl VacantEntry<'_, K, V> { } } -// /// Blanket implementation of [`Fill`] for [`Map`]-backed working sets. -// /// -// /// This covers the common case where the accessor's `Extended` type is stored directly -// /// in the map. Accessors that need custom fill logic (e.g. hybrid full-precision/quantized) -// /// should use a different `State` type and provide their own `Fill` impl. -// impl Fill> for A -// where -// P: Projection, -// A: for<'a> Accessor = P::ElementRef<'a>>, -// V: Project

+ Send + Sync + 'static, -// for<'a> A::ElementRef<'a>: Into, -// { -// type Error = ::Error; -// type View<'a> -// = View<'a, A::Id, V, P> -// where -// Self: 'a; -// -// fn fill<'a, Itr>( -// &'a mut self, -// map: &'a mut Map, -// itr: Itr, -// ) -> impl SendFuture, Self::Error>> -// where -// Itr: ExactSizeIterator + Clone + Send + Sync, -// Self: 'a, -// { -// map.fill(self, itr) -// } -// } - ///////////////// // Projections // ///////////////// @@ -912,13 +879,11 @@ where mod tests { use super::*; - use std::{borrow::Cow, sync::Arc}; + use std::sync::Arc; use diskann_utils::views::Matrix; - use crate::graph::{ - test::provider::Accessor as TestAccessor, workingset::View as WorkingSetView, - }; + use crate::graph::workingset::View as WorkingSetView; /// Convenience alias matching the test provider's working set type. type TestMap = Map, Ref<[f32]>>; @@ -1864,202 +1829,4 @@ mod tests { let ar = AsReborrowed(&*value); let _ = ar.clone(); } - - // //----------------------------------// - // // Fill / fill_with (async, Tier 1) // - // //----------------------------------// - - // /// Create a grid-backed provider (1-D, size 5). - // /// - // /// IDs 0–4 have vectors `[0.0]` .. `[4.0]`, start point is `u32::MAX`. - // fn fill_provider() -> crate::graph::test::provider::Provider { - // use crate::graph::test::synthetic::Grid; - // crate::graph::test::provider::Provider::grid(Grid::One, 5).unwrap() - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_happy_path() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // let current = map.generation; - // let view = map - // .fill(&mut accessor, [0u32, 1, 2].into_iter()) - // .await - // .unwrap(); - - // assert_eq!(view.get(0).unwrap(), &[0.0]); - // assert_eq!(view.get(1).unwrap(), &[1.0]); - // assert_eq!(view.get(2).unwrap(), &[2.0]); - // assert!(view.get(99).is_none()); - - // assert_eq!( - // map.generation, - // current + 1, - // "`fill` should bump the generation" - // ); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_clears_previous_entries() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: Map, Ref<[f32]>> = Builder::new(Capacity::None).build(0); - - // // First fill with IDs 0 and 1. - // let view = map - // .fill(&mut accessor, [0u32, 1].into_iter()) - // .await - // .unwrap(); - - // assert!(view.get(0).is_some()); - // assert!(view.get(1).is_some()); - // assert!(view.get(2).is_none()); - - // // Second fill with only ID 2 — previous entries should be cleared. - // let view = map.fill(&mut accessor, [2u32].into_iter()).await.unwrap(); - // assert!(view.get(0).is_none(), "fill should have cleared id 0"); - // assert!(view.get(1).is_none(), "fill should have cleared id 1"); - // assert_eq!(view.get(2).unwrap(), &[2.0]); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_with_preserves_entries() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // // Populate with ID 0. - // let view = map - // .fill_with(&mut accessor, [0u32].into_iter(), |e| e.into()) - // .await - // .unwrap(); - // assert!(view.get(0).is_some()); - - // // fill_with with ID 1 — should NOT clear ID 0. - // let view = map - // .fill_with(&mut accessor, [1u32].into_iter(), |e| e.into()) - // .await - // .unwrap(); - // assert_eq!( - // view.get(0).unwrap(), - // &[0.0], - // "fill_with should preserve id 0" - // ); - // assert_eq!(view.get(1).unwrap(), &[1.0]); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_skips_transient_errors() { - // let provider = fill_provider(); - - // let mut accessor = - // TestAccessor::flaky(&provider, Cow::Owned(std::collections::HashSet::from([1]))); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // // ID 1 is transient — should be skipped, not propagated. - // let view = map - // .fill(&mut accessor, [0u32, 1, 2].into_iter()) - // .await - // .unwrap(); - // assert_eq!(view.get(0).unwrap(), &[0.0]); - // assert!(view.get(1).is_none(), "transient ID should be absent"); - // assert_eq!(view.get(2).unwrap(), &[2.0]); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_propagates_critical_errors() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // // ID 99 doesn't exist — critical InvalidId error. - // let err = map - // .fill(&mut accessor, [0u32, 99].into_iter()) - // .await - // .unwrap_err(); - // let msg = err.to_string(); - // assert!( - // msg.contains("99"), - // "error should mention the invalid id: {msg}" - // ); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_with_skips_occupied_entries() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // // Pre-insert a sentinel value for ID 0. - // map.insert(0, Box::new([99.0])); - - // // fill_with should skip the occupied entry. - // let view = map - // .fill_with(&mut accessor, [0u32, 1].into_iter(), |e| e.into()) - // .await - // .unwrap(); - - // // ID 0 retains its pre-inserted sentinel. - // assert_eq!( - // view.get(0).unwrap(), - // &[99.0], - // "occupied entry should be preserved" - // ); - // assert_eq!(view.get(1).unwrap(), &[1.0]); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_with_skips_seeded_entries() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - - // // Seed with a batch containing a different value for ID 0. - // let batch = Arc::new(Matrix::try_from(Box::new([99.0, 88.0]), 2, 1).unwrap()); - // let overlay = Overlay::>::from_batch(batch, [0u32, 1].into_iter()); - // let mut map = seeded_map(overlay, Capacity::Unbounded); - - // // fill_with requests IDs 0 and 2. ID 0 is seeded → skip, ID 2 is filled. - // let view = map - // .fill_with(&mut accessor, [0u32, 2].into_iter(), |e| e.into()) - // .await - // .unwrap(); - - // // ID 0 comes from the seed (batch row 0 = [99.0]), NOT the accessor. - // assert_eq!(view.get(0).unwrap(), &[99.0]); - // // ID 2 was filled from the accessor. - // assert_eq!(view.get(2).unwrap(), &[2.0]); - // // Verify ID 0 is NOT in the fill layer. - // assert!( - // map.get(&0).is_none(), - // "ID 0 should only be in the seed, not the fill layer" - // ); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn fill_empty_iterator() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // let view = map.fill(&mut accessor, std::iter::empty()).await.unwrap(); - // assert!(view.get(0).is_none()); - // } - - // #[tokio::test(flavor = "current_thread")] - // async fn blanket_fill_trait() { - // let provider = fill_provider(); - // let mut accessor = TestAccessor::new(&provider); - // let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // // Exercise the blanket Fill impl. - // let view = <_ as Fill>::fill(&mut accessor, &mut map, [0u32, 1, 2].into_iter()) - // .await - // .unwrap(); - - // assert_eq!(view.get(0).unwrap(), &[0.0]); - // assert_eq!(view.get(1).unwrap(), &[1.0]); - // assert_eq!(view.get(2).unwrap(), &[2.0]); - // } } diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index 968671373..f6f5dd2bb 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -97,7 +97,7 @@ use std::ops::Deref; -use diskann_vector::{DistanceFunction}; +use diskann_vector::DistanceFunction; use sealed::{BoundTo, Sealed}; use crate::{ANNError, ANNResult, error::ToRanked, graph::AdjacencyList, utils::VectorId}; @@ -425,52 +425,6 @@ pub trait Accessor: HasId + Send + Sync { /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` /// requirement on `Self`. type ElementRef<'a>; - - // /// The concrete type of the data element associated with this accessor. - // /// - // /// For distance computations, this should be cheaply convertible via [`Reborrow`] to - // /// `Self::ElementRef`. - // type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync - // where - // Self: 'a; - - // /// The error (if any) returned by [`Self::get_element`]. - // type GetError: ToRanked + std::fmt::Debug + Send + Sync + 'static; - - // /// Return the value associated with the key `id`. - // /// - // /// It is expected that index algorithms will only invoke `get_element` on valid IDs, - // /// that can be derived from [`SetElement::set_element`] or by some other means. - // /// - // /// Implementations are suggested to return an error if this invariant is broken, but - // /// may also panic if that is an acceptable error mode. - // fn get_element( - // &mut self, - // id: Self::Id, - // ) -> impl std::future::Future, Self::GetError>> + Send; - - // /// A bulk interface for invoking [`Self::get_element`] on each item in an iterator and - // /// invoking the closure with the reborrowed element. - // /// - // /// Algorithms are encouraged to use this interface if appropriate as accessor - // /// implementations may specialize the implementation for better performance. - // fn on_elements_unordered( - // &mut self, - // itr: Itr, - // mut f: F, - // ) -> impl std::future::Future> + Send - // where - // Self: Sync, - // Itr: Iterator + Send, - // F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - // { - // async move { - // for i in itr { - // f(self.get_element(i).await?.reborrow(), i); - // } - // Ok(()) - // } - // } } /// A specialized [`Accessor`] that provides random-access distance computations. @@ -515,27 +469,6 @@ pub trait BuildQueryComputer: Accessor { &self, from: T, ) -> Result; - - // /// Compute the distances for the elements in the iterator `itr` using the - // /// `computer` and apply the closure `f` to each distance and ID. The default - // /// implementation uses on_elements_unordered to iterate over the elements - // /// and compute the distances using `computer` parameter. - // fn distances_unordered( - // &mut self, - // vec_id_itr: Itr, - // computer: &Self::QueryComputer, - // mut f: F, - // ) -> impl std::future::Future> + Send - // where - // Itr: Iterator + Send, - // F: Send + FnMut(f32, Self::Id), - // { - // self.on_elements_unordered(vec_id_itr, move |element, i| { - // // Default is to use the computer to evaluate the similarity. - // let distance = computer.evaluate_similarity(element); - // f(distance, i); - // }) - // } } ///////////////////////// @@ -770,11 +703,10 @@ mod sealed { #[cfg(test)] mod tests { use std::{ - collections::HashMap, future::Future, pin::Pin, sync::{ - Arc, Mutex, + Arc, atomic::{AtomicUsize, Ordering}, }, task, @@ -783,7 +715,7 @@ mod tests { use pin_project::{pin_project, pinned_drop}; use super::*; - use crate::{always_escalate, error::Infallible}; + use crate::always_escalate; //////////////////// // DefaultContext // @@ -984,40 +916,6 @@ mod tests { assert!(deleted.is_deleted()); } - /// A simple data provider that contains values consisting of floats and strings. - /// - /// The start point for this provider is as `u32::MAX`. - struct SimpleProvider { - data: Mutex>, - } - - impl SimpleProvider { - fn new(v: f32, st: String) -> Self { - let mut data = HashMap::new(); - data.insert(u32::MAX, (v, st)); - Self { - data: Mutex::new(data), - } - } - } - - impl DataProvider for SimpleProvider { - type Context = DefaultContext; - // Use the identity mapping for IDs. - type InternalId = u32; - type ExternalId = u32; - type Error = ANNError; - type Guard = NoopGuard; - - fn to_internal_id(&self, _context: &DefaultContext, gid: &u32) -> Result { - Ok(*gid) - } - - fn to_external_id(&self, _context: &DefaultContext, id: u32) -> Result { - Ok(id) - } - } - #[derive(Debug, Clone, Copy, PartialEq)] pub struct Missing; @@ -1036,320 +934,4 @@ mod tests { } always_escalate!(Missing); - - // An accessor for the `f32` portion of the data stored in the SimpleProvider. - struct FloatAccessor<'a>(&'a SimpleProvider); - impl HasId for FloatAccessor<'_> { - type Id = u32; - } - impl Accessor for FloatAccessor<'_> { - // type Element<'a> - // = f32 - // where - // Self: 'a; - type ElementRef<'a> = f32; - - // type GetError = Missing; - - // fn get_element( - // &mut self, - // id: u32, - // ) -> impl Future, Self::GetError>> + Send { - // let guard = self.0.data.lock().unwrap(); - // let v = match guard.get(&id) { - // None => Err(Missing), - // Some(v) => Ok(v.0), - // }; - // std::future::ready(v) - // } - - // // Implement `on_elements_unordered` by only acquiring the lock once. - // // - // // Real implementations will need to take care to avoid deadlocks. - // async fn on_elements_unordered( - // &mut self, - // itr: Itr, - // mut f: F, - // ) -> Result<(), Self::GetError> - // where - // Self: Sync, - // Itr: Iterator, - // F: Send + FnMut(f32, u32), - // { - // let guard = self.0.data.lock().unwrap(); - // for i in itr { - // match guard.get(&i) { - // None => return Err(Missing), - // Some(v) => f(v.0, i), - // } - // } - // Ok(()) - // } - } - - // An accessor for the `String` portion of the data stored in the SimpleProvider. - // - // We keep a local buffer `buf` into which the contents of the string are copied. - // This allows us to elide allocating on `get_element` calls. - struct StringAccessor<'a> { - provider: &'a SimpleProvider, - buf: String, - } - - impl<'a> StringAccessor<'a> { - fn new(provider: &'a SimpleProvider) -> Self { - Self { - provider, - buf: String::new(), - } - } - } - - impl HasId for StringAccessor<'_> { - type Id = u32; - } - impl Accessor for StringAccessor<'_> { - // type Element<'a> - // = &'a str - // where - // Self: 'a; - type ElementRef<'a> = &'a str; - - // type GetError = Missing; - - // fn get_element( - // &mut self, - // id: u32, - // ) -> impl Future, Self::GetError>> + Send { - // let guard = self.provider.data.lock().unwrap(); - // let v = match guard.get(&id) { - // None => Err(Missing), - // Some(v) => { - // self.buf.clone_from(&v.1); - // Ok(&*self.buf) - // } - // }; - // std::future::ready(v) - // } - } - - // #[tokio::test] - // async fn test_default_implementations() { - // let provider = SimpleProvider::new(-1.0, "hello".to_string()); - // { - // let mut data = provider.data.lock().unwrap(); - // data.insert(0, (0.0, "world".to_string())); - // data.insert(1, (1.0, "foo".to_string())); - // data.insert(2, (2.0, "bar".to_string())); - // } - - // // Float accessor - // { - // let mut accessor = FloatAccessor(&provider); - // assert_eq!(accessor.get_element(0).await.unwrap(), 0.0); - // assert_eq!(accessor.get_element(1).await.unwrap(), 1.0); - // assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), -1.0); - - // let mut v = Vec::new(); - // accessor - // .on_elements_unordered([2, 1, 0].into_iter(), |element, id| v.push((element, id))) - // .await - // .unwrap(); - - // assert_eq!(&v, &[(2.0, 2), (1.0, 1), (0.0, 0)]); - - // // Test error propagation. - // // Trying to access element 3 will result in an error, which should be propagated - // // up. - // let err = accessor - // .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { - // v.push((element, id)) - // }) - // .await - // .unwrap_err(); - // assert_eq!(err, Missing); - // } - - // // String accessor - // { - // let mut accessor = StringAccessor::new(&provider); - // assert_eq!(accessor.get_element(0).await.unwrap(), "world"); - // assert_eq!(accessor.get_element(1).await.unwrap(), "foo"); - // assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), "hello"); - - // // This method tests the provided implementation of `on_elements_unordered`. - // let expected = [("bar", 2), ("foo", 1), ("world", 0)]; - - // let mut expected_iter = expected.into_iter(); - // accessor - // .on_elements_unordered([2, 1, 0].into_iter(), |element, id| { - // assert_eq!((element, id), expected_iter.next().unwrap()); - // }) - // .await - // .unwrap(); - // assert!(expected_iter.next().is_none()); - - // // Test error propagation. - // // Trying to access element 3 will result in an error, which should be propagated - // // up. - // let mut expected_iter = expected.into_iter(); - // let err = accessor - // .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { - // assert_eq!((element, id), expected_iter.next().unwrap()); - // }) - // .await - // .unwrap_err(); - // assert_eq!(err, Missing); - // assert!(expected_iter.next().is_none()); - // } - // } - - ///////////////////////////////// - // Supported Accessor Patterns // - ///////////////////////////////// - - // This suite of tests ensure that patterns we want out of the `Accessor` associated - // trait hierarchy are all supported. - // - // These include: - // - // * Accessors that always allocate. - // * Accessors that simply reference the underlying store directly. - // * Accessors that use a local buffer. - - #[derive(Debug)] - struct Store { - data: Box<[u8]>, - } - - impl Store { - fn new() -> Self { - Self { - data: Box::from([1, 2, 3, 4]), - } - } - - fn dim(&self) -> usize { - self.data.len() - } - } - - macro_rules! common_test_accessor { - ($T:ty) => { - impl HasId for $T { - type Id = u32; - } - - impl BuildDistanceComputer for $T { - type DistanceComputerError = Infallible; - type DistanceComputer = ::Distance; - - fn build_distance_computer(&self) -> Result { - Ok(::distance( - diskann_vector::distance::Metric::L2, - None, - )) - } - } - }; - } - - // An accessor that always allocates. - struct Allocating<'a> { - store: &'a Store, - } - - impl<'a> Allocating<'a> { - fn new(store: &'a Store) -> Self { - Self { store } - } - } - - common_test_accessor!(Allocating<'_>); - - impl Accessor for Allocating<'_> { - // type Element<'a> - // = Box<[u8]> - // where - // Self: 'a; - type ElementRef<'a> = &'a [u8]; - // type GetError = Infallible; - - // async fn get_element(&mut self, _: u32) -> Result, Infallible> { - // Ok(self.store.data.clone()) - // } - } - - // An accessor that forwards - returning references directly into the underlying - // store without reallocation or copying. - struct Forwarding<'a> { - store: &'a Store, - } - - impl<'a> Forwarding<'a> { - fn new(store: &'a Store) -> Self { - Self { store } - } - } - - common_test_accessor!(Forwarding<'_>); - - impl<'provider> Accessor for Forwarding<'provider> { - type ElementRef<'a> = &'a [u8]; - } - - // An accessor that returns a non-reference type with a lifetime. - struct Wrapping<'a> { - store: &'a Store, - } - - impl<'a> Wrapping<'a> { - fn new(store: &'a Store) -> Self { - Self { store } - } - } - - #[derive(Debug)] - struct Wrapped<'a>(&'a [u8]); - - impl<'a> Reborrow<'a> for Wrapped<'_> { - type Target = &'a [u8]; - fn reborrow(&'a self) -> Self::Target { - self.0 - } - } - - impl From> for Box<[u8]> { - fn from(wrapped: Wrapped<'_>) -> Self { - wrapped.0.into() - } - } - - common_test_accessor!(Wrapping<'_>); - - impl Accessor for Wrapping<'_> { - type ElementRef<'a> = &'a [u8]; - } - - // An accessor that shares local state. - #[derive(Debug)] - struct Sharing<'a> { - store: &'a Store, - local: Box<[u8]>, - } - - impl<'a> Sharing<'a> { - fn new(store: &'a Store) -> Self { - Self { - store, - local: (0..store.dim()).map(|_| 0).collect(), - } - } - } - - common_test_accessor!(Sharing<'_>); - - impl Accessor for Sharing<'_> { - type ElementRef<'a> = &'a [u8]; - } } From 0a241a2dd61b54e07261e7b09a212fa65aba730c Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 14 May 2026 15:11:10 -0700 Subject: [PATCH 03/17] Checkpoint. --- .../src/search/provider/disk_provider.rs | 88 ++++++------- .../provider/async_/inmem/full_precision.rs | 123 ++++++++---------- .../graph/provider/async_/inmem/product.rs | 80 +++++------- .../graph/provider/async_/inmem/scalar.rs | 122 ++++++++--------- .../graph/provider/async_/inmem/spherical.rs | 115 ++++++++-------- .../model/graph/provider/layers/betafilter.rs | 74 +++++------ diskann/src/graph/glue.rs | 57 ++++---- diskann/src/graph/index.rs | 28 ++-- diskann/src/graph/search/multihop_search.rs | 6 +- diskann/src/graph/search/range_search.rs | 4 +- diskann/src/graph/test/provider.rs | 30 ++++- diskann/src/provider.rs | 2 +- 12 files changed, 352 insertions(+), 377 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 3e3cad44c..ee8236a99 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -21,8 +21,7 @@ use diskann::{ graph::{ self, glue::{ - self, DefaultPostProcessor, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, - SearchStrategy, + self, DefaultPostProcessor, IdIterator, SearchExt, SearchPostProcess, SearchStrategy, }, search::Knn, search_output_buffer, AdjacencyList, DiskANNIndex, @@ -348,7 +347,8 @@ where ProviderFactory: VertexProviderFactory, { type QueryComputer = DiskQueryComputer; - type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; + type SearchAccessor<'a> + = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; type SearchAccessorError = ANNError; fn search_accessor<'a>( @@ -430,50 +430,6 @@ where } } -impl ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - fn expand_beam( - &mut self, - ids: Itr, - _computer: &Self::QueryComputer, - mut pred: P, - mut f: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let result = (|| { - let io_limit = self.provider.search_io_limit - self.io_tracker.io_count(); - let load_ids: Box<[_]> = ids.take(io_limit).collect(); - - self.ensure_loaded(&load_ids)?; - let mut ids = Vec::new(); - for i in load_ids { - ids.clear(); - ids.extend( - self.scratch - .vertex_provider - .get_adjacency_list(&i)? - .iter() - .copied() - .filter(|id| pred.eval_mut(id)), - ); - - self.pq_distances(&ids, &mut f)?; - } - - Ok(()) - })(); - - std::future::ready(result) - } -} - // Scratch space for disk search operations that need allocations. // These allocations are amortized across searches using the scratch pool. struct DiskSearchScratch @@ -606,6 +562,44 @@ where fn terminate_early(&mut self) -> bool { self.io_tracker.io_count() > self.provider.search_io_limit } + + fn expand_beam( + &mut self, + ids: Itr, + _computer: &Self::QueryComputer, + mut pred: P, + mut f: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let result = (|| { + let io_limit = self.provider.search_io_limit - self.io_tracker.io_count(); + let load_ids: Box<[_]> = ids.take(io_limit).collect(); + + self.ensure_loaded(&load_ids)?; + let mut ids = Vec::new(); + for i in load_ids { + ids.clear(); + ids.extend( + self.scratch + .vertex_provider + .get_adjacency_list(&i)? + .iter() + .copied() + .filter(|id| pred.eval_mut(id)), + ); + + self.pq_distances(&ids, &mut f)?; + } + + Ok(()) + })(); + + std::future::ready(result) + } } impl<'a, Data, VP> DiskAccessor<'a, Data, VP> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 09fb81053..4dd3a74a0 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -12,8 +12,8 @@ use diskann::{ graph::{ AdjacencyList, SearchOutputBuffer, glue::{ - self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, + SearchExt, SearchStrategy, }, workingset, }, @@ -124,7 +124,7 @@ where /// /// The accessor reuses this allocation to amortize allocation cost over multiple bulk /// operations. - id_buffer: Vec, + id_buffer: AdjacencyList, } impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> @@ -176,6 +176,57 @@ where Ok(()) } } + + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let id_buffer = &mut self.id_buffer; + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), id_buffer)?; + + id_buffer.retain(|i| pred.eval_mut(i)); + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let v = unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }; + let distance = computer.evaluate_similarity(v); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> @@ -188,7 +239,7 @@ where pub fn new(provider: &'a FullPrecisionProvider) -> Self { Self { provider, - id_buffer: Vec::new(), + id_buffer: AdjacencyList::new(), } } } @@ -255,70 +306,6 @@ where } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - mut pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let f = move || -> ANNResult<()> { - let mut neighbors = AdjacencyList::new(); - for n in ids { - self.provider - .neighbor_provider - .get_neighbors_sync(n.into_usize(), &mut neighbors)?; - - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(neighbors.iter().filter(|i| pred.eval_mut(i))); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - let v = unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }; - let distance = computer.evaluate_similarity(v); - on_neighbors(distance, *id); - } - } - Ok(()) - }; - - std::future::ready(f()) - } -} - //-------------------// // In-mem Extensions // //-------------------// diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index f0b3f18f4..63c74bc1d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -11,8 +11,8 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, + SearchExt, SearchStrategy, }, workingset, }, @@ -142,6 +142,40 @@ where Ok(()) } } + + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + for i in neighbors.iter().filter(|i| pred.eval_mut(i)) { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); + + on_neighbors(distance, *i); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> @@ -211,48 +245,6 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - mut pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let f = move || -> ANNResult<()> { - let mut neighbors = AdjacencyList::new(); - for n in ids { - self.provider - .neighbor_provider - .get_neighbors_sync(n.into_usize(), &mut neighbors)?; - for i in neighbors.iter().filter(|i| pred.eval_mut(i)) { - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - let distance = computer.evaluate_similarity(unsafe { - self.provider.aux_vectors.get_vector_sync(i.into_usize()) - }); - - on_neighbors(distance, *i); - } - } - Ok(()) - }; - - std::future::ready(f()) - } -} - //-------------------// // In-mem Extensions // //-------------------// diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index b9a878cec..5a6039a4d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -11,8 +11,8 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, + SearchExt, SearchStrategy, }, workingset, }, @@ -372,7 +372,7 @@ where /// The accessor for SQ. pub struct QuantAccessor<'a, const NBITS: usize, V, D, Ctx> { provider: &'a DefaultProvider, D, Ctx>, - id_buffer: Vec, + id_buffer: AdjacencyList, is_search: bool, } @@ -420,6 +420,56 @@ where Ok(()) } } + + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let id_buffer = &mut self.id_buffer; + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), id_buffer)?; + + id_buffer.retain(|i| pred.eval_mut(i)); + + let len = id_buffer.len(); + let lookahead = self.provider.aux_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.aux_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .aux_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; + + // Invoke the passed closure on the vector. + let distance = computer.evaluate_similarity(vector); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } impl<'a, const NBITS: usize, V, D, Ctx> QuantAccessor<'a, NBITS, V, D, Ctx> @@ -434,7 +484,7 @@ where ) -> Self { Self { provider, - id_buffer: Vec::with_capacity(32), + id_buffer: AdjacencyList::with_capacity(32), is_search, } } @@ -565,70 +615,6 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, - Unsigned: Representation, - QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, -{ - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - mut pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let f = move || -> ANNResult<()> { - let mut neighbors = AdjacencyList::new(); - for n in ids { - self.provider - .neighbor_provider - .get_neighbors_sync(n.into_usize(), &mut neighbors)?; - - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(neighbors.iter().filter(|i| pred.eval_mut(i))); - - let len = id_buffer.len(); - let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.aux_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .aux_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; - - // Invoke the passed closure on the vector. - let distance = computer.evaluate_similarity(vector); - on_neighbors(distance, *id); - } - } - Ok(()) - }; - - std::future::ready(f()) - } -} - impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 4dd1189de..9e6dc5dff 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -12,8 +12,8 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, + SearchExt, SearchStrategy, }, workingset, }, @@ -290,7 +290,7 @@ where pub struct QuantAccessor<'a, V, D, Ctx> { provider: &'a DefaultProvider, - id_buffer: Vec, + id_buffer: AdjacencyList, layout: spherical::iface::QueryLayout, is_search: bool, } @@ -308,7 +308,7 @@ where ) -> Self { Self { provider, - id_buffer: Vec::with_capacity(32), + id_buffer: AdjacencyList::with_capacity(32), layout, is_search, } @@ -360,6 +360,53 @@ where Ok(()) } } + + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut id_buffer = &mut self.id_buffer; + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut id_buffer)?; + + id_buffer.retain(|i| pred.eval_mut(i)); + let len = id_buffer.len(); + let lookahead = self.provider.aux_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.aux_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .aux_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; + let distance = computer.evaluate_similarity(vector); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } impl Accessor for QuantAccessor<'_, V, D, Ctx> @@ -407,66 +454,6 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - mut pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let f = move || -> ANNResult<()> { - let mut neighbors = AdjacencyList::new(); - for n in ids { - self.provider - .neighbor_provider - .get_neighbors_sync(n.into_usize(), &mut neighbors)?; - - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(neighbors.iter().filter(|i| pred.eval_mut(i))); - - let len = id_buffer.len(); - let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.aux_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .aux_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; - let distance = computer.evaluate_similarity(vector); - on_neighbors(distance, *id); - } - } - Ok(()) - }; - - std::future::ready(f()) - } -} - #[derive(Debug, Error)] #[error("unconstructible")] pub enum Infallible {} diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index f65ad8087..5f992e1e4 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -20,7 +20,7 @@ use diskann::{ error::StandardError, graph::{ SearchOutputBuffer, - glue::{self, ExpandBeam, SearchExt, SearchPostProcessStep, SearchStrategy}, + glue::{self, SearchExt, SearchPostProcessStep, SearchStrategy}, index::QueryLabelProvider, }, neighbor::Neighbor, @@ -117,7 +117,8 @@ where { /// An accessor that returns the ID in addition to the element yielded by the inner /// accessor. - type SearchAccessor<'a> = BetaAccessor>; + type SearchAccessor<'a> + = BetaAccessor>; /// A [`PreprocessedDistanceFunction`] that combines applies the beta filtering factor /// if the vector ID portion of `Element` satisfies the filter predicate. @@ -147,7 +148,7 @@ where I: VectorId, O: Send, Provider: DataProvider, - Strategy: glue::DefaultPostProcessor, + Strategy: glue::DefaultPostProcessor , { type Processor = glue::Pipeline; @@ -194,6 +195,24 @@ where f(id, computer.apply(id, distance)); }) } + + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + self.inner + .expand_beam(ids, computer.inner(), pred, move |distance, id| { + on_neighbors(computer.apply(id, distance), id) + }) + } } impl<'a, Inner> DelegateNeighbor<'a> for BetaAccessor @@ -239,29 +258,6 @@ where } } -impl ExpandBeam for BetaAccessor -where - Inner: ExpandBeam, -{ - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - self.inner - .expand_beam(ids, computer.inner(), pred, move |distance, id| { - on_neighbors(computer.apply(id, distance), id) - }) - } -} - /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { inner: Inner, @@ -297,19 +293,19 @@ where } } -impl PreprocessedDistanceFunction, f32> for BetaComputer -where - I: VectorId, - Inner: PreprocessedDistanceFunction, -{ - /// Check whether the ID satisfied the predicate computed by the label provider. - /// - /// If so, multiply the distance computed by `Inner` by `beta`. - #[inline(always)] - fn evaluate_similarity(&self, x: Pair) -> f32 { - self.apply(x.id, self.inner.evaluate_similarity(x.element)) - } -} +// impl PreprocessedDistanceFunction, f32> for BetaComputer +// where +// I: VectorId, +// Inner: PreprocessedDistanceFunction, +// { +// /// Check whether the ID satisfied the predicate computed by the label provider. +// /// +// /// If so, multiply the distance computed by `Inner` by `beta`. +// #[inline(always)] +// fn evaluate_similarity(&self, x: Pair) -> f32 { +// self.apply(x.id, self.inner.evaluate_similarity(x.element)) +// } +// } // /////////// // // Tests // diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index e1956e7b5..6ac24113a 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -87,8 +87,7 @@ use crate::{ graph::{SearchOutputBuffer, workingset}, neighbor::Neighbor, provider::{ - Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, HasId, + Accessor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, HasId, }, utils::VectorId, }; @@ -108,6 +107,18 @@ pub trait SearchExt: BuildQueryComputer { where F: FnMut(Self::Id, f32) + Send; + fn expand_beam( + &mut self, + ids: Itr, + computer: &Self::QueryComputer, + pred: P, + on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send; + /// Default is to never terminate early. fn terminate_early(&mut self) -> bool { false @@ -259,19 +270,19 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// ## Error Handling /// /// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: BuildQueryComputer + AsNeighbor { - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - pred: P, - on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send; -} +// pub trait ExpandBeam: BuildQueryComputer { +// fn expand_beam( +// &mut self, +// ids: Itr, +// computer: &Self::QueryComputer, +// pred: P, +// on_neighbors: F, +// ) -> impl std::future::Future> + Send +// where +// Itr: Iterator + Send, +// P: HybridPredicate + Send + Sync, +// F: FnMut(f32, Self::Id) + Send; +// } /// A search strategy for query objects of type `T`. /// @@ -285,12 +296,7 @@ where /// /// We could grab this type from the `SearchAccessor` associated type, but it's /// useful enough that we move it up here. - type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as Accessor>::ElementRef<'b>, - f32, - > + Send - + Sync - + 'static; + type QueryComputer: Send + Sync + 'static + for<'a> DoesNotCaptureSelf>; /// An error that can occur when getting a search_accessor. type SearchAccessorError: StandardError; @@ -298,8 +304,7 @@ where /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. - type SearchAccessor<'a>: ExpandBeam - + SearchExt; + type SearchAccessor<'a>: SearchExt; /// Construct and return the search accessor. fn search_accessor<'a>( @@ -309,6 +314,9 @@ where ) -> Result, Self::SearchAccessorError>; } +pub trait DoesNotCaptureSelf {} +impl DoesNotCaptureSelf for U {} + /// Opt-in trait for strategies that have a default post-processor. /// /// Strategies implementing this trait can be used with index-level search APIs such as @@ -773,8 +781,7 @@ where /// of associated types. /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. - type DeleteSearchAccessor<'a>: ExpandBeam, Id = Provider::InternalId> - + SearchExt>; + type DeleteSearchAccessor<'a>: SearchExt, Id = Provider::InternalId>; /// The processor used during the delete-search phase. type SearchPostProcessor: for<'a> SearchPostProcess, Self::DeleteElement<'a>> diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 63c9933a0..5b47f5151 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -25,8 +25,8 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ - self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + self, Batch, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, PruneStrategy, + SearchExt, SearchPostProcess, SearchStrategy, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -1318,8 +1318,13 @@ where .collect(); // Collect IDs whose adjacency lists need to be updated. + let prune_strategy = strategy.prune_strategy(); + let mut prune_accessor = prune_strategy + .prune_accessor(self.provider(), context) + .into_ann_result()?; + let ids_to_modify = self - .return_refs_to_deleted_vertex(&mut search_accessor, id, &undeleted_ids) + .return_refs_to_deleted_vertex(&mut prune_accessor, id, &undeleted_ids) .await?; undeleted_ids.truncate(k_value); @@ -1698,9 +1703,9 @@ where .await .escalate("`inplace_delete` requires a successful delete")?; - let search_strategy = strategy.search_strategy(); - let accessor = &mut search_strategy - .search_accessor(&self.data_provider, context) + let prune_strategy = strategy.prune_strategy(); + let mut accessor = prune_strategy + .prune_accessor(&self.data_provider, context) .into_ann_result()?; let InplaceDeleteWorkList { @@ -1715,21 +1720,16 @@ where .await? } InplaceDeleteMethod::TwoHopAndOneHop => { - self.get_candidates_using_twohop_and_onehop(context, accessor, vector_id) + self.get_candidates_using_twohop_and_onehop(context, &mut accessor, vector_id) .await? } InplaceDeleteMethod::OneHop => { - self.get_candidates_using_onehop(context, accessor, vector_id) + self.get_candidates_using_onehop(context, &mut accessor, vector_id) .await? } }; - let prune_strategy = strategy.prune_strategy(); let mut working_set = prune_strategy.create_working_set(self.max_occlusion_size()); - let mut accessor = prune_strategy - .prune_accessor(&self.data_provider, context) - .into_ann_result()?; - let mut edges_to_add = HashMap::>::new(); // fetch the filtered adjacency list of `p`. @@ -2020,7 +2020,7 @@ where search_record: &mut SR, ) -> impl SendFuture> where - A: ExpandBeam + SearchExt, + A: SearchExt, SR: SearchRecord + ?Sized, Q: NeighborQueue, { diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 04a0fe4ed..6baa41f7d 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -14,8 +14,8 @@ use crate::{ error::IntoANNResult, graph::{ glue::{ - self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, - SearchPostProcess, SearchStrategy, + self, HybridPredicate, Predicate, PredicateMut, SearchExt, SearchPostProcess, + SearchStrategy, }, index::{ DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, @@ -180,7 +180,7 @@ pub(crate) async fn multihop_search_internal( ) -> ANNResult where I: VectorId, - A: ExpandBeam + SearchExt, + A: SearchExt, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 400fd3f0c..8de7aeafa 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -13,7 +13,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, ExpandBeam, SearchExt, SearchStrategy}, + glue::{self, SearchExt, SearchStrategy}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::{self, SearchOutputBuffer}, @@ -331,7 +331,7 @@ pub(crate) async fn range_search_internal( ) -> ANNResult where I: crate::utils::VectorId, - A: ExpandBeam + SearchExt, + A: SearchExt, { let beam_width = search_params.beam_width().unwrap_or(1); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 50badec86..1988ed58f 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1121,9 +1121,7 @@ impl glue::SearchExt<&[f32]> for Accessor<'_> { Ok(()) } } -} -impl glue::ExpandBeam<&[f32]> for Accessor<'_> { fn expand_beam( &mut self, ids: Itr, @@ -1151,6 +1149,34 @@ impl glue::ExpandBeam<&[f32]> for Accessor<'_> { } } +// impl glue::ExpandBeam<&[f32]> for Accessor<'_> { +// fn expand_beam( +// &mut self, +// ids: Itr, +// computer: &Self::QueryComputer, +// mut pred: P, +// mut on_neighbors: F, +// ) -> impl std::future::Future> + Send +// where +// Itr: Iterator + Send, +// P: glue::HybridPredicate + Send + Sync, +// F: FnMut(f32, Self::Id) + Send, +// { +// async move { +// let mut neighbors = AdjacencyList::new(); +// for id in ids { +// self.provider.get_neighbors(id, &mut neighbors)?; +// for &n in neighbors.iter().filter(|i| pred.eval_mut(i)) { +// if let Some(buf) = self.get(n).allow_transient("transient failures allowed")? { +// on_neighbors(computer.evaluate_similarity(buf), n) +// } +// } +// } +// Ok(()) +// } +// } +// } + type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; type View<'a> = workingset::map::View<'a, u32, Box<[f32]>, workingset::map::Ref<[f32]>>; diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f6f5dd2bb..f92a5908d 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -479,7 +479,7 @@ pub trait BuildQueryComputer: Accessor { /// /// Generally, neighbor access and data access are logically decoupled, being served from /// different stores. However, there are situations where data and neighbors are -/// interleaved in the underlying storage medium. +/// interlfor<'a> eaved in the underlying storage medium. /// /// As such, [`Accessors`] used in congunction with graph operations need to additionally /// provide an implementation of this trait. From 8aedf0c51d3f417c774b0ae397eec05aa7508461 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 14 May 2026 15:50:13 -0700 Subject: [PATCH 04/17] More simplification. --- .../src/search/provider/disk_provider.rs | 75 +--------- .../provider/async_/inmem/full_precision.rs | 8 +- .../graph/provider/async_/inmem/product.rs | 10 +- .../graph/provider/async_/inmem/scalar.rs | 84 +---------- .../graph/provider/async_/inmem/spherical.rs | 7 +- .../graph/provider/async_/postprocess.rs | 4 +- .../model/graph/provider/layers/betafilter.rs | 138 +++++------------- .../src/storage/index_storage.rs | 5 +- diskann/src/graph/glue.rs | 29 ++-- diskann/src/graph/index.rs | 11 +- diskann/src/graph/test/provider.rs | 30 +--- diskann/src/graph/workingset/mod.rs | 7 +- diskann/src/provider.rs | 59 ++------ 13 files changed, 96 insertions(+), 371 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index ee8236a99..c9d27ee3f 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -5,7 +5,6 @@ use std::{ collections::HashMap, - future::Future, num::NonZeroUsize, ops::Range, sync::{ @@ -24,13 +23,10 @@ use diskann::{ self, DefaultPostProcessor, IdIterator, SearchExt, SearchPostProcess, SearchStrategy, }, search::Knn, - search_output_buffer, AdjacencyList, DiskANNIndex, + search_output_buffer, DiskANNIndex, }, neighbor::Neighbor, - provider::{ - Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId, - NeighborAccessor, NoopGuard, - }, + provider::{BuildQueryComputer, DataProvider, DefaultContext, HasId, NoopGuard}, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, }; @@ -43,7 +39,6 @@ use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; use crate::search::pq::{quantizer_preprocess, PQData, PQScratch}; use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction}; -use futures_util::future; use tokio::runtime::Runtime; use tracing::debug; @@ -347,8 +342,7 @@ where ProviderFactory: VertexProviderFactory, { type QueryComputer = DiskQueryComputer; - type SearchAccessor<'a> - = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; + type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; type SearchAccessorError = ANNError; fn search_accessor<'a>( @@ -688,14 +682,6 @@ where type Id = u32; } -impl Accessor for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type ElementRef<'a> = &'a [u8]; -} - impl IdIterator> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, @@ -706,61 +692,6 @@ where } } -impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - AsNeighborAccessor(self) - } -} - -/// A light-weight wrapper around `&mut DiskAccessor` used to tailor the semantics of -/// [`NeighborAccessor`]. -/// -/// This implementation ensures that the vector data for adjacency lists is also retrieved -/// and cached to enhance reranking. -pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>) -where - Data: GraphDataType, - VP: VertexProvider; - -impl HasId for AsNeighborAccessor<'_, '_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type Id = u32; -} - -impl NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - fn get_neighbors( - self, - id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - if self.0.io_tracker.io_count() > self.0.provider.search_io_limit { - return future::ok(self); // Returning empty results in `neighbors` out param if IO limit is reached. - } - - if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) { - return future::err(e); - } - let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) { - Ok(list) => list, - Err(e) => return future::err(e), - }; - neighbors.overwrite_trusted(list); - future::ok(self) - } -} - /// [`DiskIndexSearcher`] is a helper class to make it easy to construct index /// and do repeated search operations. It is a wrapper around the index. /// This is useful for drivers such as search_disk_index.exe in tools. diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 4dd3a74a0..be9869a72 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -19,8 +19,8 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, + BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -258,7 +258,7 @@ where } } -impl Accessor for FullAccessor<'_, T, Q, D, Ctx> +impl HasElementRef for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, @@ -343,7 +343,7 @@ pub struct Rerank; impl<'a, A, T> glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<&'a [T], Id = u32> + GetFullPrecision + AsDeletionCheck, + A: BuildQueryComputer<&'a [T]> + HasId + GetFullPrecision + AsDeletionCheck, { type Error = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 63c74bc1d..1c52e883d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -17,8 +17,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, + HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -201,13 +201,12 @@ where } } -impl Accessor for QuantAccessor<'_, V, D, Ctx> +impl HasElementRef for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = &'a [u8]; } @@ -319,13 +318,12 @@ where } } -impl Accessor for HybridAccessor<'_, T, D, Ctx> +impl HasElementRef for HybridAccessor<'_, T, D, Ctx> where T: VectorRepr, D: AsyncFriendly, Ctx: ExecutionContext, { - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 5a6039a4d..7e7271c92 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -17,8 +17,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, + HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -490,92 +490,14 @@ where } } -impl Accessor for QuantAccessor<'_, NBITS, V, D, Ctx> +impl HasElementRef for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, Unsigned: Representation, { - // /// This accessor returns raw slices. There *is* a chance of racing when the fast - // /// providers are used. We just have to live with it. - // type Element<'a> - // = CVRef<'a, NBITS> - // where - // Self: 'a; - - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = CVRef<'a, NBITS>; - - // /// Choose to panic on an out-of-bounds access rather than propagate an error. - // type GetError = ANNError; - - // /// Return the quantized vector stored at index `i`. - // /// - // /// This function always completes synchronously. - // fn get_element( - // &mut self, - // id: Self::Id, - // ) -> impl Future, Self::GetError>> + Send { - // // SAFETY: We've decided to live with UB that can result from potentially mixing - // // unsynchronized reads and writes on the underlying memory. - // std::future::ready( - // match self.provider.aux_vectors.get_vector(id.into_usize()) { - // Ok(v) => Ok(v), - // Err(err) => Err(err.into()), - // }, - // ) - // } - - // /// Perform a bulk operation. - // /// - // /// This implementation uses prefetching. - // fn on_elements_unordered( - // &mut self, - // itr: Itr, - // mut f: F, - // ) -> impl Future> + Send - // where - // Self: Sync, - // Itr: Iterator + Send, - // F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - // { - // // Reuse the internal buffer to collect the results and give us random access - // // capabilities. - // let id_buffer = &mut self.id_buffer; - // id_buffer.clear(); - // id_buffer.extend(itr); - - // let len = id_buffer.len(); - // let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // // Prefetch the first few vectors. - // for id in id_buffer.iter().take(lookahead) { - // self.provider.aux_vectors.prefetch_hint(id.into_usize()); - // } - - // for (i, id) in id_buffer.iter().enumerate() { - // // Prefetch `lookahead` iterations ahead as long as it is safe. - // if lookahead > 0 && i + lookahead < len { - // self.provider - // .aux_vectors - // .prefetch_hint(id_buffer[i + lookahead].into_usize()); - // } - - // let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) { - // Ok(v) => v, - // Err(e) => return std::future::ready(Err(e.into())), - // }; - - // // Invoke the passed closure on the vector. - // // - // // SAFETY: We're accepting the consequences of potential unsynchronized, - // // concurrent mutation. - // f(vector, *id) - // } - - // std::future::ready(Ok(())) - // } } impl<'a, const NBITS: usize, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, NBITS, V, D, Ctx> diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 9e6dc5dff..23db9d296 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -18,8 +18,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, + HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -409,13 +409,12 @@ where } } -impl Accessor for QuantAccessor<'_, V, D, Ctx> +impl HasElementRef for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { - /// `ElementRef` has an arbitrarily short lifetime. type ElementRef<'a> = spherical::iface::Opaque<'a>; } diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index 3e1849bb0..dbbf08fa4 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::BuildQueryComputer, + provider::{BuildQueryComputer, HasId}, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -39,7 +39,7 @@ pub struct RemoveDeletedIdsAndCopy; impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + AsDeletionCheck, + A: BuildQueryComputer + HasId + AsDeletionCheck, { type Error = std::convert::Infallible; diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 5f992e1e4..2ab5dad1f 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -24,10 +24,9 @@ use diskann::{ index::QueryLabelProvider, }, neighbor::Neighbor, - provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + provider::{BuildQueryComputer, DataProvider, HasId}, utils::VectorId, }; -use diskann_vector::PreprocessedDistanceFunction; /// A [`SearchStrategy`] type that composes the inner distance computer with beta filtering. /// @@ -68,7 +67,7 @@ pub struct Unwrap; /// Delegate post-processing to the inner strategy's post-processing routine. impl SearchPostProcessStep, T, O> for Unwrap where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error = NextError @@ -82,7 +81,7 @@ where next: &Next, accessor: &mut BetaAccessor, query: T, - computer: &BetaComputer, + computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send @@ -91,13 +90,7 @@ where B: SearchOutputBuffer + Send + ?Sized, Next: glue::SearchPostProcess, { - next.post_process( - &mut accessor.inner, - query, - computer.inner(), - candidates, - output, - ) + next.post_process(&mut accessor.inner, query, computer, candidates, output) } } @@ -117,12 +110,11 @@ where { /// An accessor that returns the ID in addition to the element yielded by the inner /// accessor. - type SearchAccessor<'a> - = BetaAccessor>; + type SearchAccessor<'a> = BetaAccessor>; /// A [`PreprocessedDistanceFunction`] that combines applies the beta filtering factor /// if the vector ID portion of `Element` satisfies the filter predicate. - type QueryComputer = BetaComputer; + type QueryComputer = Strategy::QueryComputer; type SearchAccessorError = Strategy::SearchAccessorError; @@ -134,8 +126,10 @@ where ) -> Result, Self::SearchAccessorError> { Ok(BetaAccessor { inner: self.strategy.search_accessor(provider, context)?, - labels: self.labels.clone(), - beta: self.beta, + filter: Filter { + labels: self.labels.clone(), + beta: self.beta, + }, }) } } @@ -148,7 +142,7 @@ where I: VectorId, O: Send, Provider: DataProvider, - Strategy: glue::DefaultPostProcessor , + Strategy: glue::DefaultPostProcessor, { type Processor = glue::Pipeline; @@ -163,15 +157,25 @@ where Inner: HasId, { inner: Inner, - labels: Arc>, + filter: Filter, +} + +struct Filter { + labels: Arc>, beta: f32, } -/// The `Element` and `ElementRef` types used by the [`BetaAccessor`]. -#[derive(Debug, Clone, PartialEq)] -pub struct Pair { - id: I, - element: E, +impl Filter +where + I: VectorId, +{ + fn apply(&self, id: I, distance: f32) -> f32 { + if self.labels.is_match(id) { + distance * self.beta + } else { + distance + } + } } impl SearchExt for BetaAccessor @@ -190,9 +194,10 @@ where where F: FnMut(Self::Id, f32) + Send, { + let filter = &self.filter; self.inner - .start_point_distances(computer.inner(), move |id, distance| { - f(id, computer.apply(id, distance)); + .start_point_distances(computer, move |id, distance| { + f(id, filter.apply(id, distance)); }) } @@ -208,23 +213,14 @@ where P: glue::HybridPredicate + Send + Sync, F: FnMut(f32, Self::Id) + Send, { + let filter = &self.filter; self.inner - .expand_beam(ids, computer.inner(), pred, move |distance, id| { - on_neighbors(computer.apply(id, distance), id) + .expand_beam(ids, computer, pred, move |distance, id| { + on_neighbors(filter.apply(id, distance), id) }) } } -impl<'a, Inner> DelegateNeighbor<'a> for BetaAccessor -where - Inner: DelegateNeighbor<'a>, -{ - type Delegate = Inner::Delegate; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.inner.delegate_neighbor() - } -} - impl HasId for BetaAccessor where Inner: HasId, @@ -232,19 +228,14 @@ where type Id = Inner::Id; } -impl Accessor for BetaAccessor -where - Inner: Accessor, -{ - type ElementRef<'a> = Pair>; -} - impl BuildQueryComputer for BetaAccessor where - Inner: BuildQueryComputer, + Inner: BuildQueryComputer + HasId, { - /// Use a [`BetaComputer`] to apply filtering. - type QueryComputer = BetaComputer; + /// Use the inner `QueryComputer`. Application of the beta filter happens in the + /// closure for [`SearchExt::expand_beam`]. + type QueryComputer = Inner::QueryComputer; + /// Use the same error as `Inner`. type QueryComputerError = Inner::QueryComputerError; @@ -252,61 +243,10 @@ where &self, from: T, ) -> Result { - self.inner - .build_query_computer(from) - .map(|computer| BetaComputer::new(computer, self.labels.clone(), self.beta)) + self.inner.build_query_computer(from) } } -/// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. -pub struct BetaComputer { - inner: Inner, - labels: Arc>, - beta: f32, -} - -impl BetaComputer -where - I: VectorId, -{ - /// Construct a new `BetaComputer` around `Inner`. - pub fn new(inner: Inner, labels: Arc>, beta: f32) -> Self { - Self { - inner, - labels, - beta, - } - } - - /// Return a reference to the inner computer. - pub fn inner(&self) -> &Inner { - &self.inner - } - - /// Apply the beta-filtering heuristic. - pub fn apply(&self, id: I, distance: f32) -> f32 { - if self.labels.is_match(id) { - distance * self.beta - } else { - distance - } - } -} - -// impl PreprocessedDistanceFunction, f32> for BetaComputer -// where -// I: VectorId, -// Inner: PreprocessedDistanceFunction, -// { -// /// Check whether the ID satisfied the predicate computed by the label provider. -// /// -// /// If so, multiply the distance computed by `Inner` by `beta`. -// #[inline(always)] -// fn evaluate_similarity(&self, x: Pair) -> f32 { -// self.apply(x.id, self.inner.evaluate_similarity(x.element)) -// } -// } - // /////////// // // Tests // // /////////// diff --git a/diskann-providers/src/storage/index_storage.rs b/diskann-providers/src/storage/index_storage.rs index 2e76beff3..8e076ea3a 100644 --- a/diskann-providers/src/storage/index_storage.rs +++ b/diskann-providers/src/storage/index_storage.rs @@ -219,10 +219,10 @@ mod tests { use crate::storage::VirtualStorageProvider; use diskann::{ graph::{AdjacencyList, config, glue::InsertStrategy}, - provider::{Accessor, SetElement}, + provider::SetElement, utils::{IntoUsize, ONE}, }; - use diskann_utils::{Reborrow, test_data_root, views::MatrixView}; + use diskann_utils::{test_data_root, views::MatrixView}; use diskann_vector::distance::Metric; use super::*; @@ -231,7 +231,6 @@ mod tests { model::graph::provider::async_::{ SimpleNeighborProviderAsync, common::{FullPrecision, NoDeletes, NoStore, TableBasedDeletes}, - inmem::{self}, }, utils::create_rnd_from_seed_in_tests, }; diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 6ac24113a..11b2ed755 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -79,7 +79,7 @@ use std::{future::Future, sync::Arc}; use diskann_utils::Reborrow; -use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; +use diskann_vector::DistanceFunction; use crate::{ ANNError, ANNResult, @@ -87,14 +87,15 @@ use crate::{ graph::{SearchOutputBuffer, workingset}, neighbor::Neighbor, provider::{ - Accessor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, HasId, + AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, HasElementRef, + HasId, }, utils::VectorId, }; /// A trait to override search constraints such as early termination based on constraints /// by implementer. -pub trait SearchExt: BuildQueryComputer { +pub trait SearchExt: BuildQueryComputer + HasId + Send + Sync { /// Return a `Vec` containing the starting points. fn starting_points(&self) -> impl std::future::Future>> + Send; @@ -381,7 +382,7 @@ macro_rules! default_post_processor { /// directly into the output buffer. pub trait SearchPostProcess::Id> where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error: StandardError; @@ -407,7 +408,7 @@ pub struct CopyIds; impl SearchPostProcess for CopyIds where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error = std::convert::Infallible; fn post_process( @@ -432,7 +433,7 @@ where /// using a [`Pipeline`]. pub trait SearchPostProcessStep::Id> where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { /// A potentially modified version of the error yielded by the next state in the /// processing pipeline. @@ -441,7 +442,7 @@ where NextError: StandardError; /// The accessor that will be passed to the next processing stage. - type NextAccessor: BuildQueryComputer; + type NextAccessor: BuildQueryComputer + HasId; /// Perform any modification the `input`, `output`, `accessor`, or `computer` objects /// and invoke the [`SearchPostProcess`] routine `next` on stage. @@ -535,7 +536,7 @@ impl Pipeline { impl SearchPostProcess for Pipeline where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, Head: SearchPostProcessStep, Tail: SearchPostProcess + Sync, { @@ -609,9 +610,8 @@ where /// We could grab this type from the `PruneAccessor` associated type, but it's /// useful enough that we move it up here. type DistanceComputer<'computer>: for<'a, 'b, 'c, 'd> DistanceFunction< - as Accessor>::ElementRef<'b>, - as Accessor>::ElementRef<'d>, - f32, + as HasElementRef>::ElementRef<'b>, + as HasElementRef>::ElementRef<'d>, > + Send + Sync; @@ -623,10 +623,11 @@ where /// /// Implementations are encouraged to have [`Accessor::get_element`] return the /// highest-precision applicable value for a given element type. - type PruneAccessor<'a>: Accessor - + BuildDistanceComputer> + type PruneAccessor<'a>: BuildDistanceComputer> + AsNeighborMut - + workingset::Fill; + + workingset::Fill + + Send + + Sync; /// An error that can occur when getting the prune accessor. type PruneAccessorError: StandardError; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 5b47f5151..5a51aa184 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -43,9 +43,8 @@ use crate::{ internal, neighbor::{self, Neighbor, NeighborPriorityQueue, NeighborQueue}, provider::{ - Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, Delete, ElementStatus, ExecutionContext, Guard, NeighborAccessor, - NeighborAccessorMut, SetElement, + AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, Delete, + ElementStatus, ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, }, tracked_debug, tracked_error, tracked_trace, utils::{ @@ -2663,7 +2662,7 @@ where options: prune::Options, ) -> impl SendFuture> where - A: Accessor + BuildDistanceComputer + Fill, + A: BuildDistanceComputer + Fill + Send, Set: Send + Sync, { async move { @@ -2721,7 +2720,7 @@ where options: prune::Options, ) -> impl SendFuture>> where - A: Accessor + BuildDistanceComputer + Fill, + A: BuildDistanceComputer + Fill + Send, Set: Send + Sync, { async move { @@ -2812,7 +2811,7 @@ where options: prune::Options, ) -> impl SendFuture>> where - A: Accessor + BuildDistanceComputer + Fill, + A: BuildDistanceComputer + Fill + Send, Set: Send + Sync, Itr: ExactSizeIterator + Clone + Send + Sync, { diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 1988ed58f..b107f1a83 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1057,7 +1057,7 @@ impl provider::HasId for Accessor<'_> { type Id = u32; } -impl provider::Accessor for Accessor<'_> { +impl provider::HasElementRef for Accessor<'_> { type ElementRef<'a> = &'a [f32]; } @@ -1149,34 +1149,6 @@ impl glue::SearchExt<&[f32]> for Accessor<'_> { } } -// impl glue::ExpandBeam<&[f32]> for Accessor<'_> { -// fn expand_beam( -// &mut self, -// ids: Itr, -// computer: &Self::QueryComputer, -// mut pred: P, -// mut on_neighbors: F, -// ) -> impl std::future::Future> + Send -// where -// Itr: Iterator + Send, -// P: glue::HybridPredicate + Send + Sync, -// F: FnMut(f32, Self::Id) + Send, -// { -// async move { -// let mut neighbors = AdjacencyList::new(); -// for id in ids { -// self.provider.get_neighbors(id, &mut neighbors)?; -// for &n in neighbors.iter().filter(|i| pred.eval_mut(i)) { -// if let Some(buf) = self.get(n).allow_transient("transient failures allowed")? { -// on_neighbors(computer.evaluate_similarity(buf), n) -// } -// } -// } -// Ok(()) -// } -// } -// } - type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; type View<'a> = workingset::map::View<'a, u32, Box<[f32]>, workingset::map::Ref<[f32]>>; diff --git a/diskann/src/graph/workingset/mod.rs b/diskann/src/graph/workingset/mod.rs index d879c5d5a..d0a31cc37 100644 --- a/diskann/src/graph/workingset/mod.rs +++ b/diskann/src/graph/workingset/mod.rs @@ -67,7 +67,10 @@ use diskann_utils::{Reborrow, future::SendFuture}; -use crate::{ANNError, provider::Accessor}; +use crate::{ + ANNError, + provider::{HasElementRef, HasId}, +}; ///////////// // Exports // @@ -96,7 +99,7 @@ pub use map::Map; /// directly accessible by the `WorkingSet`/[`View`] types. /// /// See Also: [`View`], [`AsWorkingSet`], [`Map`]. -pub trait Fill: Accessor { +pub trait Fill: HasId + HasElementRef { /// Any critical error that occurs during [`fill`](Self::fill). /// /// Implementations of `fill` are expected to swallow any non-critical errors. diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f92a5908d..51497b676 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -305,6 +305,14 @@ where type Id = T::Id; } +/////////////////// +// HasElementRef // +/////////////////// + +pub trait HasElementRef { + type ElementRef<'a>; +} + //////////////// // SetElement // //////////////// @@ -380,55 +388,8 @@ where } } -////////////// -// Accessor // -////////////// - -/// A lens through which [`DataProvider`]s contextually viewed. -/// -/// Accessors are **not** required to be `'static` and almost always contain a scoped -/// reference to their parent provider. -/// -/// # Element Relationship -/// -/// Accessors are expected to define two associated element types: -/// -/// * `Element<'_>`: The type returned by `get_element`. This is scoped to the borrow -/// of the accessor at the `get_element` call site. As a consequence, there may only -/// be one such `Element` active at a time. -/// -/// * `ElementRef<'_>`: A generalized borrowed form of `Element` obtainable via -/// `Reborrow`. This is the type on which distance computations are defined and is the -/// element type provided to the `on_element_unordered` bulk operation. -/// -/// The below diagram summarizes the relationship. -/// -/// ```text -/// Element<'_> ------ Reborrow ----> ElementRef<'_> -/// ~~~~ ~~~~ -/// ^ ^ -/// | | -/// Lifetime tied Arbitrarily short -/// to the Accessor lifetime decoupled -/// from the Accessor -/// ``` -/// -/// ## Technical Details -/// -/// The need for `ElementRef` arises to allow HRTB bounds to distance computers without -/// inducing a `'static` bound on `Self`. In traits like [`BuildQueryComputer`], attempting -/// to use `Element` directly will result in such a requirement on the implementing Accessor. -pub trait Accessor: HasId + Send + Sync { - /// A generalized reference type used for distance computations. - /// - /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a - /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` - /// requirement on `Self`. - type ElementRef<'a>; -} - /// A specialized [`Accessor`] that provides random-access distance computations. -pub trait BuildDistanceComputer: Accessor { +pub trait BuildDistanceComputer: HasElementRef { /// The error type (if any) associated with distance computer construction. /// /// Implementations are encouraged to make distance computer construction infallible. @@ -453,7 +414,7 @@ pub trait BuildDistanceComputer: Accessor { /// /// Query computers are allowed to preprocess the query to enable more efficient distance /// computations. -pub trait BuildQueryComputer: Accessor { +pub trait BuildQueryComputer { /// The error type (if any) associated with distance computer construction. type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; From e2d495a45685fcc429d276ef5dd57d53cee56090 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 14 May 2026 16:46:27 -0700 Subject: [PATCH 05/17] Tidy up betafilter. --- .../model/graph/provider/layers/betafilter.rs | 392 +++++--------- diskann/src/graph/glue.rs | 510 +++++++----------- diskann/src/graph/test/provider.rs | 5 + 3 files changed, 333 insertions(+), 574 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 2ab5dad1f..7c4cb10ec 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -247,270 +247,128 @@ where } } -// /////////// -// // Tests // -// /////////// -// -// #[cfg(test)] -// mod tests { -// use diskann::{ -// ANNError, ANNResult, always_escalate, -// graph::AdjacencyList, -// graph::glue::CopyIds, -// provider::{DefaultContext, NeighborAccessor, NoopGuard}, -// }; -// use futures_util::future; -// use thiserror::Error; -// -// use super::*; -// -// /// A very simple data provider. -// struct SimpleProvider; -// impl DataProvider for SimpleProvider { -// type Context = DefaultContext; -// type InternalId = u32; -// type ExternalId = u64; -// type Guard = NoopGuard; -// -// type Error = ANNError; -// -// fn to_internal_id(&self, _context: &DefaultContext, gid: &u64) -> ANNResult { -// Ok((*gid).try_into()?) -// } -// -// fn to_external_id(&self, _context: &DefaultContext, id: u32) -> ANNResult { -// Ok(id.into()) -// } -// } -// -// /// An `Accessor` that doubles its input ID as its output element. -// /// -// /// This also tracks the number of calls made to `get_element` and -// /// `on_elements_unordered` to ensure that `BetaFilter` correctly forwards these methods. -// #[derive(Debug, Default, Clone, Copy)] -// struct Doubler { -// get_element: usize, -// on_elements_unordered: usize, -// } -// -// impl SearchExt for Doubler { -// async fn starting_points(&self) -> ANNResult> { -// Ok(vec![0]) -// } -// } -// -// impl Doubler { -// fn reset(&mut self) { -// *self = Self::default(); -// } -// } -// -// /// A simple error type to test error forwarding. -// #[derive(Debug, Error)] -// #[error("the value {0} is not allowed")] -// pub struct NotAllowed(u32); -// -// impl From for ANNError { -// #[inline(never)] -// fn from(value: NotAllowed) -> Self { -// ANNError::log_async_error(value) -// } -// } -// -// impl HasId for Doubler { -// type Id = u32; -// } -// -// impl NeighborAccessor for Doubler { -// fn get_neighbors( -// self, -// _id: Self::Id, -// neighbors: &mut AdjacencyList, -// ) -> impl Future> + Send { -// neighbors.clear(); -// future::ok(self) -// } -// } -// -// always_escalate!(NotAllowed); -// -// impl Accessor for Doubler { -// // type Element<'a> -// // = u64 -// // where -// // Self: 'a; -// type ElementRef<'a> = u64; -// -// // type GetError = NotAllowed; -// -// // fn get_element( -// // &mut self, -// // id: u32, -// // ) -> impl std::future::Future, Self::GetError>> + Send -// // { -// // self.get_element += 1; -// // let is_err = (100..200).contains(&id); -// -// // async move { -// // if is_err { -// // Err(NotAllowed(id)) -// // } else { -// // let id: u64 = id.into(); -// // Ok(2 * id) -// // } -// // } -// // } -// -// async fn on_elements_unordered( -// &mut self, -// itr: Itr, -// mut f: F, -// ) -> Result<(), Self::GetError> -// where -// Self: Sync, -// Itr: Iterator + Send, -// F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), -// { -// self.on_elements_unordered += 1; -// for i in itr { -// f(self.get_element(i).await?, i); -// } -// Ok(()) -// } -// } -// -// struct AddingComputer(u64); -// impl PreprocessedDistanceFunction for AddingComputer { -// fn evaluate_similarity(&self, x: u64) -> f32 { -// (self.0 + x) as f32 -// } -// } -// -// impl BuildQueryComputer for Doubler { -// type QueryComputer = AddingComputer; -// type QueryComputerError = ANNError; -// -// fn build_query_computer( -// &self, -// from: u64, -// ) -> Result { -// Ok(AddingComputer(from)) -// } -// } -// -// impl ExpandBeam for Doubler {} -// -// #[derive(Debug)] -// struct SimpleStrategy; -// -// impl SearchStrategy for SimpleStrategy { -// type SearchAccessor<'a> = Doubler; -// type QueryComputer = AddingComputer; -// type SearchAccessorError = ANNError; -// -// fn search_accessor<'a>( -// &'a self, -// _provider: &'a SimpleProvider, -// _context: &'a DefaultContext, -// ) -> Result, Self::SearchAccessorError> { -// Ok(Doubler::default()) -// } -// } -// -// impl glue::DefaultPostProcessor for SimpleStrategy { -// diskann::default_post_processor!(CopyIds); -// } -// -// /// A simple `QueryLabelProvider` that matches multiples of 3. -// #[derive(Debug)] -// struct ThreeFilter; -// -// impl QueryLabelProvider for ThreeFilter { -// fn is_match(&self, id: u32) -> bool { -// id.is_multiple_of(3) -// } -// } -// -// #[tokio::test] -// async fn test_beta_filter() { -// let provider = SimpleProvider; -// let context = &DefaultContext; -// let beta: f32 = 0.25; -// -// let strategy = BetaFilter::new(SimpleStrategy, Arc::new(ThreeFilter), beta); -// -// let mut accessor: BetaAccessor<_> = strategy.search_accessor(&provider, context).unwrap(); -// assert_eq!(accessor.inner.get_element, 0); -// assert_eq!(accessor.inner.on_elements_unordered, 0); -// -// // Test non-erroring path. -// let v = accessor.get_element(1).await.unwrap(); -// assert_eq!(v, Pair::new(1, 2)); -// -// let v = accessor.get_element(2).await.unwrap(); -// assert_eq!(v, Pair::new(2, 4)); -// -// // Test erroring path. -// assert!(accessor.get_element(100).await.is_err()); -// assert!(accessor.get_element(101).await.is_err()); -// -// assert_eq!(accessor.inner.get_element, 4); -// assert_eq!(accessor.inner.on_elements_unordered, 0); -// accessor.inner.reset(); -// -// // On elements unordered. -// { -// let mut v = Vec::new(); -// accessor -// .on_elements_unordered([1, 2, 3, 4, 5].into_iter(), |element, id| { -// v.push((element, id)); -// }) -// .await -// .unwrap(); -// -// assert_eq!(accessor.inner.get_element, 5); -// assert_eq!(accessor.inner.on_elements_unordered, 1); -// assert_eq!( -// v, -// &[ -// (Pair::new(1, 2), 1), -// (Pair::new(2, 4), 2), -// (Pair::new(3, 6), 3), -// (Pair::new(4, 8), 4), -// (Pair::new(5, 10), 5) -// ] -// ); -// accessor.inner.reset(); -// } -// -// // On-elements-unordered propagates errors. -// assert!( -// accessor -// .on_elements_unordered([1, 2, 3, 100, 4].into_iter(), |_, _| {}) -// .await -// .is_err() -// ); -// -// // Computation. -// let query = 10; -// let computer = accessor.build_query_computer(query).unwrap(); -// -// assert_eq!( -// computer.evaluate_similarity(accessor.get_element(10).await.unwrap()), -// (10 * 2 + query) as f32 -// ); -// assert_eq!( -// computer.evaluate_similarity(accessor.get_element(11).await.unwrap()), -// (11 * 2 + query) as f32 -// ); -// assert_eq!( -// computer.evaluate_similarity(accessor.get_element(12).await.unwrap()), -// beta * ((12 * 2 + query) as f32) -// ); -// -// // Test dummy implementation of `get_neighbors` for code coverage. -// let mut neighbors = AdjacencyList::new(); -// accessor.get_neighbors(0, &mut neighbors).await.unwrap(); -// assert_eq!(neighbors.len(), 0); -// } -// } +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use diskann::graph::{ + glue::{HybridPredicate, Predicate, PredicateMut}, + test::{provider as test_provider, synthetic::Grid}, + }; + use std::collections::HashSet; + + /// A simple `QueryLabelProvider` that matches multiples of 3. + #[derive(Debug)] + struct ThreeFilter; + + impl QueryLabelProvider for ThreeFilter { + fn is_match(&self, id: u32) -> bool { + id.is_multiple_of(3) + } + } + + struct NotIn<'a>(&'a mut HashSet); + + impl Predicate for NotIn<'_> { + fn eval(&self, item: &u32) -> bool { + !self.0.contains(item) + } + } + + impl PredicateMut for NotIn<'_> { + fn eval_mut(&mut self, item: &u32) -> bool { + self.0.insert(*item) + } + } + + impl HybridPredicate for NotIn<'_> {} + + #[tokio::test] + async fn test_beta_filter() { + // The grid of 4x4 will look like this: + // + // | 16 + // | 3 7 11 15 + // | 2 6 10 14 + // | 1 5 9 13 + // | 0 4 8 12 + // +--------------- + // + let provider = test_provider::Provider::grid(Grid::Two, 4).unwrap(); + let context = test_provider::Context::new(); + + let beta: f32 = 0.25; + + let strategy = BetaFilter::new(test_provider::Strategy::new(), Arc::new(ThreeFilter), beta); + + let start_point_ids: Vec<_> = provider.start_point_ids().collect(); + assert_eq!( + start_point_ids.len(), + 1, + "grid should only have a single start point" + ); + let start_point = start_point_ids[0]; + assert_eq!( + start_point, + u32::MAX, + "`Provider::grid` is documented to use `u32::MAX` as its start point", + ); + + let mut visited = HashSet::new(); + let mut buf = Vec::new(); + + let mut accessor = strategy.search_accessor(&provider, &context).unwrap(); + + assert_eq!( + &*accessor.starting_points().await.unwrap(), + &*start_point_ids, + "the underlying start points should match", + ); + + // Build a query computer for the point 0, 0. + let computer = accessor.build_query_computer(&[0.0, 0.0]).unwrap(); + + accessor + .expand_beam( + [0, 5, 10, 15].into_iter(), + &computer, + NotIn(&mut visited), + |distance, id| buf.push((distance, id)), + ) + .await + .unwrap(); + + // The expansion order is unknown, but we know from the structure of the grid what + // the result should be. + // + // Since each entry in the beam is connected + buf.sort_by_key(|(_, id)| *id); + assert_eq!( + &*buf, + [ + (1.0, 1), + (1.0, 4), + (5.0 * beta, 6), + (5.0 * beta, 9), + (13.0, 11), + (13.0, 14), + ] + ); + + buf.clear(); + accessor + .start_point_distances(&computer, |id, distance| buf.push((distance, id))) + .await + .unwrap(); + + assert_eq!( + &*buf, + [(32.0 * beta, start_point)], + "u32::MAX is a multiple of 3" + ); + } +} diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 11b2ed755..c07a92033 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -108,6 +108,50 @@ pub trait SearchExt: BuildQueryComputer + HasId + Send + Sync { where F: FnMut(Self::Id, f32) + Send; + /// A primitive routine used by graph search. This is purposely implemented as a + /// coarse-grained operation to enable optimization opportunities by the backend. + /// + /// # Description + /// + /// For each `i` in `itr`, fetch the adjacency list `v_i` for `i`. For each `v_i`, then + /// for each id `j` in `v_i`, compute the distance `d` using `computer` to the data + /// associated with `j` and invoke the closure `f` with `d` and `j`, provided + /// `pred.eval_mut(j)` evaluates to `true`. + /// + /// No specification is made on the traversal order of `ids`, the computation order of + /// the leaf elements, nor the order in which `pred` is evaluated. + /// + /// Restriction in the implementation are as follows: + /// + /// * If `pred.eval_mut()` returns `true` for an id `i`, then `on_neighbors` must be + /// invoked for that item. + /// + /// If an item `i` is already passed to `on_neighbors`, the implementation is not + /// obligated to provided it again, though it **may** do so provided `pred.eval_mut()` + /// continues to return `true`. + /// + /// * If `pred.eval_mut()` returns `false` for an item, then `on_neighbors` must not be + /// invoked for that item. + /// + /// * `pred.eval()` may be invoked an arbitrary number of times. Proper predicate + /// implementations will ensure that + /// + /// - `pred.eval() == true` implies `pred.eval_mut() == true` if `pred.eval_mut()` is + /// invoked immediately after `pred.eval()`. + /// + /// - `pred.eval() == false` implies `pred.eval_mut() == false` and vice-versa. + /// + /// * `pred.eval_mut()` must be invoked at most once for all transitive items in the beam. + /// + /// ## Predicate Requirements + /// + /// Well behaved predicates must never return `true` (allow an id to be forwarded to + /// `on_neighbors`) if it previously returned `false`. Implementations of `ExpandBeam` + /// are allowed to assume this holds. + /// + /// Additionally, the callback `on_neighbors` and the predicate have to cooperate. If the + /// callback requires unique items, the predicate must be structured such that `eval_mut` + /// correctly filters out duplicates. fn expand_beam( &mut self, ids: Itr, @@ -212,79 +256,6 @@ where /// The interfaces `contains` and `insert` agree with each other. impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash::Hash {} -/// A primitive routine used by graph search. This is purposely implemented as a -/// coarse-grained operation to enable optimization opportunities by the backend. -/// -/// # Description -/// -/// For each `i` in `itr`, fetch the adjacency list `v_i` for `i`. For each `v_i`, then for -/// each id `j` in `v_i`, compute the distance `d` using `computer` to the data associated with -/// `j` and invoke the closure `f` with `d` and `j`, provided `pred.eval_mut(j)` evaluates to -/// `true`. -/// -/// No specification is made on the traversal order of `ids`, the computation order of the -/// leaf elements, nor the order in which `pred` is evaluated. -/// -/// Restriction in the implementation are as follows: -/// -/// * If `pred.eval_mut()` returns `true` for an id `i`, then `on_neighbors` must be invoked -/// for that item. -/// -/// If an item `i` is already passed to `on_neighbors`, the implementation is not obligated -/// to provided it again, though it **may** do so provided `pred.eval_mut()` continues to -/// return `true`. -/// -/// * If `pred.eval_mut()` returns `false` for an item, then `on_neighbors` must not be -/// invoked for that item. -/// -/// * `pred.eval()` may be invoked an arbitrary number of times. Proper predicate -/// implementations will ensure that -/// -/// - `pred.eval() == true` implies `pred.eval_mut() == true` if `pred.eval_mut()` is -/// invoked immediately after `pred.eval()`. -/// -/// - `pred.eval() == false` implies `pred.eval_mut() == false` and vice-versa. -/// -/// * `pred.eval_mut()` must be invoked at most once for all transitive items in the beam. -/// -/// ## Predicate Requirements -/// -/// Well behaved predicates must never return `true` (allow an id to be forwarded to -/// `on_neighbors`) if it previously returned `false`. Implementations of `ExpandBeam` are -/// allowed to assume this holds. -/// -/// Additionally, the callback `on_neighbors` and the predicate have to cooperate. If the -/// callback requires unique items, the predicate must be structured such that `eval_mut` -/// correctly filters out duplicates. -/// -/// # Provided Implementation -/// -/// The provided implementation works on each element of `ids` sequentially, pre-filters -/// the resulting candidate list using `pred.eval()` before invoking -/// [`BuildQueryComputer::distances_unordered`]. -/// -/// The callback `on_neighbors` is decorated to the uses `pred.eval_mut()`. -/// -/// This ensures that if `distances_unordered` errors, the predicate is not erroneously -/// updated. -/// -/// ## Error Handling -/// -/// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -// pub trait ExpandBeam: BuildQueryComputer { -// fn expand_beam( -// &mut self, -// ids: Itr, -// computer: &Self::QueryComputer, -// pred: P, -// on_neighbors: F, -// ) -> impl std::future::Future> + Send -// where -// Itr: Iterator + Send, -// P: HybridPredicate + Send + Sync, -// F: FnMut(f32, Self::Id) + Send; -// } - /// A search strategy for query objects of type `T`. /// /// This trait should be overloaded by data providers wishing to extend @@ -836,237 +807,162 @@ where // Tests // /////////// -// #[cfg(test)] -// mod tests { -// use std::sync::{ -// Arc, -// atomic::{AtomicUsize, Ordering}, -// }; -// -// use diskann_vector::PreprocessedDistanceFunction; -// use futures_util::future; -// -// use super::*; -// use crate::{ -// ANNResult, neighbor, -// provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor}, -// }; -// -// // A really simple provider that just holds floats and uses the absolute value for its -// // distances. -// struct SimpleProvider { -// items: Vec, -// } -// -// #[derive(Default, Clone)] -// struct CountGetVector { -// count: Arc, -// } -// impl ExecutionContext for CountGetVector {} -// -// impl CountGetVector { -// fn count(&self) -> usize { -// self.count.load(Ordering::Relaxed) -// } -// -// fn clear(&self) { -// self.count.store(0, Ordering::Relaxed) -// } -// } -// -// impl DataProvider for SimpleProvider { -// type Context = CountGetVector; -// type InternalId = u32; -// type ExternalId = u32; -// type Error = ANNError; -// type Guard = crate::provider::NoopGuard; -// -// /// Translate an external id to its corresponding internal id. -// fn to_internal_id( -// &self, -// _context: &CountGetVector, -// gid: &Self::ExternalId, -// ) -> Result { -// Ok(*gid) -// } -// -// /// Translate an internal id to its corresponding external id. -// fn to_external_id( -// &self, -// _context: &CountGetVector, -// id: Self::InternalId, -// ) -> Result { -// Ok(id) -// } -// } -// -// #[derive(Clone, Copy)] -// struct Retriever<'a> { -// provider: &'a SimpleProvider, -// count: &'a CountGetVector, -// } -// -// impl SearchExt for Retriever<'_> { -// async fn starting_points(&self) -> ANNResult> { -// Ok(vec![0]) -// } -// } -// -// impl<'a> Retriever<'a> { -// fn new(provider: &'a SimpleProvider, count: &'a CountGetVector) -> Self { -// Self { provider, count } -// } -// } -// -// impl HasId for Retriever<'_> { -// type Id = u32; -// } -// -// impl Accessor for Retriever<'_> { -// // type Element<'a> -// // = f32 -// // where -// // Self: 'a; -// type ElementRef<'a> = f32; -// -// // type GetError = ANNError; -// // fn get_element( -// // &mut self, -// // id: Self::Id, -// // ) -> impl std::future::Future, Self::GetError>> + Send -// // { -// // let result = match self.provider.items.get(id as usize) { -// // Some(v) => { -// // self.count.count.fetch_add(1, Ordering::Relaxed); -// // Ok(*v) -// // } -// // None => panic!("invalid id: {}", id), -// // }; -// // async move { result } -// // } -// } -// -// impl NeighborAccessor for Retriever<'_> { -// fn get_neighbors( -// self, -// _id: Self::Id, -// neighbors: &mut AdjacencyList, -// ) -> impl Future> + Send { -// neighbors.clear(); -// future::ok(self) -// } -// } -// -// struct QueryComputer; -// impl PreprocessedDistanceFunction for QueryComputer { -// fn evaluate_similarity(&self, _changing: f32) -> f32 { -// panic!("this method should not be called") -// } -// } -// -// impl BuildQueryComputer for Retriever<'_> { -// type QueryComputerError = ANNError; -// type QueryComputer = QueryComputer; -// fn build_query_computer(&self, _from: f32) -> Result { -// Ok(QueryComputer) -// } -// } -// -// impl ExpandBeam for Retriever<'_> {} -// -// // This strategy explicitly does not define `post_process` so we can test the provided -// // implementation. -// struct Strategy; -// -// impl SearchStrategy for Strategy { -// type QueryComputer = QueryComputer; -// type SearchAccessorError = ANNError; -// type SearchAccessor<'a> = Retriever<'a>; -// -// fn search_accessor<'a>( -// &'a self, -// provider: &'a SimpleProvider, -// context: &'a CountGetVector, -// ) -> Result, Self::SearchAccessorError> { -// Ok(Retriever::new(provider, context)) -// } -// } -// -// impl DefaultPostProcessor for Strategy { -// default_post_processor!(CopyIds); -// } -// -// #[tokio::test(flavor = "current_thread")] -// async fn test_default_post_process() { -// let ctx = CountGetVector::default(); -// let strategy = Strategy; -// -// let num_points: usize = 100; -// let provider = SimpleProvider { -// items: (0..num_points).map(|i| i as f32).collect(), -// }; -// -// assert_eq!(provider.to_internal_id(&ctx, &10).unwrap(), 10); -// assert_eq!(provider.to_external_id(&ctx, 10).unwrap(), 10); -// -// let mut accessor = strategy.search_accessor(&provider, &ctx).unwrap(); -// assert_eq!(accessor.starting_points().await.unwrap().as_slice(), &[0]); -// for i in 0..num_points { -// assert_eq!(accessor.get_element(i as u32).await.unwrap(), i as f32); -// } -// -// // Check dummy get_neighbors implmeentation for code coverage -// let mut neighbors = AdjacencyList::new(); -// accessor -// .delegate_neighbor() -// .get_neighbors(0, &mut neighbors) -// .await -// .unwrap(); -// assert_eq!(neighbors.len(), 0); -// -// // Check that the correct number of reads were emitted. -// assert_eq!(ctx.count(), num_points); -// ctx.clear(); -// -// let query = 11.5; -// let computer = accessor.build_query_computer(query).unwrap(); -// -// for input_len in 0..10 { -// let input: Vec<_> = (0..input_len) -// .map(|i| Neighbor::::new(i as u32, i as f32)) -// .collect(); -// for output_len in 0..10 { -// let mut output = vec![Neighbor::::default(); output_len]; -// -// let count = strategy -// .default_post_processor() -// .post_process( -// &mut accessor, -// query, -// &computer, -// input.iter().copied(), -// &mut neighbor::BackInserter::new(output.as_mut_slice()), -// ) -// .await -// .unwrap(); -// -// assert_eq!(count, input_len.min(output_len)); -// -// // Check that the in-range values were properly copied. -// for (i, n) in output.iter().take(count).enumerate() { -// assert_eq!(i, n.id as usize); -// assert_eq!(i as f32, n.distance); -// } -// -// // Check that out-of-range values were untouched. -// for n in output.iter().skip(count) { -// assert_eq!(n.id, 0); -// assert_eq!(n.distance, 0.0); -// } -// } -// } -// -// // Ensure that no reads were emitted. -// assert_eq!(ctx.count(), 0); -// } -// } +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ANNResult, neighbor, + provider::{DefaultContext, HasId}, + }; + + // A really simple provider that just holds floats and uses the absolute value for its + // distances. + struct SimpleProvider; + + impl DataProvider for SimpleProvider { + type Context = DefaultContext; + type InternalId = u32; + type ExternalId = u32; + type Error = ANNError; + type Guard = crate::provider::NoopGuard; + + /// Translate an external id to its corresponding internal id. + fn to_internal_id( + &self, + _context: &DefaultContext, + gid: &Self::ExternalId, + ) -> Result { + Ok(*gid) + } + + /// Translate an internal id to its corresponding external id. + fn to_external_id( + &self, + _context: &DefaultContext, + id: Self::InternalId, + ) -> Result { + Ok(id) + } + } + + #[derive(Clone, Copy)] + struct Accessor; + + impl SearchExt for Accessor { + async fn starting_points(&self) -> ANNResult> { + unimplemented!(); + } + + async fn start_point_distances( + &mut self, + _computer: &Self::QueryComputer, + _f: F, + ) -> ANNResult<()> + where + F: FnMut(Self::Id, f32) + Send, + { + unimplemented!(); + } + + async fn expand_beam( + &mut self, + _ids: Itr, + _computer: &Self::QueryComputer, + _pred: P, + _on_neighbors: F, + ) -> ANNResult<()> + where + Itr: Iterator + Send, + P: HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + unimplemented!(); + } + } + + impl HasId for Accessor { + type Id = u32; + } + + struct QueryComputer; + + impl BuildQueryComputer for Accessor { + type QueryComputerError = ANNError; + type QueryComputer = QueryComputer; + fn build_query_computer(&self, _from: f32) -> Result { + Ok(QueryComputer) + } + } + + // This strategy explicitly does not define `post_process` so we can test the provided + // implementation. + struct Strategy; + + impl SearchStrategy for Strategy { + type QueryComputer = QueryComputer; + type SearchAccessorError = ANNError; + type SearchAccessor<'a> = Accessor; + + fn search_accessor<'a>( + &'a self, + _provider: &'a SimpleProvider, + _context: &'a DefaultContext, + ) -> Result, Self::SearchAccessorError> { + Ok(Accessor) + } + } + + impl DefaultPostProcessor for Strategy { + default_post_processor!(CopyIds); + } + + #[tokio::test(flavor = "current_thread")] + async fn test_default_post_process() { + let ctx = DefaultContext; + let strategy = Strategy; + let provider = SimpleProvider; + + assert_eq!(provider.to_internal_id(&ctx, &10).unwrap(), 10); + assert_eq!(provider.to_external_id(&ctx, 10).unwrap(), 10); + + let mut accessor = strategy.search_accessor(&provider, &ctx).unwrap(); + + let query = 11.5; + let computer = accessor.build_query_computer(query).unwrap(); + + for input_len in 0..10 { + let input: Vec<_> = (0..input_len) + .map(|i| Neighbor::::new(i as u32, i as f32)) + .collect(); + for output_len in 0..10 { + let mut output = vec![Neighbor::::default(); output_len]; + + let count = strategy + .default_post_processor() + .post_process( + &mut accessor, + query, + &computer, + input.iter().copied(), + &mut neighbor::BackInserter::new(output.as_mut_slice()), + ) + .await + .unwrap(); + + assert_eq!(count, input_len.min(output_len)); + + // Check that the in-range values were properly copied. + for (i, n) in output.iter().take(count).enumerate() { + assert_eq!(i, n.id as usize); + assert_eq!(i as f32, n.distance); + } + + // Check that out-of-range values were untouched. + for n in output.iter().skip(count) { + assert_eq!(n.id, 0); + assert_eq!(n.distance, 0.0); + } + } + } + } +} diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index b107f1a83..33c37f65f 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -472,6 +472,11 @@ impl Provider { .map(|ref_multi| *ref_multi.key()) .filter(|id| !self.is_start_point(*id)) } + + // Return all start point Ids. + pub fn start_point_ids(&self) -> impl Iterator + '_ { + self.config.start_points.keys().copied() + } } /// Provider level metrics. From 292c7aa1144079a257d21ffb9512767455a21fb9 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 13:44:09 -0700 Subject: [PATCH 06/17] Checkpoint before one more experiment. --- diskann/src/graph/glue.rs | 4 ++-- diskann/src/provider.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index c07a92033..77aad705f 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -268,10 +268,10 @@ where /// /// We could grab this type from the `SearchAccessor` associated type, but it's /// useful enough that we move it up here. - type QueryComputer: Send + Sync + 'static + for<'a> DoesNotCaptureSelf>; + type QueryComputer: Send + Sync + 'static; /// An error that can occur when getting a search_accessor. - type SearchAccessorError: StandardError; + type SearchAccessorError: StandardError + for<'a> DoesNotCaptureSelf>; /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index 51497b676..c1372ace1 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -440,7 +440,7 @@ pub trait BuildQueryComputer { /// /// Generally, neighbor access and data access are logically decoupled, being served from /// different stores. However, there are situations where data and neighbors are -/// interlfor<'a> eaved in the underlying storage medium. +/// interleaved in the underlying storage medium. /// /// As such, [`Accessors`] used in congunction with graph operations need to additionally /// provide an implementation of this trait. From 70cfe6216c393343a7cee3f92453aa4c2052a7f4 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 15:07:15 -0700 Subject: [PATCH 07/17] Overhaul PagedSearch. --- diskann-providers/src/index/diskann_async.rs | 19 +- diskann-providers/src/index/wrapped_async.rs | 92 +++--- diskann/src/graph/index.rs | 312 +++++++++---------- diskann/src/graph/search/scratch.rs | 1 + diskann/src/graph/test/cases/paged_search.rs | 82 ++--- 5 files changed, 233 insertions(+), 273 deletions(-) diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 69c254c86..ddb642295 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -436,25 +436,16 @@ pub(crate) mod tests { Q: Copy + std::fmt::Debug + Send + Sync, { assert!(max_candidates <= groundtruth.len()); - let mut state = index - .start_paged_search(strategy, ¶meters.context, query, parameters.search_l) + let mut search = index + .paged_search(strategy, ¶meters.context, query, parameters.search_l) .await .unwrap(); - let mut buffer = vec![Neighbor::::default(); parameters.search_k]; let mut iter = 0; let mut seen = 0; while !groundtruth.is_empty() { - let count = index - .next_search_results::( - ¶meters.context, - &mut state, - parameters.search_k, - &mut buffer, - ) - .await - .unwrap(); - for (i, b) in buffer.iter().enumerate().take(count) { + let page = search.next_page(parameters.search_k).await.unwrap(); + for (i, b) in page.iter().enumerate() { let m = is_match(groundtruth, *b, 0.01); match m { None => { @@ -469,7 +460,7 @@ pub(crate) mod tests { b, iter, i, - &buffer[i..], + &page[i..], ); } Some(j) => groundtruth.remove(j), diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index e8d26a609..73e5308cf 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -13,7 +13,7 @@ use diskann::{ Batch, DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, PruneStrategy, SearchStrategy, }, - index::{DegreeStats, PagedSearchState, PartitionedNeighbors, SearchState}, + index::{DegreeStats, PartitionedNeighbors}, search_output_buffer, }, neighbor::Neighbor, @@ -339,59 +339,51 @@ where ) } - #[allow(clippy::type_complexity)] - pub fn start_paged_search( - &self, + /// Begin a paged search over the index (synchronous wrapper). + /// + /// Returns a [`PagedSearch`] handle. See + /// [`graph::index::PagedSearch::next_page`] for retrieving results. + pub fn paged_search<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - ) -> ANNResult> + ) -> ANNResult> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { - self.handle.block_on( - self.inner - .start_paged_search(strategy, context, query, l_value), - ) + let inner = self + .handle + .block_on(self.inner.paged_search(strategy, context, query, l_value))?; + Ok(PagedSearch { + handle: self.handle.clone(), + inner, + }) } - #[allow(clippy::type_complexity)] - pub fn start_paged_search_with_init_ids( - &self, + /// Begin a paged search with explicit initial seed IDs (synchronous wrapper). + pub fn paged_search_with_init_ids<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - init_ids: Option<&[DP::InternalId]>, - ) -> ANNResult> + init_ids: Option<&'a [DP::InternalId]>, + ) -> ANNResult> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { - self.handle.block_on( + let inner = self.handle.block_on( self.inner - .start_paged_search_with_init_ids(strategy, context, query, l_value, init_ids), - ) - } - - pub fn next_search_results( - &self, - context: &DP::Context, - search_state: &mut SearchState, - k: usize, - result_output: &mut [Neighbor], - ) -> ANNResult - where - S: SearchStrategy, - { - self.handle.block_on(self.inner.next_search_results( - context, - search_state, - k, - result_output, - )) + .paged_search_with_init_ids(strategy, context, query, l_value, init_ids), + )?; + Ok(PagedSearch { + handle: self.handle.clone(), + inner, + }) } pub fn count_reachable_nodes( @@ -416,6 +408,28 @@ where } } +/// Synchronous wrapper around [`graph::index::PagedSearch`] that owns a tokio runtime handle. +/// +/// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) +/// blocks the current thread to drive the underlying async search forward. +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + handle: tokio::runtime::Handle, + inner: graph::index::PagedSearch<'a, DP, S, T>, +} + +impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +where + DP: DataProvider, + S: SearchStrategy, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Blocks the current thread. Returns an empty `Vec` when the search is exhausted. + pub fn next_page(&mut self, k: usize) -> ANNResult>> { + self.handle.block_on(self.inner.next_page(k)) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index d696135a0..0f0637975 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -137,34 +137,6 @@ pub struct PartitionedNeighbors { pub deleted: Vec, } -/// Placeholder for extra state. -/// -/// The contents of the search state are designed for the synchronous index. -/// However, use cases in the asynchronous index require some extra state. -/// -/// This placeholder is used in the synchronous code-paths. -pub struct NoExtraState; - -/// Represents the state of the pagged search. -/// It can be used to do paged search by doing multiple `nextSearchResults()` queries. -/// -/// Generic extra state can be included to facilitate extra use-cases. -/// However, this extra state **must** be '`static' as we do not know how long the search -/// state will live for. -#[derive(Debug)] -pub struct SearchState { - /// Scratch space for query processing. - pub scratch: SearchScratch, - /// The computed search results ready to be returned in `nextSearchResults()` query - pub computed_result: Vec>, - /// The index of the next result to be returned. - pub next_result_index: usize, - /// The search computes results in the multiple of `search_param_l`. - pub search_param_l: usize, - /// Any extra data needed by down-stream implementations. - pub extra: ExtraState, -} - /// Edge pending submission for multi-insert. #[derive(Debug)] struct PendingEdge { @@ -204,16 +176,6 @@ where /// and `Err` paths. type BatchResult = Result; -/// State used during by paged search to perform multiple, consecutive searches over the index. -/// -/// Type parameters: -/// -/// * `DP`: The type of the [`DataProvider`]. -/// * `S`: The type of the [`SearchStrategy`]. -/// * `C`: The type of `S`'s [`BuildQueryComputer`] computer. This exists as a separate -/// type parameter because the type of the query computer depends on the type of the query. -pub type PagedSearchState = SearchState<::InternalId, (S, C)>; - impl DiskANNIndex where DP: DataProvider, @@ -2280,34 +2242,42 @@ where // Paged Search // ////////////////// - pub fn start_paged_search( - &self, + /// Begin a paged search over the index. + /// + /// Returns a [`PagedSearch`] handle whose [`next_page`](PagedSearch::next_page) method + /// yields successive pages of nearest-neighbor results. + pub fn paged_search<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - ) -> impl SendFuture>> + ) -> impl SendFuture>> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { async move { - self.start_paged_search_with_init_ids(strategy, context, query, l_value, None) + self.paged_search_with_init_ids(strategy, context, query, l_value, None) .await } } - pub fn start_paged_search_with_init_ids( - &self, + /// Begin a paged search with explicit initial seed IDs. + /// + /// This is the same as [`paged_search`](Self::paged_search) but allows the caller to + /// provide custom starting points for the graph traversal. + pub fn paged_search_with_init_ids<'a, S, T>( + &'a self, strategy: S, - context: &DP::Context, + context: &'a DP::Context, query: T, l_value: usize, - init_ids: Option<&[DP::InternalId]>, - ) -> impl SendFuture>> + init_ids: Option<&'a [DP::InternalId]>, + ) -> impl SendFuture>> where S: SearchStrategy + 'static, - T: Copy + Send, + T: Copy + Send + 'a, { async move { let (computer, scratch) = { @@ -2350,121 +2320,20 @@ where (computer, scratch) }; - ANNResult::Ok(SearchState { + ANNResult::Ok(PagedSearch { + index: self, + context, scratch, computed_result: vec![Neighbor::default(); l_value], next_result_index: l_value, search_param_l: l_value, - extra: (strategy, computer), + strategy, + computer, + _query: std::marker::PhantomData, }) } } - pub fn next_search_results( - &self, - context: &DP::Context, - search_state: &mut SearchState, - k: usize, - result_output: &mut [Neighbor], - ) -> impl SendFuture> - where - S: SearchStrategy, - { - async move { - if k > search_state.search_param_l { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be less than or equal to search_param_l".to_string(), - )); - } - if k == 0 { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be greater than 0".to_string(), - )); - } - if result_output.len() < k { - return ANNResult::Err(ANNError::log_paged_search_error( - "The size of result_output should be greater than or equal to k".to_string(), - )); - } - - let copy_to_output = - |search_state: &mut SearchState, - count: usize, - result_output: &mut [Neighbor], - result_output_offset: usize| { - result_output[result_output_offset..result_output_offset + count] - .copy_from_slice( - &search_state.computed_result[search_state.next_result_index - ..search_state.next_result_index + count], - ); - search_state.next_result_index += count; - }; - - let used_computed_result_count: usize = cmp::min( - k, - search_state.computed_result.len() - search_state.next_result_index, - ); - if used_computed_result_count > 0 { - copy_to_output( - search_state, - used_computed_result_count, - result_output, - 0, // result_output_offset - ); - - if used_computed_result_count == k { - return ANNResult::Ok(k); - } - } - - let start_points = { - let mut accessor = search_state - .extra - .0 - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - self.search_internal( - None, // beam_width - &start_ids, - &mut accessor, - &search_state.extra.1, - &mut search_state.scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - start_ids - }; - - let (mut candidates, total_considered) = self - .filter_search_candidates(&start_points, k, &mut search_state.scratch.best) - .await?; - search_state.scratch.best.drain_best(total_considered); - - let computed_result_count = candidates.len(); - search_state.computed_result.clear(); - search_state.computed_result.append(&mut candidates); - - search_state.next_result_index = 0; - if computed_result_count != search_state.search_param_l { - search_state.computed_result.truncate(computed_result_count); - } - - let leftover_results = cmp::min(k - used_computed_result_count, computed_result_count); - - copy_to_output( - search_state, - leftover_results, // count of results to copy - result_output, - used_computed_result_count, // result_output_offset - ); - - ANNResult::Ok(used_computed_result_count + leftover_results) - } - } - /// Count the number of nodes in the graph reachable from the given `start_points`. /// /// This function has a large memory footprint for large graphs and should not be called @@ -3188,6 +3057,135 @@ struct BatchIdMismatch { ids_len: usize, } +////////////////// +// Paged Search // +////////////////// + +/// A paged search handle that owns all search state internally. +/// +/// Created by [`DiskANNIndex::paged_search`] or +/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to +/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of +/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// +/// # Type Parameters +/// +/// * `'idx` — lifetime of the borrowed [`DiskANNIndex`]. +/// * `'ctx` — lifetime of the borrowed [`DataProvider::Context`]. +/// * `DP` — the [`DataProvider`] type. +/// * `S` — the [`SearchStrategy`] type. +/// * `T` — the original query type (carried only for trait-bound resolution). +#[derive(Debug)] +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + index: &'a DiskANNIndex, + context: &'a DP::Context, + scratch: SearchScratch, + computed_result: Vec>, + next_result_index: usize, + search_param_l: usize, + strategy: S, + computer: S::QueryComputer, + // Note: + _query: std::marker::PhantomData T>, +} + +impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +where + DP: DataProvider, + S: SearchStrategy, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Results across pages are non-overlapping and ordered by non-decreasing distance. + /// When the search is exhausted, returns an empty `Vec`. + pub fn next_page( + &mut self, + k: usize, + ) -> impl SendFuture>>> { + async move { + if k > self.search_param_l { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be less than or equal to search_param_l".to_string(), + )); + } + if k == 0 { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be greater than 0".to_string(), + )); + } + + let mut result = Vec::with_capacity(k); + + // Drain any already-computed results first. + let available = self + .computed_result + .len() + .saturating_sub(self.next_result_index); + let from_cache = cmp::min(k, available); + if from_cache > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + from_cache], + ); + self.next_result_index += from_cache; + + if result.len() == k { + return ANNResult::Ok(result); + } + } + + // Resume graph search to fill the next batch. + let start_points = { + let mut accessor = self + .strategy + .search_accessor(&self.index.data_provider, self.context) + .into_ann_result()?; + + let start_ids = accessor.starting_points().await?; + self.index + .search_internal( + None, // beam_width + &start_ids, + &mut accessor, + &self.computer, + &mut self.scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + start_ids + }; + + let (mut candidates, total_considered) = self + .index + .filter_search_candidates(&start_points, k, &mut self.scratch.best) + .await?; + self.scratch.best.drain_best(total_considered); + + let computed_result_count = candidates.len(); + self.computed_result.clear(); + self.computed_result.append(&mut candidates); + + self.next_result_index = 0; + if computed_result_count != self.search_param_l { + self.computed_result.truncate(computed_result_count); + } + + let remaining_need = k - result.len(); + let leftover = cmp::min(remaining_need, computed_result_count); + if leftover > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + leftover], + ); + self.next_result_index += leftover; + } + + ANNResult::Ok(result) + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/diskann/src/graph/search/scratch.rs b/diskann/src/graph/search/scratch.rs index 98a5b3127..2670f118c 100644 --- a/diskann/src/graph/search/scratch.rs +++ b/diskann/src/graph/search/scratch.rs @@ -158,6 +158,7 @@ where } /// Return the currently configured `search_l`: the number of best candidates to track. + #[cfg(test)] pub fn search_l(&self) -> usize { self.best.search_l() } diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 0b8fa4706..8c2f79ee1 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -6,8 +6,8 @@ //! Tests for paged (iterative) search. //! //! Paged search returns results in pages of k neighbors via a stateful -//! `SearchState`. Tests cover basic pagination, single-page retrieval, -//! and small page sizes that stress the iteration machinery. +//! [`PagedSearch`](crate::graph::index::PagedSearch) handle. Tests cover basic pagination, +//! single-page retrieval, and small page sizes that stress the iteration machinery. use std::sync::Arc; @@ -155,8 +155,8 @@ fn basic_paged_search() { let page_size = 4; let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( + let mut search = rt + .block_on(index.paged_search( test_provider::Strategy::new(), &ctx, query.as_slice(), @@ -165,24 +165,13 @@ fn basic_paged_search() { .unwrap(); let mut pages: Vec>> = Vec::new(); - let mut buffer = vec![Neighbor::::default(); page_size]; loop { - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - if count == 0 { + let page = rt.block_on(search.next_page(page_size)).unwrap(); + if page.is_empty() { break; } - pages.push(buffer[..count].to_vec()); + pages.push(page); } let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); @@ -209,8 +198,8 @@ fn single_page() { let page_size = 200; // larger than total points (125) let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( + let mut search = rt + .block_on(index.paged_search( test_provider::Strategy::new(), &ctx, query.as_slice(), @@ -218,21 +207,8 @@ fn single_page() { )) .unwrap(); - let mut buffer = vec![Neighbor::::default(); page_size]; - - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - let results: Vec> = buffer[..count].to_vec(); - let pages = vec![results.clone()]; + let results = rt.block_on(search.next_page(page_size)).unwrap(); + let pages = vec![results]; let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); @@ -242,18 +218,9 @@ fn single_page() { assert_no_duplicates_across_pages(&pages); assert_non_decreasing_distances(&pages); - // Verify second call returns 0 (nothing left) - let count2 = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - assert_eq!(count2, 0, "second page should be empty"); + // Verify second call returns empty (nothing left) + let page2 = rt.block_on(search.next_page(page_size)).unwrap(); + assert!(page2.is_empty(), "second page should be empty"); } #[test] @@ -270,8 +237,8 @@ fn small_page_size() { let page_size = 1; // one result per page, maximum iterations let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( + let mut search = rt + .block_on(index.paged_search( test_provider::Strategy::new(), &ctx, query.as_slice(), @@ -280,24 +247,13 @@ fn small_page_size() { .unwrap(); let mut pages: Vec>> = Vec::new(); - let mut buffer = vec![Neighbor::::default(); page_size]; loop { - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - if count == 0 { + let page = rt.block_on(search.next_page(page_size)).unwrap(); + if page.is_empty() { break; } - pages.push(buffer[..count].to_vec()); + pages.push(page); } let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); From 219302537bfcdf3a83e183f300f850411fc0071d Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 15:08:23 -0700 Subject: [PATCH 08/17] Query Computers no longer need to be '`static`. --- diskann/src/graph/index.rs | 2 +- diskann/src/provider.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 0f0637975..cb6e75907 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -3085,7 +3085,7 @@ pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { search_param_l: usize, strategy: S, computer: S::QueryComputer, - // Note: + // Note: The use of `fn` here is so _query: std::marker::PhantomData T>, } diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f55af356c..262e9b15c 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -508,8 +508,7 @@ pub trait BuildQueryComputer: Accessor { /// elements yielded by the [`Accessor`]. type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + Send - + Sync - + 'static; + + Sync; /// Build the query computer for this accessor. /// From 9c773eb2b779ed14f2c90e0fc0f57d8e80330ecc Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 15:10:52 -0700 Subject: [PATCH 09/17] Clarify use of `fn` pointer. --- diskann/src/graph/index.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index cb6e75907..27c6baa81 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -3067,14 +3067,6 @@ struct BatchIdMismatch { /// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to /// [`next_page`](Self::next_page) resumes the graph search and returns the next page of /// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. -/// -/// # Type Parameters -/// -/// * `'idx` — lifetime of the borrowed [`DiskANNIndex`]. -/// * `'ctx` — lifetime of the borrowed [`DataProvider::Context`]. -/// * `DP` — the [`DataProvider`] type. -/// * `S` — the [`SearchStrategy`] type. -/// * `T` — the original query type (carried only for trait-bound resolution). #[derive(Debug)] pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { index: &'a DiskANNIndex, @@ -3085,8 +3077,8 @@ pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { search_param_l: usize, strategy: S, computer: S::QueryComputer, - // Note: The use of `fn` here is so - _query: std::marker::PhantomData T>, + // Note: The `fn` is so we derive `Send` and `Sync` more easily: `fn` is always Send/Sync. + _query: std::marker::PhantomData, } impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> From b63d009893f362e07ace518314f7c85733cba212 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 12:23:09 -0700 Subject: [PATCH 10/17] Paged search 2.0 --- diskann-providers/src/index/wrapped_async.rs | 4 +- diskann/src/graph/index.rs | 158 +-------------- diskann/src/graph/search/mod.rs | 3 + diskann/src/graph/search/paged.rs | 163 ++++++++++++++++ diskann/src/graph/test/cases/paged_search.rs | 4 + rfcs/01078-paged-search.md | 193 +++++++++++++++++++ 6 files changed, 369 insertions(+), 156 deletions(-) create mode 100644 diskann/src/graph/search/paged.rs create mode 100644 rfcs/01078-paged-search.md diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 73e5308cf..858218b55 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -408,13 +408,13 @@ where } } -/// Synchronous wrapper around [`graph::index::PagedSearch`] that owns a tokio runtime handle. +/// Synchronous wrapper around [`graph::search::PagedSearch`] that owns a tokio runtime handle. /// /// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) /// blocks the current thread to drive the underlying async search forward. pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { handle: tokio::runtime::Handle, - inner: graph::index::PagedSearch<'a, DP, S, T>, + inner: graph::search::PagedSearch<'a, DP, S, T>, } impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 27c6baa81..54d40cc81 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -30,7 +30,7 @@ use super::{ }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ - Knn, + Knn, PagedSearch, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -42,7 +42,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::{ErrorExt, IntoANNResult}, internal, - neighbor::{self, Neighbor, NeighborPriorityQueue, NeighborQueue}, + neighbor::{self, Neighbor, NeighborQueue}, provider::{ Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, Delete, ElementStatus, ExecutionContext, Guard, NeighborAccessor, @@ -2052,35 +2052,6 @@ where } } - /// Filter out start nodes from the best candidates in the scratch. - fn filter_search_candidates( - &self, - start_points: &[DP::InternalId], - l_value: usize, - best: &mut NeighborPriorityQueue, - ) -> impl SendFuture>, usize)>> { - async move { - let mut total = 0usize; - let mut candidates = Vec::with_capacity(l_value); - for n in best.iter() { - total += 1; - if !start_points.contains(&n.id) { - candidates.push(n); - if candidates.len() >= l_value { - break; - } - } - } - - debug_assert!( - l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), - "Not enough candidates after filtering starting points", - ); - - Ok((candidates, total)) - } - } - /// Execute a search using the unified search interface. /// /// This method provides a single entry point for all search types. The `search_params` argument @@ -2254,7 +2225,7 @@ where l_value: usize, ) -> impl SendFuture>> where - S: SearchStrategy + 'static, + S: SearchStrategy, T: Copy + Send + 'a, { async move { @@ -2276,7 +2247,7 @@ where init_ids: Option<&'a [DP::InternalId]>, ) -> impl SendFuture>> where - S: SearchStrategy + 'static, + S: SearchStrategy, T: Copy + Send + 'a, { async move { @@ -3057,127 +3028,6 @@ struct BatchIdMismatch { ids_len: usize, } -////////////////// -// Paged Search // -////////////////// - -/// A paged search handle that owns all search state internally. -/// -/// Created by [`DiskANNIndex::paged_search`] or -/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to -/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of -/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. -#[derive(Debug)] -pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { - index: &'a DiskANNIndex, - context: &'a DP::Context, - scratch: SearchScratch, - computed_result: Vec>, - next_result_index: usize, - search_param_l: usize, - strategy: S, - computer: S::QueryComputer, - // Note: The `fn` is so we derive `Send` and `Sync` more easily: `fn` is always Send/Sync. - _query: std::marker::PhantomData, -} - -impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> -where - DP: DataProvider, - S: SearchStrategy, -{ - /// Returns the next page of at most `k` nearest-neighbor results. - /// - /// Results across pages are non-overlapping and ordered by non-decreasing distance. - /// When the search is exhausted, returns an empty `Vec`. - pub fn next_page( - &mut self, - k: usize, - ) -> impl SendFuture>>> { - async move { - if k > self.search_param_l { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be less than or equal to search_param_l".to_string(), - )); - } - if k == 0 { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be greater than 0".to_string(), - )); - } - - let mut result = Vec::with_capacity(k); - - // Drain any already-computed results first. - let available = self - .computed_result - .len() - .saturating_sub(self.next_result_index); - let from_cache = cmp::min(k, available); - if from_cache > 0 { - result.extend_from_slice( - &self.computed_result - [self.next_result_index..self.next_result_index + from_cache], - ); - self.next_result_index += from_cache; - - if result.len() == k { - return ANNResult::Ok(result); - } - } - - // Resume graph search to fill the next batch. - let start_points = { - let mut accessor = self - .strategy - .search_accessor(&self.index.data_provider, self.context) - .into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - self.index - .search_internal( - None, // beam_width - &start_ids, - &mut accessor, - &self.computer, - &mut self.scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - start_ids - }; - - let (mut candidates, total_considered) = self - .index - .filter_search_candidates(&start_points, k, &mut self.scratch.best) - .await?; - self.scratch.best.drain_best(total_considered); - - let computed_result_count = candidates.len(); - self.computed_result.clear(); - self.computed_result.append(&mut candidates); - - self.next_result_index = 0; - if computed_result_count != self.search_param_l { - self.computed_result.truncate(computed_result_count); - } - - let remaining_need = k - result.len(); - let leftover = cmp::min(remaining_need, computed_result_count); - if leftover > 0 { - result.extend_from_slice( - &self.computed_result - [self.next_result_index..self.next_result_index + leftover], - ); - self.next_result_index += leftover; - } - - ANNResult::Ok(result) - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index fac279421..350930671 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -46,6 +46,9 @@ mod knn_search; mod multihop_search; mod range_search; +mod paged; +pub use paged::PagedSearch; + pub mod record; pub(crate) mod scratch; diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs new file mode 100644 index 000000000..3b1cebcbb --- /dev/null +++ b/diskann/src/graph/search/paged.rs @@ -0,0 +1,163 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::future::SendFuture; + +use crate::{ + ANNError, ANNResult, + error::IntoANNResult, + graph::{ + DiskANNIndex, + glue::{SearchExt, SearchStrategy}, + search::{record::NoopSearchRecord, scratch::SearchScratch}, + }, + neighbor::{Neighbor, NeighborPriorityQueue}, + provider::DataProvider, + utils::VectorId, +}; + +/// A paged search handle that owns all search state internally. +/// +/// Created by [`DiskANNIndex::paged_search`] or +/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to +/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of +/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +#[derive(Debug)] +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + pub(in crate::graph) index: &'a DiskANNIndex, + pub(in crate::graph) context: &'a DP::Context, + pub(in crate::graph) scratch: SearchScratch, + pub(in crate::graph) computed_result: Vec>, + pub(in crate::graph) next_result_index: usize, + pub(in crate::graph) search_param_l: usize, + pub(in crate::graph) strategy: S, + pub(in crate::graph) computer: S::QueryComputer, + // Note: The `fn` is so we derive `Send` and `Sync` more easily: `fn` is always Send/Sync. + pub(in crate::graph) _query: std::marker::PhantomData, +} + +impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +where + DP: DataProvider, + S: SearchStrategy, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Results across pages are non-overlapping and ordered by non-decreasing distance. + /// When the search is exhausted, returns an empty `Vec`. + pub fn next_page( + &mut self, + k: usize, + ) -> impl SendFuture>>> { + async move { + if k > self.search_param_l { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be less than or equal to search_param_l".to_string(), + )); + } + if k == 0 { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be greater than 0".to_string(), + )); + } + + let mut result = Vec::with_capacity(k); + + // Drain any already-computed results first. + let available = self + .computed_result + .len() + .saturating_sub(self.next_result_index); + let from_cache = std::cmp::min(k, available); + if from_cache > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + from_cache], + ); + self.next_result_index += from_cache; + + if result.len() == k { + return ANNResult::Ok(result); + } + } + + // Resume graph search to fill the next batch. + let start_points = { + let mut accessor = self + .strategy + .search_accessor(&self.index.data_provider, self.context) + .into_ann_result()?; + + let start_ids = accessor.starting_points().await?; + self.index + .search_internal( + None, // beam_width + &start_ids, + &mut accessor, + &self.computer, + &mut self.scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + start_ids + }; + + let (mut candidates, total_considered) = + filter_search_candidates(&start_points, k, &mut self.scratch.best); + self.scratch.best.drain_best(total_considered); + + let computed_result_count = candidates.len(); + self.computed_result.clear(); + self.computed_result.append(&mut candidates); + + self.next_result_index = 0; + if computed_result_count != self.search_param_l { + self.computed_result.truncate(computed_result_count); + } + + let remaining_need = k - result.len(); + let leftover = std::cmp::min(remaining_need, computed_result_count); + if leftover > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + leftover], + ); + self.next_result_index += leftover; + } + + ANNResult::Ok(result) + } + } +} + +// FIXME: Wire proper post-processing support into paged search. +fn filter_search_candidates( + start_points: &[I], + l_value: usize, + best: &mut NeighborPriorityQueue, +) -> (Vec>, usize) +where + I: VectorId, +{ + let mut total = 0usize; + let mut candidates = Vec::with_capacity(l_value); + for n in best.iter() { + total += 1; + if !start_points.contains(&n.id) { + candidates.push(n); + if candidates.len() >= l_value { + break; + } + } + } + + debug_assert!( + l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), + "Not enough candidates after filtering starting points", + ); + + (candidates, total) +} diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 8c2f79ee1..9b7fd27f6 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -221,6 +221,10 @@ fn single_page() { // Verify second call returns empty (nothing left) let page2 = rt.block_on(search.next_page(page_size)).unwrap(); assert!(page2.is_empty(), "second page should be empty"); + + // Verify repeated calls after exhaustion remain empty (idempotent, no panic). + let page3 = rt.block_on(search.next_page(page_size)).unwrap(); + assert!(page3.is_empty(), "third page should still be empty"); } #[test] diff --git a/rfcs/01078-paged-search.md b/rfcs/01078-paged-search.md new file mode 100644 index 000000000..50ac58d2c --- /dev/null +++ b/rfcs/01078-paged-search.md @@ -0,0 +1,193 @@ +# Overhaul Paged Search + +| | | +|---|---| +| **Authors** | Mark Hildebrand | +| **Contributors** | | +| **Created** | 2026-05-18 | +| **Updated** | 2026-05-18 | + +## Summary + +Replace the `SearchState<..., ExtraState: 'static>` pattern for paged search with a lifetime-bound `PagedSearch<'a, ...>` in `diskann`, and document a channel-based spawned-task pattern for downstream consumers that need to cross `tokio::spawn` or FFI boundaries. +This removes the `'static` requirement on query computers and search strategies, enabling future trait simplification. + +## Motivation + +### Background + +Paged search allows callers to retrieve nearest-neighbor results incrementally (one "page" at a time) without restarting the graph traversal. +The search state (scratch buffers, the priority queue, the query computer) must persist across page boundaries. + +Earlier synchronous version of DiskANN did this by persisting the search state manually and passing the search state explicitly to the next-page requests. +The async rewrite stuck with this pattern where callers were required to manage a `SearchState` struct whose `ExtraState` type parameter carried a `'static` bound: +```rust +// OLD: ExtraState must be 'static +pub struct SearchState { ... } +pub type PagedSearchState = SearchState<::InternalId, (S, C)>; + +// Note +// * S: SearchStrategy for some T +// * C: S::QueryComputer +``` +In downstream crates that expose paged search across FFI or task boundaries, the state was traditionally type-erased behind `Box` (which also captured the `DiskANNIndex`) and sent as an opaque pointer. +This required both the strategy and the query computer (parameters `S` and `C`) to be `'static`. + +### Problem Statement + +The `'static` bound on `BuildQueryComputer::QueryComputer` propagates throughout the trait hierarchy: + +```rust +// OLD +type QueryComputer: ... + Send + Sync + 'static; +``` + +This prevents: + +1. Query computers that borrow from the index or context (common with quantization tables). +2. Fusing the query-computer into the accessor. + Due to the lifetime needed by accessors, they can't be persisted in a `'static` struct this way. + However, query computers may contain non-trivial pre-processed state, meaning recreating them on each new page retrieval is a performance footgun. +3. Simplifying the `SearchStrategy` / `Accessor` trait tower by removing unnecessary indirection introduced solely to satisfy `'static`. + +### Goals + +1. Remove the `'static` bound from `BuildQueryComputer::QueryComputer`. +2. Remove `SearchState`, `NoExtraState`, and `PagedSearchState` from the public API. +3. Provide a `PagedSearch<'a, ...>` handle that is lifetime-bound to the index and context, encapsulating all search state. +4. Document the channel-based pattern for downstream consumers that need to cross task/FFI boundaries (where `'static` is inherently required by the runtime, not by `diskann`). + +## Proposal + +The key idea is this: DiskANN for better or worse is already fully async, and async Rust despite its flaws already provides a clean way of doing this without requiring our traits to bend over backwards. +So let's embrace async to actually help us for a change. + +### Core library (`diskann`) + +Replace the old split API: + +```rust +// OLD +index.start_paged_search(strategy, ctx, query, l) -> SearchState<...> +index.next_search_results(ctx, &mut state, k, &mut buf) -> usize +``` + +With a self-contained handle: + +```rust +// NEW +impl DiskANNIndex { + pub fn paged_search<'a, S, T>( + &'a self, + strategy: S, + context: &'a DP::Context, + query: T, + l_value: usize, + ) -> impl SendFuture>> + where + S: SearchStrategy, + T: Copy + Send + 'a; +} + +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + index: &'a DiskANNIndex, + context: &'a DP::Context, + scratch: SearchScratch, + computed_result: Vec>, + next_result_index: usize, + search_param_l: usize, + strategy: S, + computer: S::QueryComputer, + _query: PhantomData, // covariant, always Send+Sync +} + +impl PagedSearch<'a, DP, S, T> { + pub fn next_page(&mut self, k: usize) -> impl SendFuture>>>; +} +``` + +The key change is that `PagedSearch` borrows all the necessary components. +There is no need to deconstruct the search components into `'static` pieces after each paged search and "reassemble" them on subsequent searches. + +### Crossing spawn boundaries: the channel pattern + +`PagedSearch<'a, ...>` borrows the index, so it cannot be sent to a `tokio::spawn`'d task directly. +When a long-lived session or an FFI boundary requires `'static` ownership, the recommended pattern is to **spawn a task that owns the search state as a local variable** and communicate with it via channels: + +```rust +// Types are illustrative — adapt names to your crate. + +type PageResult = ANNResult>>; + +/// Spawn a paged search session. The index is held by Arc so the task is 'static. +/// +/// Returns a request channel and a result channel. The caller sends the desired +/// page size (`k`) and awaits the corresponding result on the other end. +fn spawn_paged_session( + index: Arc>, + context: Arc, + query: T, + l: usize, +) -> (mpsc::Sender, mpsc::Receiver) { + let (req_tx, mut req_rx) = mpsc::channel::(1); + let (res_tx, res_rx) = mpsc::channel::(1); + + tokio::spawn(async move { + // Borrow from the Arc — these references are scoped to the task. + let mut search = index.paged_search(strategy, &*context, query, l).await.unwrap(); + + while let Some(k) = req_rx.recv().await { + let page = search.next_page(k).await; + if res_tx.send(page).await.is_err() { + break; // caller dropped the result receiver + } + } + // Request channel closed -> caller dropped sender -> clean shutdown. + }); + + (req_tx, res_rx) +} +``` + +Key properties of this pattern: + +1. **`'static` is confined to the spawn boundary**: the `Arc` satisfies the runtime's requirement, while the borrow from it lives entirely inside the task's local scope. + Importantly, even though `PagedSearch` borrows, it can be embedded inside a `'static` future. +2. **State is fully encapsulated**: callers never see `SearchScratch`, `QueryComputer`, or any internal types. +3. **Clean shutdown**: dropping the request sender closes the channel; the task exits gracefully. +4. **Per-request context**: the request channel can carry additional metadata (profiling tokens, cancellation flags, etc.) without polluting the core API. + +### Migration guide + +| Old pattern | New pattern | +|---|---| +| `index.start_paged_search(s, ctx, q, l)` | `index.paged_search(s, ctx, q, l).await` | +| `index.next_search_results(ctx, &mut state, k, &mut buf)` | `search.next_page(k).await` | +| `SearchState` | `PagedSearch<'a, DP, S, T>` | +| `PagedSearchState` | `PagedSearch<'a, DP, S, T>` | +| Check return count for exhaustion | Check `page.is_empty()` | +| Type-erased `Box` across task/FFI boundaries | Channel + spawned task (see above) | + +## Feasibility via FFI + +This is a pretty big change in API, but it enables some significant future simplifications to our trait hierarchy by removing the `'static` special case introduced by paged search. +An internal user of paged search was ported to this new approach to check the feasibility. +While it was a bit of work to overcome the impedance mismatch of the quirks of that integration, the end result is cleaner, has fewer overall task spawns, and fewer FFI related race conditions. +And really, this integration was already basically doing this same thing behind the scenes. + +## Alternatives + +The main alternative I see is to keep the status quo with explicit state management. +While some of planned trait simplifications are still on the table, I think the opportunity to align paged search with the rest of the trait hierarchy is well worth it. + +## Benchmark Results + +No performance change expected (nor observed in the simulator for the aforementioned internal FFI user) since the search algorithm is identical. +Existing Rust code will have a similar pattern of future usage as before, just packaged slightly differently. + +## References + +1. [RFC 3498 — Lifetime Capture Rules 2024](https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html) — + Rust edition 2024 changes that make `impl Trait + 'a` returns more ergonomic. +2. [tokio::sync::mpsc](https://docs.rs/tokio/latest/tokio/sync/mpsc/index.html) — the channel + primitive used in the spawned-task pattern. From c357f024b9ddfacd53d314664ec440320db1ed32 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 14:55:02 -0700 Subject: [PATCH 11/17] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-providers/src/index/wrapped_async.rs | 2 +- diskann/src/graph/test/cases/paged_search.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 858218b55..2f7ee997a 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -342,7 +342,7 @@ where /// Begin a paged search over the index (synchronous wrapper). /// /// Returns a [`PagedSearch`] handle. See - /// [`graph::index::PagedSearch::next_page`] for retrieving results. + /// [`PagedSearch::next_page`] for retrieving results. pub fn paged_search<'a, S, T>( &'a self, strategy: S, diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 9b7fd27f6..61e2c8c45 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -6,7 +6,7 @@ //! Tests for paged (iterative) search. //! //! Paged search returns results in pages of k neighbors via a stateful -//! [`PagedSearch`](crate::graph::index::PagedSearch) handle. Tests cover basic pagination, +//! [`PagedSearch`](crate::graph::search::PagedSearch) handle. Tests cover basic pagination, //! single-page retrieval, and small page sizes that stress the iteration machinery. use std::sync::Arc; From 02e2ccc41ef7ef59bf488d61bace13250d0bcbb3 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 14:55:22 -0700 Subject: [PATCH 12/17] Fix documentation. --- diskann/src/graph/search/paged.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs index 3b1cebcbb..ab77829e3 100644 --- a/diskann/src/graph/search/paged.rs +++ b/diskann/src/graph/search/paged.rs @@ -18,12 +18,12 @@ use crate::{ utils::VectorId, }; -/// A paged search handle that owns all search state internally. +/// Intermediate state for paged search. /// -/// Created by [`DiskANNIndex::paged_search`] or -/// [`DiskANNIndex::paged_search_with_init_ids`]. Each call to -/// [`next_page`](Self::next_page) resumes the graph search and returns the next page of -/// nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// Each call to [`next_page`](Self::next_page) resumes the graph search and returns the +/// next page of nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// +/// See also: [`DiskANNIndex::paged_search`], [`DiskANNIndex::paged_search_with_init_ids`]. #[derive(Debug)] pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { pub(in crate::graph) index: &'a DiskANNIndex, @@ -45,7 +45,11 @@ where { /// Returns the next page of at most `k` nearest-neighbor results. /// - /// Results across pages are non-overlapping and ordered by non-decreasing distance. + /// Results across pages are non-overlapping but not guaranteed to be monotonic with + /// respect to distance. + /// + /// Within a page, results ordered by non-decreasing distance. + /// /// When the search is exhausted, returns an empty `Vec`. pub fn next_page( &mut self, From feb2e97cf9473dbb71875dddb43bb21b4c09a0b6 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Mon, 18 May 2026 14:59:47 -0700 Subject: [PATCH 13/17] Commit simplification. --- diskann/src/graph/search/paged.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs index ab77829e3..34bdd16d6 100644 --- a/diskann/src/graph/search/paged.rs +++ b/diskann/src/graph/search/paged.rs @@ -116,11 +116,7 @@ where let computed_result_count = candidates.len(); self.computed_result.clear(); self.computed_result.append(&mut candidates); - self.next_result_index = 0; - if computed_result_count != self.search_param_l { - self.computed_result.truncate(computed_result_count); - } let remaining_need = k - result.len(); let leftover = std::cmp::min(remaining_need, computed_result_count); From 7344b0023070fd42201f2ca5b4c533474fa3b73f Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 19 May 2026 15:31:29 -0700 Subject: [PATCH 14/17] Add specialized noawait paged-search. --- Cargo.lock | 37 +- diskann-providers/Cargo.toml | 2 +- diskann-providers/src/index/wrapped_async.rs | 343 +++++++++++++++++++ 3 files changed, 362 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94c12d370..a4f7df368 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,9 +1236,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -1251,9 +1251,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -1261,15 +1261,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -1278,15 +1278,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -1295,15 +1295,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" @@ -1313,9 +1313,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -1325,7 +1325,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index c5447dca2..ef8d79182 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -36,7 +36,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tempfile = { workspace = true, optional = true } bf-tree = { workspace = true, optional = true } prost = "0.14.1" -futures-util.workspace = true +futures-util = { workspace = true, features = ["async-await"] } serde_json = { workspace = true, optional = true } vfs = { workspace = true, optional = true } diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 2f7ee997a..d546fcc81 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -20,6 +20,7 @@ use diskann::{ provider::{AsNeighbor, AsNeighborMut, DataProvider, Delete, SetElement}, utils::ONE, }; +use diskann_utils::Reborrow; use crate::storage::{LoadWith, StorageReadProvider}; @@ -386,6 +387,28 @@ where }) } + /// Begin a synchronous paged search over the index. + /// + /// This will construct a [`noawait::PagedSearch`] and initialize search with the + /// providers start points. Pages can be retrieved with [`noawait::PagedSearch::next`]. + /// + /// **Caution**: This method should only be used if is known that all functions reachable + /// via the implementation of [`SearchStrategy`] are known to be synchronous and never + /// truly await. This allows [`noawait::PagedSearch`] to be much more efficient. + pub fn paged_search_no_await( + &self, + strategy: S, + context: DP::Context, + query: T, + l_value: usize, + ) -> ANNResult> + where + T: for<'a> Reborrow<'a, Target: Copy + Send> + 'static, + S: for<'a> SearchStrategy>::Target> + 'static, + { + noawait::PagedSearch::new(self.inner.clone(), strategy, context, query, l_value) + } + pub fn count_reachable_nodes( &self, start_points: &[DP::InternalId], @@ -430,6 +453,210 @@ where } } +pub mod noawait { + //! Implementations of a synchronous wrapper around [`diskann::graph::DiskANNIndex`] that + //! assume the [`Accessor`] and associated implementations never truly `await` and are + //! in fact synchronous. + //! + //! With this assumption, we can perform lighter-weight communication with the index + //! by assuming that each `poll` returns ready. + //! + //! **Do not use this if your index ever actually await**: Doing so will lead to deadlock! + + use super::*; + + use std::{ + cell::RefCell, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, + }; + + use diskann::{ANNErrorKind, utils::VectorId}; + use diskann_utils::Reborrow; + use thiserror::Error; + + type Input = Rc>>; + type Output = Rc>>>>; + + fn channel() -> (Input, Output) + where + I: VectorId, + { + let input = Rc::new(RefCell::new(None)); + let output = Rc::new(RefCell::new(None)); + (input, output) + } + + fn step(fut: Pin<&mut dyn Future>) -> Option { + let mut cx = Context::from_waker(Waker::noop()); + match fut.poll(&mut cx) { + Poll::Ready(v) => Some(v), + Poll::Pending => None, + } + } + + /// A synchronous wrapper for [`graph::search::PagedSearch`] + /// + /// See: [`super::DiskANNIndex::paged_search_no_await`]. + pub struct PagedSearch { + // The `input` is wrapped in an `Option` so we can fuse `searcher` if it exits + // with an error. Polling a completed future risk panicking. + // + // We construct `searcher` to pull its next-page size from this input. + input: Option, + + // Output yielded from polling `searcher`. + output: Output, + + // We shut down the future by running `Drop`. Thus, the only way it can actually + // finish is if it returns with an error. + searcher: Pin>>, + } + + impl PagedSearch + where + I: VectorId, + { + /// Construct a new [`PagedSearch`]. + /// + /// This works by creating a small async task using [`graph::search::PagedSearch`] + /// internally. The requested k-nearest neighors are sent using a `Rc>` + /// channel and the actual neighbors are retrieved from a similar data structure. + /// + /// Under the assumption that the implementation of [`graph::search::PagedSearch`]'s + /// implementations are fully synchronous, we can directly poll this task instead + /// of going through a runtime since we (theoretically) control the only suspension + /// point. + /// + /// Doing so allows stepping the task state machine to be done with a single function + /// call to `Future::poll`. + /// + /// Obviously, if the "noawait" assumption is broken, then the inner async job may + /// yield before our control point, but we can detect this situation since no output + /// will be generated on the output channel. + /// + /// We rely on `Drop` to clean up the paged search resources. + pub(super) fn new( + index: Arc>, + strategy: S, + context: DP::Context, + query: T, + l_value: usize, + ) -> ANNResult + where + DP: DataProvider, + T: for<'a> Reborrow<'a, Target: Copy + Send> + 'static, + S: for<'a> SearchStrategy>::Target> + 'static, + { + // Prepare the input and output channels used to communicate with the search task. + let (input, output) = channel::(); + let input_clone = input.clone(); + let output_clone = output.clone(); + + // Create the search task. + let mut searcher: Pin>> = Box::pin(async move { + // The assumption of `noawait` is that this call will always resolve to + // `Poll::Ready`. + let mut state = match index + .paged_search(strategy, &context, query.reborrow(), l_value) + .await + { + Ok(state) => state, + Err(err) => return err, + }; + + loop { + // This is the await point that pauses the future. + // + // Under the "noawait" assumption, this should be the only point where + // this future ever yields `Pending` and is where we expect the future + // to stop every time we poll it. + futures_util::pending!(); + + // We control the invocation of poll and should always ensure that + // input is available. + let k_value = match input_clone.take() { + Some(value) => value, + None => return InternalInvariantViolated::MissingInput.into(), + }; + + // Step paged search and propagate any errors. + let page = match state.next_page(k_value).await { + Ok(page) => page, + Err(err) => return err, + }; + + // Send output to the caller. + output_clone.replace(Some(page)); + } + }); + + // Drive the inner future one step to initialize paged search. + if let Some(err) = step(searcher.as_mut()) { + return Err(err); + } + + let this = Self { + input: Some(input), + output, + searcher, + }; + Ok(this) + } + + /// Retrieve the next results from paged search, returning any errors. + /// + /// If [`next`](Self::next) previously returned with an error, it will continue + /// to do so. + pub fn next(&mut self, k: usize) -> ANNResult>> { + // Prepare input. We use the presence of the input channel to decide whether + // or not it is safe to poll search task. + match self.input.as_ref() { + Some(input) => input.replace(Some(k)), + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "paged searcher errored and is no longer runnable", + )); + } + }; + + // Progress the future. + // + // The only reason to return return `Some` is if the inner future aborts with + // an error. Here, we fuse the searcher to prevent panics on re-enters and + // forward the error. + if let Some(result) = step(self.searcher.as_mut()) { + self.input = None; + return Err(result); + } + + // Profit! + match self.output.take() { + Some(v) => Ok(v), + None => Err(InternalInvariantViolated::MissingOutput.into()), + } + } + } + + #[derive(Debug, Clone, Copy, Error)] + enum InternalInvariantViolated { + #[error("INTERNAL: input channel was not configured")] + MissingInput, + #[error("noawait contract violated: future suspended before expected yield point")] + MissingOutput, + } + + impl From for ANNError { + #[track_caller] + #[cold] + fn from(err: InternalInvariantViolated) -> Self { + Self::new(ANNErrorKind::Opaque, err) + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -552,4 +779,120 @@ mod tests { assert_eq!(ids[0], 0); assert_eq!(distances[0], 0.0); } + + fn wrapped_test_provider() -> DiskANNIndex { + let provider = + graph::test::provider::Provider::grid(graph::test::synthetic::Grid::One, 100).unwrap(); + + DiskANNIndex::new_with_current_thread_runtime( + graph::config::Builder::new( + provider.max_degree(), + diskann::graph::config::MaxDegree::same(), + 100, + (Metric::L2).into(), + ) + .build() + .unwrap(), + provider, + ) + } + + // Test the `noawait` paged searcher. + // + // This relies on the test-provider being no-await. + #[test] + fn test_paged_search_noawait() { + let index = wrapped_test_provider(); + + for page_size in [1, 5, 9, 12] { + let mut paged = index + .paged_search_no_await::<_, Vec>( + graph::test::provider::Strategy::new(), + graph::test::provider::Context::new(), + vec![0.0], + 10.max(page_size), + ) + .unwrap(); + + let mut i = 0u32; + loop { + let v = paged.next(page_size).unwrap(); + assert!( + v.len() <= page_size, + "candidates returned ({}) exceeded page size ({})", + v.len(), + page_size, + ); + + if v.is_empty() { + break; + } + + for neighbor in v { + assert_ne!( + neighbor.id, + u32::MAX, + "paged search should not return start point", + ); + assert_eq!( + neighbor.id, i, + "monotonicity should at least hold for the 1d grid" + ); + assert_eq!( + neighbor.distance, + (i as f32) * (i as f32), + "distance was computed incorrectly!", + ); + i += 1; + } + } + + // Search is exhausted - make sure that subsequent searches yield empty vectors. + let exhausted = paged.next(5).unwrap(); + assert!( + exhausted.is_empty(), + "expected an empty vector when exhausted - instead got {:?}", + exhausted + ); + } + } + + // Verify that the searcher is properly fused when it returns with an error. + #[test] + fn test_paged_search_noawait_fuse() { + let index = wrapped_test_provider(); + + // To do this test, we request more neighbors than the search-L, which triggers + // an inner error. + let search_l = 10; + let bigger_than_search_l = 20; + + let mut paged = index + .paged_search_no_await::<_, Vec>( + graph::test::provider::Strategy::new(), + graph::test::provider::Context::new(), + vec![0.0], + search_l, + ) + .unwrap(); + + let expected = "search_param_l"; + let err = paged.next(bigger_than_search_l).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains(expected), + "expected error message to contain \"{}\" - instead got\n\n{}", + expected, + msg, + ); + + // Now that we've yielded an error - the next time we request pages should also error. + let err = paged.next(10).unwrap_err(); + let err_msg = err.to_string(); + assert!( + err_msg.contains("paged searcher errored"), + "unexpected error message:\n\n{}", + err_msg + ); + } } From 5e403681ce435103fec63ae9fbbed9b11ea914a8 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 19 May 2026 16:56:07 -0700 Subject: [PATCH 15/17] Checkpoint. --- diskann-disk/src/search/provider/disk_provider.rs | 4 ++-- diskann-garnet/src/provider.rs | 6 +++--- .../inline_beta_search/encoded_document_accessor.rs | 6 +++--- diskann-label-filter/src/lib.rs | 10 +++++----- .../model/graph/provider/async_/bf_tree/provider.rs | 8 ++++---- .../graph/provider/async_/inmem/full_precision.rs | 6 +++--- .../src/model/graph/provider/async_/inmem/product.rs | 6 +++--- .../src/model/graph/provider/async_/inmem/scalar.rs | 6 +++--- .../model/graph/provider/async_/inmem/spherical.rs | 6 +++--- .../src/model/graph/provider/async_/inmem/test.rs | 4 ++-- .../src/model/graph/provider/layers/betafilter.rs | 8 ++++---- diskann/src/graph/glue.rs | 11 ++++++----- diskann/src/graph/index.rs | 10 +++++----- diskann/src/graph/search/diverse_search.rs | 2 +- diskann/src/graph/search/knn_search.rs | 2 +- diskann/src/graph/search/multihop_search.rs | 4 ++-- diskann/src/graph/search/range_search.rs | 4 ++-- diskann/src/graph/test/provider.rs | 2 +- 18 files changed, 53 insertions(+), 52 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index c9d27ee3f..353a65d60 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -20,7 +20,7 @@ use diskann::{ graph::{ self, glue::{ - self, DefaultPostProcessor, IdIterator, SearchExt, SearchPostProcess, SearchStrategy, + self, DefaultPostProcessor, Explore, IdIterator, SearchPostProcess, SearchStrategy, }, search::Knn, search_output_buffer, DiskANNIndex, @@ -525,7 +525,7 @@ where } } -impl SearchExt<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> +impl Explore<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 0be5be9a4..695650f43 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -10,8 +10,8 @@ use diskann::{ AdjacencyList, SearchOutputBuffer, config::defaults::MAX_OCCLUSION_SIZE, glue::{ - self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + self, DefaultPostProcessor, ExpandBeam, Explore, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchPostProcess, SearchStrategy, }, workingset::{self, map::Entry}, }, @@ -448,7 +448,7 @@ impl HasId for FullAccessor<'_, T> { type Id = u32; } -impl SearchExt for FullAccessor<'_, T> { +impl Explore for FullAccessor<'_, T> { fn starting_points(&self) -> impl Future>> + Send { let points = if self.provider.start_points_exist() { vec![0] diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0d16248dd..5562059fa 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, RwLock}; use diskann::{ error::{ErrorExt, IntoANNResult}, - graph::glue::{ExpandBeam, SearchExt}, + graph::glue::{Explore}, provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, ANNError, ANNErrorKind, }; @@ -139,9 +139,9 @@ where } } -impl SearchExt for EncodedDocumentAccessor +impl Explore for EncodedDocumentAccessor where - IA: SearchExt, + IA: Explore, { fn starting_points( &self, diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 0a6abcb9e..106845f98 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -17,11 +17,11 @@ pub mod utils { pub mod jsonl_reader; } -// pub mod inline_beta_search { -// pub mod encoded_document_accessor; -// pub mod inline_beta_filter; -// pub mod predicate_evaluator; -// } +pub mod inline_beta_search { + pub mod encoded_document_accessor; + pub mod inline_beta_filter; + pub mod predicate_evaluator; +} // Persisent Index Traits pub mod traits { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 8fde6f465..b2fa77074 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -20,8 +20,8 @@ use diskann::{ graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ - self, Batch, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - MultiInsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + self, Batch, DefaultPostProcessor, ExpandBeam, Explore, InplaceDeleteStrategy, + InsertStrategy, MultiInsertStrategy, PruneStrategy, SearchStrategy, }, workingset::{self, map}, }, @@ -934,7 +934,7 @@ where type Id = u32; } -impl SearchExt for FullAccessor<'_, T, Q, D> +impl Explore for FullAccessor<'_, T, Q, D> where T: VectorRepr, Q: AsyncFriendly, @@ -1098,7 +1098,7 @@ where type Id = u32; } -impl SearchExt for QuantAccessor<'_, T, D> +impl Explore for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index be9869a72..2682443b8 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -12,8 +12,8 @@ use diskann::{ graph::{ AdjacencyList, SearchOutputBuffer, glue::{ - self, DefaultPostProcessor, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, - SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchStrategy, }, workingset, }, @@ -144,7 +144,7 @@ where type Id = u32; } -impl SearchExt<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +impl Explore<&[T]> for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 1c52e883d..aadcbddc8 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -11,8 +11,8 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - self, DefaultPostProcessor, InplaceDeleteStrategy, InsertStrategy, PruneStrategy, - SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchStrategy, }, workingset, }, @@ -110,7 +110,7 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl SearchExt<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl Explore<&[T]> for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 7e7271c92..b9b5580dc 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -11,8 +11,8 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - self, DefaultPostProcessor, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, - SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, FilterStartPoints, InsertStrategy, Pipeline, + PruneStrategy, SearchStrategy, }, workingset, }, @@ -391,7 +391,7 @@ impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx type Id = u32; } -impl SearchExt<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> +impl Explore<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 23db9d296..d022eccec 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -12,8 +12,8 @@ use diskann::{ graph::{ AdjacencyList, glue::{ - self, DefaultPostProcessor, FilterStartPoints, InsertStrategy, Pipeline, PruneStrategy, - SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, FilterStartPoints, InsertStrategy, Pipeline, + PruneStrategy, SearchStrategy, }, workingset, }, @@ -329,7 +329,7 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl SearchExt<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl Explore<&[T]> for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 84dabdd8e..79aaf307b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -14,7 +14,7 @@ // graph::{ // glue::{ // CopyIds, DefaultPostProcessor, ExpandBeam, InsertStrategy, MultiInsertStrategy, -// PruneStrategy, SearchExt, SearchStrategy, +// PruneStrategy, Explore, SearchStrategy, // }, // workingset::map, // }, @@ -139,7 +139,7 @@ // // type FullAccessor<'a> = super::FullAccessor<'a, f32, DefaultQuant, Tda, DefaultContext>; // -// impl SearchExt for FlakyAccessor<'_> { +// impl Explore for FlakyAccessor<'_> { // fn starting_points(&self) -> impl Future>> { // std::future::ready(self.provider.starting_points()) // } diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 7c4cb10ec..a56402dc2 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -20,7 +20,7 @@ use diskann::{ error::StandardError, graph::{ SearchOutputBuffer, - glue::{self, SearchExt, SearchPostProcessStep, SearchStrategy}, + glue::{self, Explore, SearchPostProcessStep, SearchStrategy}, index::QueryLabelProvider, }, neighbor::Neighbor, @@ -178,9 +178,9 @@ where } } -impl SearchExt for BetaAccessor +impl Explore for BetaAccessor where - Inner: SearchExt, + Inner: Explore, { fn starting_points(&self) -> impl Future>> + Send { self.inner.starting_points() @@ -233,7 +233,7 @@ where Inner: BuildQueryComputer + HasId, { /// Use the inner `QueryComputer`. Application of the beta filter happens in the - /// closure for [`SearchExt::expand_beam`]. + /// closure for [`Explore::expand_beam`]. type QueryComputer = Inner::QueryComputer; /// Use the same error as `Inner`. diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 77aad705f..527a09ef0 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -95,7 +95,7 @@ use crate::{ /// A trait to override search constraints such as early termination based on constraints /// by implementer. -pub trait SearchExt: BuildQueryComputer + HasId + Send + Sync { +pub trait Explore: BuildQueryComputer + HasId + Send + Sync { /// Return a `Vec` containing the starting points. fn starting_points(&self) -> impl std::future::Future>> + Send; @@ -276,13 +276,14 @@ where /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. - type SearchAccessor<'a>: SearchExt; + type SearchAccessor<'a>: Explore; /// Construct and return the search accessor. fn search_accessor<'a>( &'a self, provider: &'a Provider, context: &'a Provider::Context, + query: T, ) -> Result, Self::SearchAccessorError>; } @@ -438,7 +439,7 @@ pub struct FilterStartPoints; impl SearchPostProcessStep for FilterStartPoints where - A: BuildQueryComputer + SearchExt, + A: BuildQueryComputer + Explore, T: Copy + Send + Sync, { /// A this level, sub-errors are converted into [`ANNError`] to provide additional @@ -753,7 +754,7 @@ where /// of associated types. /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. - type DeleteSearchAccessor<'a>: SearchExt, Id = Provider::InternalId>; + type DeleteSearchAccessor<'a>: Explore, Id = Provider::InternalId>; /// The processor used during the delete-search phase. type SearchPostProcessor: for<'a> SearchPostProcess, Self::DeleteElement<'a>> @@ -848,7 +849,7 @@ mod tests { #[derive(Clone, Copy)] struct Accessor; - impl SearchExt for Accessor { + impl Explore for Accessor { async fn starting_points(&self) -> ANNResult> { unimplemented!(); } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 5a51aa184..ede7aec56 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -25,8 +25,8 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ - self, Batch, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, PruneStrategy, - SearchExt, SearchPostProcess, SearchStrategy, + self, Batch, Explore, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, + PruneStrategy, SearchPostProcess, SearchStrategy, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ @@ -1274,7 +1274,7 @@ where let search_strategy = strategy.search_strategy(); let mut search_accessor = search_strategy - .search_accessor(&self.data_provider, context) + .search_accessor(&self.data_provider, context, v.reborrow()) .into_ann_result()?; let computer = search_accessor @@ -2019,7 +2019,7 @@ where search_record: &mut SR, ) -> impl SendFuture> where - A: SearchExt, + A: Explore, SR: SearchRecord + ?Sized, Q: NeighborQueue, { @@ -2304,7 +2304,7 @@ where async move { let (computer, scratch) = { let mut accessor = strategy - .search_accessor(&self.data_provider, context) + .search_accessor(&self.data_provider, context, query) .into_ann_result()?; let computer = accessor.build_query_computer(query).into_ann_result()?; diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index f09a63cb9..0f8e2342e 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -14,7 +14,7 @@ use crate::{ error::IntoANNResult, graph::{ DiverseSearchParams, - glue::{SearchExt, SearchPostProcess, SearchStrategy}, + glue::{Explore, SearchPostProcess, SearchStrategy}, index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 9468e4204..2eb7618bc 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -15,7 +15,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{SearchExt, SearchPostProcess, SearchStrategy}, + glue::{Explore, SearchPostProcess, SearchStrategy}, index::{DiskANNIndex, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index 6baa41f7d..70ec0680c 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -14,7 +14,7 @@ use crate::{ error::IntoANNResult, graph::{ glue::{ - self, HybridPredicate, Predicate, PredicateMut, SearchExt, SearchPostProcess, + self, Explore, HybridPredicate, Predicate, PredicateMut, SearchPostProcess, SearchStrategy, }, index::{ @@ -180,7 +180,7 @@ pub(crate) async fn multihop_search_internal( ) -> ANNResult where I: VectorId, - A: SearchExt, + A: Explore, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 8de7aeafa..1d4dd1b3b 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -13,7 +13,7 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, SearchExt, SearchStrategy}, + glue::{self, Explore, SearchStrategy}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::{self, SearchOutputBuffer}, @@ -331,7 +331,7 @@ pub(crate) async fn range_search_internal( ) -> ANNResult where I: crate::utils::VectorId, - A: SearchExt, + A: Explore, { let beam_width = search_params.beam_width().unwrap_or(1); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 33c37f65f..50debe7a2 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1103,7 +1103,7 @@ impl provider::BuildDistanceComputer for Accessor<'_> { // Glue // //------// -impl glue::SearchExt<&[f32]> for Accessor<'_> { +impl glue::Explore<&[f32]> for Accessor<'_> { fn starting_points(&self) -> impl Future>> + Send { futures_util::future::ok(self.provider.config.start_points.keys().copied().collect()) } From 5083dc5998ad71f6f1d11807a4afb78d7d8485a7 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 20 May 2026 11:55:46 -0700 Subject: [PATCH 16/17] Not sure about this line of thought. --- .../src/search/provider/disk_provider.rs | 4 +- .../encoded_document_accessor.rs | 2 +- diskann-providers/src/index/diskann_async.rs | 6 +- diskann-providers/src/index/wrapped_async.rs | 136 ++++++------ .../graph/provider/async_/bf_tree/provider.rs | 39 ++-- .../provider/async_/inmem/full_precision.rs | 46 ++-- .../graph/provider/async_/inmem/product.rs | 45 ++-- .../graph/provider/async_/inmem/scalar.rs | 58 +++-- .../graph/provider/async_/inmem/spherical.rs | 58 +++-- .../graph/provider/async_/postprocess.rs | 5 +- .../model/graph/provider/layers/betafilter.rs | 38 +--- diskann/src/graph/glue.rs | 78 ++----- diskann/src/graph/index.rs | 46 ++-- diskann/src/graph/search/diverse_search.rs | 2 +- diskann/src/graph/search/knn_search.rs | 11 +- diskann/src/graph/search/multihop_search.rs | 14 +- diskann/src/graph/search/paged.rs | 36 ++-- diskann/src/graph/search/range_search.rs | 12 +- diskann/src/graph/test/cases/paged_search.rs | 24 +-- diskann/src/graph/test/provider.rs | 199 +++++++++--------- diskann/src/provider.rs | 26 --- 21 files changed, 373 insertions(+), 512 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index a4395a75b..bf720a53f 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -18,9 +18,7 @@ use diskann::{ error::IntoANNResult, graph::{ self, - glue::{ - self, DefaultPostProcessor, Explore, SearchPostProcess, SearchStrategy, - }, + glue::{self, DefaultPostProcessor, Explore, SearchPostProcess, SearchStrategy}, search::Knn, search_output_buffer, DiskANNIndex, }, diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 5562059fa..64c81620b 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, RwLock}; use diskann::{ error::{ErrorExt, IntoANNResult}, - graph::glue::{Explore}, + graph::glue::Explore, provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, ANNError, ANNErrorKind, }; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index ceb9b63e9..ee2adcd4c 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -174,7 +174,7 @@ pub(crate) mod tests { }, neighbor::Neighbor, provider::{ - AsNeighbor, AsNeighborMut, BuildQueryComputer, DataProvider, DefaultContext, Delete, + AsNeighbor, AsNeighborMut, DataProvider, DefaultContext, Delete, ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, }, utils::{IntoUsize, ONE}, @@ -2315,7 +2315,9 @@ pub(crate) mod tests { // Ensure that the query computer used for insertion uses the `SameAsData` layout. let strategy = inmem::spherical::Quantized::build(); - let accessor = strategy.search_accessor(index.provider(), ctx, data.row(0)).unwrap(); + let accessor = strategy + .search_accessor(index.provider(), ctx, data.row(0)) + .unwrap(); let computer = accessor.build_query_computer(data.row(0)).unwrap(); assert_eq!( computer.layout(), diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 60fb814ef..40a9aeea7 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -340,52 +340,52 @@ where ) } - /// Begin a paged search over the index (synchronous wrapper). - /// - /// Returns a [`PagedSearch`] handle. See - /// [`PagedSearch::next_page`] for retrieving results. - pub fn paged_search<'a, S, T>( - &'a self, - strategy: S, - context: &'a DP::Context, - query: T, - l_value: usize, - ) -> ANNResult> - where - S: SearchStrategy + 'static, - T: Copy + Send + 'a, - { - let inner = self - .handle - .block_on(self.inner.paged_search(strategy, context, query, l_value))?; - Ok(PagedSearch { - handle: self.handle.clone(), - inner, - }) - } - - /// Begin a paged search with explicit initial seed IDs (synchronous wrapper). - pub fn paged_search_with_init_ids<'a, S, T>( - &'a self, - strategy: S, - context: &'a DP::Context, - query: T, - l_value: usize, - init_ids: Option<&'a [DP::InternalId]>, - ) -> ANNResult> - where - S: SearchStrategy + 'static, - T: Copy + Send + 'a, - { - let inner = self.handle.block_on( - self.inner - .paged_search_with_init_ids(strategy, context, query, l_value, init_ids), - )?; - Ok(PagedSearch { - handle: self.handle.clone(), - inner, - }) - } + // /// Begin a paged search over the index (synchronous wrapper). + // /// + // /// Returns a [`PagedSearch`] handle. See + // /// [`PagedSearch::next_page`] for retrieving results. + // pub fn paged_search<'a, S, T>( + // &'a self, + // strategy: S, + // context: &'a DP::Context, + // query: T, + // l_value: usize, + // ) -> ANNResult> + // where + // S: SearchStrategy + 'static, + // T: Copy + Send + 'a, + // { + // let inner = self + // .handle + // .block_on(self.inner.paged_search(strategy, context, query, l_value))?; + // Ok(PagedSearch { + // handle: self.handle.clone(), + // inner, + // }) + // } + + // /// Begin a paged search with explicit initial seed IDs (synchronous wrapper). + // pub fn paged_search_with_init_ids<'a, S, T>( + // &'a self, + // strategy: S, + // context: &'a DP::Context, + // query: T, + // l_value: usize, + // init_ids: Option<&'a [DP::InternalId]>, + // ) -> ANNResult> + // where + // S: SearchStrategy + 'static, + // T: Copy + Send + 'a, + // { + // let inner = self.handle.block_on( + // self.inner + // .paged_search_with_init_ids(strategy, context, query, l_value, init_ids), + // )?; + // Ok(PagedSearch { + // handle: self.handle.clone(), + // inner, + // }) + // } /// Begin a synchronous paged search over the index. /// @@ -431,28 +431,28 @@ where } } -/// Synchronous wrapper around [`graph::search::PagedSearch`] that owns a tokio runtime handle. -/// -/// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) -/// blocks the current thread to drive the underlying async search forward. -pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { - handle: tokio::runtime::Handle, - inner: graph::search::PagedSearch<'a, DP, S, T>, -} - -impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> -where - DP: DataProvider, - S: SearchStrategy, - T: Copy + Send, -{ - /// Returns the next page of at most `k` nearest-neighbor results. - /// - /// Blocks the current thread. Returns an empty `Vec` when the search is exhausted. - pub fn next_page(&mut self, k: usize) -> ANNResult>> { - self.handle.block_on(self.inner.next_page(k)) - } -} +// /// Synchronous wrapper around [`graph::search::PagedSearch`] that owns a tokio runtime handle. +// /// +// /// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) +// /// blocks the current thread to drive the underlying async search forward. +// pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { +// handle: tokio::runtime::Handle, +// inner: graph::search::PagedSearch<'a, DP, S, T>, +// } +// +// impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +// where +// DP: DataProvider, +// S: SearchStrategy, +// T: Copy + Send, +// { +// /// Returns the next page of at most `k` nearest-neighbor results. +// /// +// /// Blocks the current thread. Returns an empty `Vec` when the search is exhausted. +// pub fn next_page(&mut self, k: usize) -> ANNResult>> { +// self.handle.block_on(self.inner.next_page(k)) +// } +// } pub mod noawait { //! Implementations of a synchronous wrapper around [`diskann::graph::DiskANNIndex`] that diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 27e8dc275..f87d1e936 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -27,7 +27,7 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, + Accessor, BuildDistanceComputer, DataProvider, DefaultContext, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, @@ -911,7 +911,6 @@ where /// /// * [`Accessor`] for the [`BfTreeProvider`]. /// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. /// pub struct FullAccessor<'a, T, Q, D> where @@ -1034,22 +1033,22 @@ where } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} +// impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D> +// where +// T: VectorRepr, +// Q: AsyncFriendly, +// D: AsyncFriendly, +// { +// type QueryComputerError = Panics; +// type QueryComputer = T::QueryDistance; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// Ok(T::query_distance(from, self.provider.metric)) +// } +// } impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D> where T: VectorRepr, @@ -1415,7 +1414,7 @@ where Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, { - type QueryComputer = T::QueryDistance; + // type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D>; type SearchAccessorError = Panics; @@ -1494,7 +1493,7 @@ where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - type QueryComputer = pq::distance::QueryComputer>; + // type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, T, D>; type SearchAccessorError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 26897e065..76e1167df 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -19,7 +19,7 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, + BuildDistanceComputer, DefaultContext, DelegateNeighbor, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, @@ -112,7 +112,6 @@ where /// /// * [`Accessor`] for the [`DefaultProvider`]. /// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. pub struct FullAccessor<'a, T, Q, D, Ctx> where T: VectorRepr, @@ -144,7 +143,7 @@ where type Id = u32; } -impl Explore<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +impl Explore for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, @@ -157,7 +156,6 @@ where fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> impl std::future::Future> + Send where @@ -180,7 +178,6 @@ where fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, mut pred: P, mut on_neighbors: F, ) -> impl std::future::Future> + Send @@ -288,23 +285,23 @@ where } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} +// impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +// where +// T: VectorRepr, +// Q: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type QueryComputerError = Panics; +// type QueryComputer = T::QueryDistance; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// Ok(T::query_distance(from, self.provider.metric)) +// } +// } //-------------------// // In-mem Extensions // @@ -343,7 +340,7 @@ pub struct Rerank; impl<'a, A, T> glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<&'a [T]> + HasId + GetFullPrecision + AsDeletionCheck, + A: HasId + GetFullPrecision + AsDeletionCheck, { type Error = Panics; @@ -351,7 +348,6 @@ where &self, accessor: &mut A, query: &'a [T], - _computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send @@ -400,7 +396,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = T::QueryDistance; + // type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; type SearchAccessorError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 63f333567..429ae02c6 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -17,7 +17,7 @@ use diskann::{ workingset, }, provider::{ - BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, + BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, @@ -91,7 +91,6 @@ where /// This type implements the following traits: /// /// * [`Accessor`] for the `DefaultProvider`. -/// * [`BuildQueryComputer`]. pub struct QuantAccessor<'a, V, D, Ctx> { provider: &'a DefaultProvider, } @@ -110,7 +109,7 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl Explore<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl Explore for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, @@ -123,7 +122,6 @@ where fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> impl std::future::Future> + Send where @@ -146,7 +144,6 @@ where fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, mut pred: P, mut on_neighbors: F, ) -> impl std::future::Future> + Send @@ -210,23 +207,23 @@ where type ElementRef<'a> = &'a [u8]; } -impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = ANNError; - type QueryComputer = pq::distance::QueryComputer>; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - self.provider.aux_vectors.query_computer(from) - } -} +// impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> +// where +// T: VectorRepr, +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type QueryComputerError = ANNError; +// type QueryComputer = pq::distance::QueryComputer>; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// self.provider.aux_vectors.query_computer(from) +// } +// } impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> where @@ -491,7 +488,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = pq::distance::QueryComputer>; + // type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = Panics; @@ -641,7 +638,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = pq::distance::QueryComputer>; + // type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; type SearchAccessorError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index dd9868323..888a44217 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -17,7 +17,7 @@ use diskann::{ workingset, }, provider::{ - BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, + BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, @@ -391,7 +391,7 @@ impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx type Id = u32; } -impl Explore<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> +impl Explore for QuantAccessor<'_, NBITS, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, @@ -406,7 +406,6 @@ where fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> impl std::future::Future> + Send where @@ -424,7 +423,6 @@ where fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, mut pred: P, mut on_neighbors: F, ) -> impl std::future::Future> + Send @@ -512,30 +510,30 @@ where } } -impl BuildQueryComputer<&[T]> - for QuantAccessor<'_, NBITS, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, - Unsigned: Representation, - QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, -{ - type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - // Allow rescaling if this is search. - Ok(self - .provider - .aux_vectors - .query_computer(from, self.is_search)?) - } -} +// impl BuildQueryComputer<&[T]> +// for QuantAccessor<'_, NBITS, V, D, Ctx> +// where +// T: VectorRepr, +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// Unsigned: Representation, +// QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, +// { +// type QueryComputerError = ANNError; +// type QueryComputer = QueryComputer; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// // Allow rescaling if this is search. +// Ok(self +// .provider +// .aux_vectors +// .query_computer(from, self.is_search)?) +// } +// } impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> where @@ -584,7 +582,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - type QueryComputer = QueryComputer; + // type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -622,7 +620,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - type QueryComputer = QueryComputer; + // type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, NoStore, D, Ctx>; type SearchAccessorError = ANNError; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 7c6e840aa..565e6214f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -18,7 +18,7 @@ use diskann::{ workingset, }, provider::{ - BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, + BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, @@ -329,7 +329,7 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl Explore<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl Explore for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, @@ -342,7 +342,6 @@ where fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> impl std::future::Future> + Send where @@ -364,7 +363,6 @@ where fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, mut pred: P, mut on_neighbors: F, ) -> impl std::future::Future> + Send @@ -430,28 +428,28 @@ where } } -impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Bridge; - type QueryComputer = - UnwrapErr; - - fn build_query_computer( - &self, - query: &[T], - ) -> Result { - self.provider - .aux_vectors - .query_computer(query, self.layout, self.is_search) - .bridge_err() - .map(UnwrapErr::new) - } -} +// impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> +// where +// T: VectorRepr, +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type QueryComputerError = Bridge; +// type QueryComputer = +// UnwrapErr; +// +// fn build_query_computer( +// &self, +// query: &[T], +// ) -> Result { +// self.provider +// .aux_vectors +// .query_computer(query, self.layout, self.is_search) +// .bridge_err() +// .map(UnwrapErr::new) +// } +// } #[derive(Debug, Error)] #[error("unconstructible")] @@ -537,8 +535,8 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = - UnwrapErr; + // type QueryComputer = + // UnwrapErr; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -571,8 +569,8 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = - UnwrapErr; + // type QueryComputer = + // UnwrapErr; type SearchAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; type SearchAccessorError = ANNError; diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index dbbf08fa4..ff62f1ea0 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::{BuildQueryComputer, HasId}, + provider::{HasId}, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -39,7 +39,7 @@ pub struct RemoveDeletedIdsAndCopy; impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + HasId + AsDeletionCheck, + A: HasId + AsDeletionCheck, { type Error = std::convert::Infallible; @@ -47,7 +47,6 @@ where &self, accessor: &mut A, _query: T, - _computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 4eafead10..49b75c527 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -24,7 +24,7 @@ use diskann::{ index::QueryLabelProvider, }, neighbor::Neighbor, - provider::{BuildQueryComputer, DataProvider, HasId}, + provider::{DataProvider, HasId}, utils::VectorId, }; @@ -67,7 +67,7 @@ pub struct Unwrap; /// Delegate post-processing to the inner strategy's post-processing routine. impl SearchPostProcessStep, T, O> for Unwrap where - A: BuildQueryComputer + HasId, + A: HasId, { type Error = NextError @@ -81,7 +81,6 @@ where next: &Next, accessor: &mut BetaAccessor, query: T, - computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send @@ -90,7 +89,7 @@ where B: SearchOutputBuffer + Send + ?Sized, Next: glue::SearchPostProcess, { - next.post_process(&mut accessor.inner, query, computer, candidates, output) + next.post_process(&mut accessor.inner, query, candidates, output) } } @@ -112,9 +111,9 @@ where /// accessor. type SearchAccessor<'a> = BetaAccessor>; - /// A [`PreprocessedDistanceFunction`] that combines applies the beta filtering factor - /// if the vector ID portion of `Element` satisfies the filter predicate. - type QueryComputer = Strategy::QueryComputer; + // /// A [`PreprocessedDistanceFunction`] that combines applies the beta filtering factor + // /// if the vector ID portion of `Element` satisfies the filter predicate. + // type QueryComputer = Strategy::QueryComputer; type SearchAccessorError = Strategy::SearchAccessorError; @@ -123,7 +122,7 @@ where &'a self, provider: &'a Provider, context: &'a Provider::Context, - query: T + query: T, ) -> Result, Self::SearchAccessorError> { Ok(BetaAccessor { inner: self.strategy.search_accessor(provider, context, query)?, @@ -229,25 +228,6 @@ where type Id = Inner::Id; } -impl BuildQueryComputer for BetaAccessor -where - Inner: BuildQueryComputer + HasId, -{ - /// Use the inner `QueryComputer`. Application of the beta filter happens in the - /// closure for [`Explore::expand_beam`]. - type QueryComputer = Inner::QueryComputer; - - /// Use the same error as `Inner`. - type QueryComputerError = Inner::QueryComputerError; - - fn build_query_computer( - &self, - from: T, - ) -> Result { - self.inner.build_query_computer(from) - } -} - /////////// // Tests // /////////// @@ -322,7 +302,9 @@ mod tests { let mut visited = HashSet::new(); let mut buf = Vec::new(); - let mut accessor = strategy.search_accessor(&provider, &context, &[0.0, 0.0]).unwrap(); + let mut accessor = strategy + .search_accessor(&provider, &context, &[0.0, 0.0]) + .unwrap(); assert_eq!( &*accessor.starting_points().await.unwrap(), diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 54eb4c43c..778db1e39 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -86,22 +86,18 @@ use crate::{ error::StandardError, graph::{SearchOutputBuffer, workingset}, neighbor::Neighbor, - provider::{ - AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, HasElementRef, - HasId, - }, + provider::{AsNeighborMut, BuildDistanceComputer, DataProvider, HasElementRef, HasId}, }; /// A trait to override search constraints such as early termination based on constraints /// by implementer. -pub trait Explore: BuildQueryComputer + HasId + Send + Sync { +pub trait Explore: HasId + Send + Sync { /// Return a `Vec` containing the starting points. fn starting_points(&self) -> impl std::future::Future>> + Send; fn start_point_distances( &mut self, - computer: &Self::QueryComputer, f: F, ) -> impl std::future::Future> + Send where @@ -154,7 +150,6 @@ pub trait Explore: BuildQueryComputer + HasId + Send + Sync { fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, pred: P, on_neighbors: F, ) -> impl std::future::Future> + Send @@ -263,19 +258,13 @@ pub trait SearchStrategy: Send + Sync where Provider: DataProvider, { - /// The computer used by the associated accessor. - /// - /// We could grab this type from the `SearchAccessor` associated type, but it's - /// useful enough that we move it up here. - type QueryComputer: Send + Sync + 'static; - /// An error that can occur when getting a search_accessor. type SearchAccessorError: StandardError + for<'a> DoesNotCaptureSelf>; /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. - type SearchAccessor<'a>: Explore; + type SearchAccessor<'a>: Explore; /// Construct and return the search accessor. fn search_accessor<'a>( @@ -353,7 +342,7 @@ macro_rules! default_post_processor { /// directly into the output buffer. pub trait SearchPostProcess::Id> where - A: BuildQueryComputer + HasId, + A: HasId, { type Error: StandardError; @@ -363,7 +352,6 @@ where &self, accessor: &mut A, query: T, - computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -379,14 +367,13 @@ pub struct CopyIds; impl SearchPostProcess for CopyIds where - A: BuildQueryComputer + HasId, + A: HasId, { type Error = std::convert::Infallible; fn post_process( &self, _accessor: &mut A, _query: T, - _computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -404,7 +391,7 @@ where /// using a [`Pipeline`]. pub trait SearchPostProcessStep::Id> where - A: BuildQueryComputer + HasId, + A: HasId, { /// A potentially modified version of the error yielded by the next state in the /// processing pipeline. @@ -413,7 +400,7 @@ where NextError: StandardError; /// The accessor that will be passed to the next processing stage. - type NextAccessor: BuildQueryComputer + HasId; + type NextAccessor: HasId; /// Perform any modification the `input`, `output`, `accessor`, or `computer` objects /// and invoke the [`SearchPostProcess`] routine `next` on stage. @@ -422,7 +409,6 @@ where next: &Next, accessor: &mut A, query: T, - computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future>> + Send @@ -438,7 +424,7 @@ pub struct FilterStartPoints; impl SearchPostProcessStep for FilterStartPoints where - A: BuildQueryComputer + Explore, + A: Explore, T: Copy + Send + Sync, { /// A this level, sub-errors are converted into [`ANNError`] to provide additional @@ -456,7 +442,6 @@ where next: &Next, accessor: &mut A, query: T, - computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> ANNResult @@ -466,18 +451,12 @@ where Next: SearchPostProcess + Sync, { let filter = accessor.is_not_start_point().await?; - next.post_process( - accessor, - query, - computer, - candidates.filter(|n| filter(n.id)), - output, - ) - .await - .map_err(|err| { - let err = err.into(); - err.context("after filtering start points") - }) + next.post_process(accessor, query, candidates.filter(|n| filter(n.id)), output) + .await + .map_err(|err| { + let err = err.into(); + err.context("after filtering start points") + }) } } @@ -507,7 +486,7 @@ impl Pipeline { impl SearchPostProcess for Pipeline where - A: BuildQueryComputer + HasId, + A: HasId, Head: SearchPostProcessStep, Tail: SearchPostProcess + Sync, { @@ -517,7 +496,6 @@ where &self, accessor: &mut A, query: T, - computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -526,7 +504,7 @@ where B: SearchOutputBuffer + Send + ?Sized, { self.head - .post_process_step(&self.tail, accessor, query, computer, candidates, output) + .post_process_step(&self.tail, accessor, query, candidates, output) } } @@ -754,7 +732,7 @@ where /// of associated types. /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. - type DeleteSearchAccessor<'a>: Explore, Id = Provider::InternalId>; + type DeleteSearchAccessor<'a>: Explore; /// The processor used during the delete-search phase. type SearchPostProcessor: for<'a> SearchPostProcess, Self::DeleteElement<'a>> @@ -831,16 +809,12 @@ mod tests { #[derive(Clone, Copy)] struct Accessor; - impl Explore for Accessor { + impl Explore for Accessor { async fn starting_points(&self) -> ANNResult> { unimplemented!(); } - async fn start_point_distances( - &mut self, - _computer: &Self::QueryComputer, - _f: F, - ) -> ANNResult<()> + async fn start_point_distances(&mut self, _f: F) -> ANNResult<()> where F: FnMut(Self::Id, f32) + Send, { @@ -850,7 +824,6 @@ mod tests { async fn expand_beam( &mut self, _ids: Itr, - _computer: &Self::QueryComputer, _pred: P, _on_neighbors: F, ) -> ANNResult<()> @@ -867,22 +840,11 @@ mod tests { type Id = u32; } - struct QueryComputer; - - impl BuildQueryComputer for Accessor { - type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; - fn build_query_computer(&self, _from: f32) -> Result { - Ok(QueryComputer) - } - } - // This strategy explicitly does not define `post_process` so we can test the provided // implementation. struct Strategy; impl SearchStrategy for Strategy { - type QueryComputer = QueryComputer; type SearchAccessorError = ANNError; type SearchAccessor<'a> = Accessor; @@ -911,7 +873,6 @@ mod tests { let query = 11.5; let mut accessor = strategy.search_accessor(&provider, &ctx, query).unwrap(); - let computer = accessor.build_query_computer(query).unwrap(); for input_len in 0..10 { let input: Vec<_> = (0..input_len) @@ -925,7 +886,6 @@ mod tests { .post_process( &mut accessor, query, - &computer, input.iter().copied(), &mut neighbor::BackInserter::new(output.as_mut_slice()), ) diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 907f4be5f..09496e74b 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -44,8 +44,8 @@ use crate::{ internal, neighbor::{self, Neighbor, NeighborQueue}, provider::{ - AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, Delete, - ElementStatus, ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, + AsNeighbor, AsNeighborMut, BuildDistanceComputer, DataProvider, Delete, ElementStatus, + ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, }, tracked_debug, tracked_error, tracked_trace, utils::{ @@ -278,8 +278,6 @@ where .insert_search_accessor(&self.data_provider, context, vector) .into_ann_result()?; - let computer = accessor.build_query_computer(vector).into_ann_result()?; - // NOTE: We don't filter the start points out of `visited_nodes`, as those are // needed to generate out edges from the start points. let start_ids = accessor.starting_points().await?; @@ -308,7 +306,6 @@ where self.search_internal( None, // beam_width &mut accessor, - &computer, &mut scratch, &mut search_record, ) @@ -417,9 +414,6 @@ where .insert_search_accessor(&self.data_provider, context, vector) .into_ann_result()?; - let computer = accessor - .build_query_computer(vector) - .into_ann_result()?; let start_ids = accessor.starting_points().await?; let mut scratch = self.search_scratch(self.l_build(), start_ids.len()); @@ -438,7 +432,6 @@ where self.search_internal( None, // beam_width &mut accessor, - &computer, &mut scratch, &mut search_record, ) @@ -1242,10 +1235,6 @@ where .search_accessor(&self.data_provider, context, v.reborrow()) .into_ann_result()?; - let computer = search_accessor - .build_query_computer(v.reborrow()) - .into_ann_result()?; - let start_ids = search_accessor.starting_points().await?; let mut scratch = self.search_scratch(l_value, start_ids.len()); @@ -1253,7 +1242,6 @@ where self.search_internal( None, // beam_width &mut search_accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), ) @@ -1268,7 +1256,6 @@ where .post_process( &mut search_accessor, v.reborrow(), - &computer, scratch.best.iter(), &mut neighbor::BackInserter::new(output.as_mut_slice()), ) @@ -1974,17 +1961,15 @@ where } } - // A is the accessor type, T is the query type used for BuildQueryComputer - pub(crate) fn search_internal( + pub(crate) fn search_internal( &self, beam_width: Option, accessor: &mut A, - computer: &A::QueryComputer, scratch: &mut SearchScratch, search_record: &mut SR, ) -> impl SendFuture> where - A: Explore, + A: Explore, SR: SearchRecord + ?Sized, Q: NeighborQueue, { @@ -1995,7 +1980,7 @@ where // state if not already initialized. if scratch.visited.is_empty() { accessor - .start_point_distances(computer, |id, dist| { + .start_point_distances(|id, dist| { scratch.visited.insert(id); scratch.best.insert(Neighbor::new(id, dist)); scratch.cmps += 1; @@ -2020,7 +2005,6 @@ where accessor .expand_beam( scratch.beam_nodes.iter().copied(), - computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| neighbors.push(Neighbor::new(id, distance)), ) @@ -2120,11 +2104,11 @@ where /// yields successive pages of nearest-neighbor results. pub fn paged_search<'a, S, T>( &'a self, - strategy: S, + strategy: &'a S, context: &'a DP::Context, query: T, l_value: usize, - ) -> impl SendFuture>> + ) -> impl SendFuture>>> where S: SearchStrategy, T: Copy + Send + 'a, @@ -2141,24 +2125,22 @@ where /// provide custom starting points for the graph traversal. pub fn paged_search_with_init_ids<'a, S, T>( &'a self, - strategy: S, + strategy: &'a S, context: &'a DP::Context, query: T, l_value: usize, init_ids: Option<&'a [DP::InternalId]>, - ) -> impl SendFuture>> + ) -> impl SendFuture>>> where S: SearchStrategy, T: Copy + Send + 'a, { async move { - let (computer, scratch) = { + let (accessor, scratch) = { let mut accessor = strategy .search_accessor(&self.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; let num_start_points = start_ids.len(); @@ -2178,7 +2160,6 @@ where accessor .expand_beam( init_ids.iter().copied(), - &computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| neighbors.push(Neighbor::new(id, distance)), ) @@ -2189,19 +2170,16 @@ where .iter() .for_each(|neighbor| scratch.best.insert(*neighbor)); - (computer, scratch) + (accessor, scratch) }; ANNResult::Ok(PagedSearch { index: self, - context, scratch, computed_result: vec![Neighbor::default(); l_value], next_result_index: l_value, search_param_l: l_value, - strategy, - computer, - query, + accessor, }) } } diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index 0f8e2342e..89544278a 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -19,7 +19,7 @@ use crate::{ search_output_buffer::SearchOutputBuffer, }, neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue}, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, }; /// Parameters for diversity-aware search. diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index 8e9aea627..9434addde 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -20,7 +20,7 @@ use crate::{ search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, }, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, }; /// Error type for [`Knn`] parameter validation. @@ -194,23 +194,20 @@ where .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); let stats = index .search_internal( Some(self.beam_width.get()), &mut accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), ) .await?; let result_count = processor - .post_process(&mut accessor, query, &computer, scratch.best.iter(), output) + .post_process(&mut accessor, query, scratch.best.iter(), output) .await .into_ann_result()?; @@ -269,7 +266,6 @@ where .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; let mut scratch = index.search_scratch(self.inner.l_value.get(), start_ids.len()); @@ -278,14 +274,13 @@ where .search_internal( Some(self.inner.beam_width.get()), &mut accessor, - &computer, &mut scratch, self.recorder, ) .await?; let result_count = processor - .post_process(&mut accessor, query, &computer, scratch.best.iter(), output) + .post_process(&mut accessor, query, scratch.best.iter(), output) .await .into_ann_result()?; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index b4d77435f..fe4cf1a16 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -24,7 +24,7 @@ use crate::{ search_output_buffer::SearchOutputBuffer, }, neighbor::Neighbor, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, utils::VectorId, }; @@ -77,7 +77,6 @@ where let mut accessor = strategy .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; @@ -87,7 +86,6 @@ where index.max_degree_with_slack(), &self.inner, &mut accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), self.label_evaluator, @@ -98,7 +96,6 @@ where .post_process( &mut accessor, query, - &computer, scratch.best.iter().take(self.inner.l_value().get()), output, ) @@ -169,18 +166,17 @@ impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId { /// /// Performs label-filtered search by expanding through non-matching nodes /// to find matching neighbors within two hops. -pub(crate) async fn multihop_search_internal( +pub(crate) async fn multihop_search_internal( max_degree_with_slack: usize, search_params: &Knn, accessor: &mut A, - computer: &A::QueryComputer, scratch: &mut SearchScratch, search_record: &mut SR, query_label_evaluator: &dyn QueryLabelProvider, ) -> ANNResult where I: VectorId, - A: Explore, + A: Explore, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); @@ -193,7 +189,7 @@ where }; accessor - .start_point_distances(computer, |id, dist| { + .start_point_distances(|id, dist| { scratch.visited.insert(id); scratch.best.insert(Neighbor::new(id, dist)); scratch.cmps += 1; @@ -224,7 +220,6 @@ where accessor .expand_beam( scratch.beam_nodes.iter().copied(), - computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| one_hop_neighbors.push(Neighbor::new(id, distance)), ) @@ -270,7 +265,6 @@ where accessor .expand_beam( two_hop_expansion_candidate_ids.iter().copied(), - computer, NotInMutWithLabelCheck::new(&mut scratch.visited, query_label_evaluator), |distance, id| { two_hop_neighbors.push(Neighbor::new(id, distance)); diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs index 4334a298f..f3670f494 100644 --- a/diskann/src/graph/search/paged.rs +++ b/diskann/src/graph/search/paged.rs @@ -7,10 +7,9 @@ use diskann_utils::future::SendFuture; use crate::{ ANNError, ANNResult, - error::IntoANNResult, graph::{ DiskANNIndex, - glue::{Explore, SearchStrategy}, + glue::Explore, search::{record::NoopSearchRecord, scratch::SearchScratch}, }, neighbor::{Neighbor, NeighborPriorityQueue}, @@ -25,23 +24,23 @@ use crate::{ /// /// See also: [`DiskANNIndex::paged_search`], [`DiskANNIndex::paged_search_with_init_ids`]. #[derive(Debug)] -pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { +pub struct PagedSearch<'a, DP, T> +where + DP: DataProvider, + T: Explore, +{ pub(in crate::graph) index: &'a DiskANNIndex, - pub(in crate::graph) context: &'a DP::Context, - pub(in crate::graph) scratch: SearchScratch, - pub(in crate::graph) computed_result: Vec>, + pub(in crate::graph) scratch: SearchScratch, + pub(in crate::graph) computed_result: Vec>, pub(in crate::graph) next_result_index: usize, pub(in crate::graph) search_param_l: usize, - pub(in crate::graph) strategy: S, - pub(in crate::graph) computer: S::QueryComputer, - pub(in crate::graph) query: T, + pub(in crate::graph) accessor: T, } -impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +impl<'a, DP, T> PagedSearch<'a, DP, T> where DP: DataProvider, - S: SearchStrategy, - T: Copy + Send, + T: Explore, { /// Returns the next page of at most `k` nearest-neighbor results. /// @@ -89,17 +88,16 @@ where // Resume graph search to fill the next batch. let start_points = { - let mut accessor = self - .strategy - .search_accessor(&self.index.data_provider, self.context, self.query) - .into_ann_result()?; + // let mut accessor = self + // .strategy + // .search_accessor(&self.index.data_provider, self.context, self.query) + // .into_ann_result()?; - let start_ids = accessor.starting_points().await?; + let start_ids = self.accessor.starting_points().await?; self.index .search_internal( None, // beam_width - &mut accessor, - &self.computer, + &mut self.accessor, &mut self.scratch, &mut NoopSearchRecord::new(), ) diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index 100dcdc84..6cbea685a 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -19,7 +19,7 @@ use crate::{ search_output_buffer::{self, SearchOutputBuffer}, }, neighbor::Neighbor, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, utils::IntoUsize, }; @@ -183,7 +183,6 @@ where let mut accessor = strategy .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; let mut scratch = index.search_scratch(self.starting_l(), start_ids.len()); @@ -192,7 +191,6 @@ where .search_internal( self.beam_width(), &mut accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), ) @@ -221,7 +219,6 @@ where index.max_degree_with_slack(), &self, &mut accessor, - &computer, &mut scratch, ) .await?; @@ -251,7 +248,6 @@ where .post_process( &mut accessor, query, - &computer, scratch.in_range.iter().copied(), &mut filtered, ) @@ -322,16 +318,15 @@ where /// /// Expands the search frontier to find all points within the specified radius. /// Called after the initial graph search has identified starting candidates. -pub(crate) async fn range_search_internal( +pub(crate) async fn range_search_internal( max_degree_with_slack: usize, search_params: &Range, accessor: &mut A, - computer: &A::QueryComputer, scratch: &mut SearchScratch, ) -> ANNResult where I: crate::utils::VectorId, - A: Explore, + A: Explore, { let beam_width = search_params.beam_width().unwrap_or(1); @@ -359,7 +354,6 @@ where accessor .expand_beam( scratch.beam_nodes.iter().copied(), - computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| neighbors.push(Neighbor::new(id, distance)), ) diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 61e2c8c45..511779201 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -155,13 +155,9 @@ fn basic_paged_search() { let page_size = 4; let ctx = test_provider::Context::new(); + let strategy = test_provider::Strategy::new(); let mut search = rt - .block_on(index.paged_search( - test_provider::Strategy::new(), - &ctx, - query.as_slice(), - search_l, - )) + .block_on(index.paged_search(&strategy, &ctx, query.as_slice(), search_l)) .unwrap(); let mut pages: Vec>> = Vec::new(); @@ -197,14 +193,10 @@ fn single_page() { let search_l = 200; let page_size = 200; // larger than total points (125) let ctx = test_provider::Context::new(); + let strategy = test_provider::Strategy::new(); let mut search = rt - .block_on(index.paged_search( - test_provider::Strategy::new(), - &ctx, - query.as_slice(), - search_l, - )) + .block_on(index.paged_search(&strategy, &ctx, query.as_slice(), search_l)) .unwrap(); let results = rt.block_on(search.next_page(page_size)).unwrap(); @@ -240,14 +232,10 @@ fn small_page_size() { let search_l = 32; let page_size = 1; // one result per page, maximum iterations let ctx = test_provider::Context::new(); + let strategy = test_provider::Strategy::new(); let mut search = rt - .block_on(index.paged_search( - test_provider::Strategy::new(), - &ctx, - query.as_slice(), - search_l, - )) + .block_on(index.paged_search(&strategy, &ctx, query.as_slice(), search_l)) .unwrap(); let mut pages: Vec>> = Vec::new(); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index e881646b9..a3432d23f 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1006,7 +1006,7 @@ impl provider::NeighborAccessorMut for NeighborAccessor<'_> { #[derive(Debug)] pub struct Accessor<'a> { provider: &'a Provider, - buffer: Box<[f32]>, + distance: ::QueryDistance, get_vector: LocalCounter<'a>, /// IDs that will produce transient errors when accessed. transient_ids: Option>>, @@ -1019,39 +1019,66 @@ impl<'a> Accessor<'a> { } /// Creates an accessor with no flaky behavior (backward-compatible). - pub fn new(provider: &'a Provider) -> Self { - Self::new_inner(provider, None) + pub fn new(provider: &'a Provider, query: &[f32]) -> Self { + Self::new_inner(provider, query, None) } /// Creates an accessor where `get_element` returns a transient error for /// any ID in `transient_ids`. The ID must still exist in the provider — /// accessing a truly missing ID remains a critical `InvalidId` error. - pub fn flaky(provider: &'a Provider, transient_ids: Cow<'a, HashSet>) -> Self { - Self::new_inner(provider, Some(transient_ids)) + pub fn flaky( + provider: &'a Provider, + query: &[f32], + transient_ids: Cow<'a, HashSet>, + ) -> Self { + Self::new_inner(provider, query, Some(transient_ids)) } - fn new_inner(provider: &'a Provider, transient_ids: Option>>) -> Self { - let buffer = (0..provider.dim()).map(|_| 0.0).collect(); + fn new_inner( + provider: &'a Provider, + query: &[f32], + transient_ids: Option>>, + ) -> Self { + // FIXME: Return error. + assert_eq!(query.len(), provider.dim()); + let distance = f32::query_distance(query, provider.distance_metric()); Self { provider, - buffer, + distance, get_vector: provider.get_vector.local(), transient_ids, } } - pub fn get(&mut self, id: u32) -> Result<&[f32], AccessError> { + fn check_flaky(&self, id: u32) -> Result<(), AccessError> { + if let Some(transient) = &self.transient_ids + && transient.contains(&id) + { + Err(AccessError::Transient(TransientAccessError::new(id))) + } else { + Ok(()) + } + } + + pub fn copy(&mut self, id: u32) -> Result, AccessError> { match self.provider.terms.get(&id) { Some(term) => { - if let Some(transient) = &self.transient_ids - && transient.contains(&id) - { - return Err(AccessError::Transient(TransientAccessError::new(id))); - } + self.check_flaky(id)?; + + self.get_vector.increment(); + Ok((*term.data).into()) + } + None => Err(AccessError::InvalidId(AccessedInvalidId(id))), + } + } + + pub fn get_distance(&mut self, id: u32) -> Result { + match self.provider.terms.get(&id) { + Some(term) => { + self.check_flaky(id)?; self.get_vector.increment(); - self.buffer.copy_from_slice(&term.data); - Ok(&*self.buffer) + Ok(self.distance.evaluate_similarity(&*term.data)) } None => Err(AccessError::InvalidId(AccessedInvalidId(id))), } @@ -1073,18 +1100,6 @@ impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { } } -impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { - type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(f32::query_distance(from, self.provider.config.metric)) - } -} - impl provider::BuildDistanceComputer for Accessor<'_> { type DistanceComputerError = Infallible; type DistanceComputer = ::Distance; @@ -1103,14 +1118,13 @@ impl provider::BuildDistanceComputer for Accessor<'_> { // Glue // //------// -impl glue::Explore<&[f32]> for Accessor<'_> { +impl glue::Explore for Accessor<'_> { fn starting_points(&self) -> impl Future>> + Send { futures_util::future::ok(self.provider.config.start_points.keys().copied().collect()) } fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> impl std::future::Future> + Send where @@ -1118,10 +1132,7 @@ impl glue::Explore<&[f32]> for Accessor<'_> { { async move { for &i in self.provider.config.start_points.keys() { - f( - i, - computer.evaluate_similarity(self.get(i).escalate("start points must exist")?), - ) + f(i, self.get_distance(i).escalate("start points must exist")?) } Ok(()) } @@ -1130,7 +1141,6 @@ impl glue::Explore<&[f32]> for Accessor<'_> { fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, mut pred: P, mut on_neighbors: F, ) -> impl std::future::Future> + Send @@ -1144,8 +1154,11 @@ impl glue::Explore<&[f32]> for Accessor<'_> { for id in ids { self.provider.get_neighbors(id, &mut neighbors)?; for &n in neighbors.iter().filter(|i| pred.eval_mut(i)) { - if let Some(buf) = self.get(n).allow_transient("transient failures allowed")? { - on_neighbors(computer.evaluate_similarity(buf), n) + if let Some(distance) = self + .get_distance(n) + .allow_transient("transient failures allowed")? + { + on_neighbors(distance, n) } } } @@ -1183,9 +1196,9 @@ impl workingset::Fill for Accessor<'_> { Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } Entry::Vacant(vacant) => { if let Some(buf) = - self.get(i).allow_transient("transient failures allowed")? + self.copy(i).allow_transient("transient failures allowed")? { - vacant.insert(buf.into()); + vacant.insert(buf); } } } @@ -1236,7 +1249,6 @@ impl Default for Strategy { } impl glue::SearchStrategy for Strategy { - type QueryComputer = ::QueryDistance; type SearchAccessorError = Infallible; type SearchAccessor<'a> = Accessor<'a>; @@ -1244,9 +1256,9 @@ impl glue::SearchStrategy for Strategy { &'a self, provider: &'a Provider, _context: &'a Context, - _query: &[f32], + query: &[f32], ) -> Result, Infallible> { - Ok(Accessor::new(provider)) + Ok(Accessor::new(provider, query)) } } @@ -1275,9 +1287,10 @@ impl glue::PruneStrategy for Strategy { provider: &'a Provider, _context: &'a Context, ) -> Result, Self::PruneAccessorError> { + // FIXME: Different Accessors? match &self.transient_ids { - Some(ids) => Ok(Accessor::flaky(provider, Cow::Borrowed(ids))), - None => Ok(Accessor::new(provider)), + Some(ids) => Ok(Accessor::flaky(provider, &[], Cow::Borrowed(ids))), + None => Ok(Accessor::new(provider, &[])), } } } @@ -1293,9 +1306,9 @@ impl glue::InsertStrategy for Strategy { &'a self, provider: &'a Provider, _context: &'a Context, - _vector: &[f32], + vector: &[f32], ) -> Result, Self::SearchAccessorError> { - Ok(Accessor::new(provider)) + Ok(Accessor::new(provider, vector)) } } @@ -1351,7 +1364,6 @@ impl<'a, 'b, O> glue::SearchPostProcessStep, &'b [f32], O> for Filt next: &Next, accessor: &mut Accessor<'a>, query: &'b [f32], - computer: &::QueryDistance, candidates: I, output: &mut B, ) -> impl std::future::Future>> + Send @@ -1364,7 +1376,6 @@ impl<'a, 'b, O> glue::SearchPostProcessStep, &'b [f32], O> for Filt next.post_process( accessor, query, - computer, candidates.filter(|n| !provider.is_deleted(n.id).unwrap_or(true)), output, ) @@ -1716,52 +1727,52 @@ mod tests { ); } - #[test] - fn test_set_element() { - use provider::{Guard, SetElement}; - - let provider = create_test_provider(); - let rt = current_thread_runtime(); - - let context = Context::new(); - let mut accessor = super::Accessor::new(&provider); - let id = 5; - - assert!(accessor.get(5).is_err()); - - // Setting with the wrong dimension is an error. - { - let v = vec![1.0f32; provider.dim() + 1]; - let err = rt - .block_on(provider.set_element(&context, &id, &v)) - .unwrap_err(); - let msg = err.to_string(); - assert_message_contains!(msg, "wrong dim"); - assert!(accessor.get(id).is_err()); - } - - // Setting with the correct dimension is successful. - { - let v = vec![1.0f32; provider.dim()]; - let guard = rt - .block_on(provider.set_element(&context, &id, &v)) - .unwrap(); - rt.block_on(guard.complete()); - - let element = accessor.get(id).unwrap(); - assert_eq!(v, element); - } - - // Setting again is an error. - { - let v = vec![1.0f32; provider.dim()]; - let err = rt - .block_on(provider.set_element(&context, &id, &v)) - .unwrap_err(); - let msg = err.to_string(); - assert_message_contains!(msg, "vector id 5 is already assigned"); - } - } + // #[test] + // fn test_set_element() { + // use provider::{Guard, SetElement}; + + // let provider = create_test_provider(); + // let rt = current_thread_runtime(); + + // let context = Context::new(); + // let mut accessor = super::Accessor::new(&provider); + // let id = 5; + + // assert!(accessor.get(5).is_err()); + + // // Setting with the wrong dimension is an error. + // { + // let v = vec![1.0f32; provider.dim() + 1]; + // let err = rt + // .block_on(provider.set_element(&context, &id, &v)) + // .unwrap_err(); + // let msg = err.to_string(); + // assert_message_contains!(msg, "wrong dim"); + // assert!(accessor.get(id).is_err()); + // } + + // // Setting with the correct dimension is successful. + // { + // let v = vec![1.0f32; provider.dim()]; + // let guard = rt + // .block_on(provider.set_element(&context, &id, &v)) + // .unwrap(); + // rt.block_on(guard.complete()); + + // let element = accessor.get(id).unwrap(); + // assert_eq!(v, element); + // } + + // // Setting again is an error. + // { + // let v = vec![1.0f32; provider.dim()]; + // let err = rt + // .block_on(provider.set_element(&context, &id, &v)) + // .unwrap_err(); + // let msg = err.to_string(); + // assert_message_contains!(msg, "vector id 5 is already assigned"); + // } + // } #[test] fn test_neighbor_accessor() { diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index 21b3fbd73..3c00dc0a6 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -71,10 +71,6 @@ //! * [`BuildDistanceComputer`]: A sub-trait of [`Accessor`] that allows for random-access //! distance computations on the retrieved elements. //! -//! * [`BuildQueryComputer`]: A sub-trait of [`Accessor`] that allows for specialized query -//! based computations. This allows a query to be pre-processed in a way that allows -//! faster computations. -//! //! # Neighbor Delegation //! //! Index search requires that accessor types implement both the data-centric [`Accessor`] @@ -410,28 +406,6 @@ pub trait BuildDistanceComputer: HasElementRef { ) -> Result; } -/// A specialized [`Accessor`] that provides query computations for a query type `T`. -/// -/// Query computers are allowed to preprocess the query to enable more efficient distance -/// computations. -pub trait BuildQueryComputer { - /// The error type (if any) associated with distance computer construction. - type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; - - /// The concrete type of the distance computer, which must be applicable for all - /// elements yielded by the [`Accessor`]. - type QueryComputer: Send + Sync; - - /// Build the query computer for this accessor. - /// - /// This method is encouraged to be as fast as possible, but will generally only be - /// invoked once per search or graph insert. - fn build_query_computer( - &self, - from: T, - ) -> Result; -} - ///////////////////////// // Neighbor Delegation // ///////////////////////// From 56122a592822a6707de7ac2fed0f73bb0645ef4a Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Wed, 20 May 2026 15:29:04 -0700 Subject: [PATCH 17/17] Well - this is something. --- .../src/search/provider/disk_provider.rs | 125 +--- diskann-providers/src/index/diskann_async.rs | 108 ++- diskann-providers/src/index/wrapped_async.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 5 +- .../provider/async_/inmem/full_precision.rs | 240 +++--- .../graph/provider/async_/inmem/product.rs | 690 ++++++++++++------ .../graph/provider/async_/inmem/scalar.rs | 245 +++---- .../graph/provider/async_/inmem/spherical.rs | 231 +++--- .../graph/provider/async_/postprocess.rs | 2 +- .../model/graph/provider/layers/betafilter.rs | 26 +- diskann/src/graph/test/provider.rs | 232 +++--- 11 files changed, 1021 insertions(+), 885 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index bf720a53f..33407b1c7 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -22,8 +22,8 @@ use diskann::{ search::Knn, search_output_buffer, DiskANNIndex, }, - neighbor::Neighbor, - provider::{BuildQueryComputer, DataProvider, DefaultContext, HasId, NoopGuard}, + neighbor::{Neighbor, NeighborPriorityQueue}, + provider::{DataProvider, DefaultContext, HasId, NoopGuard}, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, }; @@ -289,7 +289,6 @@ where &self, accessor: &mut DiskAccessor<'_, Data, VP>, query: &[Data::VectorDataType], - _computer: &DiskQueryComputer, candidates: I, output: &mut B, ) -> Result @@ -338,7 +337,6 @@ where Data: GraphDataType, ProviderFactory: VertexProviderFactory, { - type QueryComputer = DiskQueryComputer; type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; type SearchAccessorError = ANNError; @@ -346,6 +344,7 @@ where &'a self, provider: &'a DiskProvider, _context: &DefaultContext, + query: &[Data::VectorDataType], ) -> Result, Self::SearchAccessorError> { DiskAccessor::new( provider, @@ -377,50 +376,6 @@ where } } -/// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates. -pub struct DiskQueryComputer { - num_pq_chunks: usize, - query_centroid_l2_distance: Vec, -} - -impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer { - fn evaluate_similarity(&self, changing: &[u8]) -> f32 { - let mut dist = 0.0f32; - #[allow(clippy::expect_used)] - compute_pq_distance_for_pq_coordinates( - changing, - self.num_pq_chunks, - &self.query_centroid_l2_distance, - std::slice::from_mut(&mut dist), - ) - .expect("PQ distance compute for PQ coordinates is expected to succeed"); - dist - } -} - -impl BuildQueryComputer<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type QueryComputerError = ANNError; - type QueryComputer = DiskQueryComputer; - - fn build_query_computer( - &self, - _from: &[Data::VectorDataType], - ) -> Result { - Ok(DiskQueryComputer { - num_pq_chunks: self.provider.pq_data.get_num_chunks(), - query_centroid_l2_distance: self - .scratch - .pq_scratch - .aligned_pqtable_dist_scratch - .to_vec(), - }) - } -} - // Scratch space for disk search operations that need allocations. // These allocations are amortized across searches using the scratch pool. struct DiskSearchScratch @@ -522,7 +477,7 @@ where } } -impl Explore<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> +impl Explore for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, @@ -534,19 +489,19 @@ where async fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> ANNResult<()> where F: FnMut(Self::Id, f32) + Send, { let start_vertex_id = self.provider.graph_header.metadata().medoid as u32; - let vector = self - .provider - .pq_data - .get_compressed_vector(start_vertex_id.into_usize())?; - let distance = computer.evaluate_similarity(vector); - f(start_vertex_id, distance); + self.pq_distances(&[start_vertex_id], |distance, id| f(id, distance))?; + // let vector = self + // .provider + // .pq_data + // .get_compressed_vector(start_vertex_id.into_usize())?; + // let distance = computer.evaluate_similarity(vector); + // f(start_vertex_id, distance); Ok(()) } @@ -557,7 +512,6 @@ where fn expand_beam( &mut self, ids: Itr, - _computer: &Self::QueryComputer, mut pred: P, mut f: F, ) -> impl std::future::Future> + Send @@ -833,35 +787,36 @@ where { let provider = self.index.provider(); let mut accessor = strategy - .search_accessor(provider, &DefaultContext) + .search_accessor(provider, &DefaultContext, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); - let mut cmps = 0u32; - - let num_points = provider.num_points as u32; - for id in 0..num_points { - if vector_filter(&id) { - let element = accessor.get_element(id).await.into_ann_result()?; - let dist = computer.evaluate_similarity(element); - best.insert(Neighbor::new(id, dist)); - cmps += 1; - } - } - - let result_count = strategy - .default_post_processor() - .post_process(&mut accessor, query, &computer, best.iter(), output) - .await - .into_ann_result()?; - - Ok(graph::index::SearchStats { - cmps, - hops: 0, - result_count: result_count as u32, - range_search_second_round: false, - }) + unimplemented!(); + // let computer = accessor.build_query_computer(query).into_ann_result()?; + + // let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); + // let mut cmps = 0u32; + + // let num_points = provider.num_points as u32; + // for id in 0..num_points { + // if vector_filter(&id) { + // let element = accessor.get_element(id).await.into_ann_result()?; + // let dist = computer.evaluate_similarity(element); + // best.insert(Neighbor::new(id, dist)); + // cmps += 1; + // } + // } + + // let result_count = strategy + // .default_post_processor() + // .post_process(&mut accessor, query, &computer, best.iter(), output) + // .await + // .into_ann_result()?; + + // Ok(graph::index::SearchStats { + // cmps, + // hops: 0, + // result_count: result_count as u32, + // range_search_second_round: false, + // }) } /// Perform a search on the disk index. diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index ee2adcd4c..892372ed6 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -174,8 +174,8 @@ pub(crate) mod tests { }, neighbor::Neighbor, provider::{ - AsNeighbor, AsNeighborMut, DataProvider, DefaultContext, Delete, - ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, + AsNeighbor, AsNeighborMut, DataProvider, DefaultContext, Delete, ExecutionContext, + Guard, NeighborAccessor, NeighborAccessorMut, SetElement, }, utils::{IntoUsize, ONE}, }; @@ -434,7 +434,7 @@ pub(crate) mod tests { { assert!(max_candidates <= groundtruth.len()); let mut search = index - .paged_search(strategy, ¶meters.context, query, parameters.search_l) + .paged_search(&strategy, ¶meters.context, query, parameters.search_l) .await .unwrap(); @@ -2318,9 +2318,8 @@ pub(crate) mod tests { let accessor = strategy .search_accessor(index.provider(), ctx, data.row(0)) .unwrap(); - let computer = accessor.build_query_computer(data.row(0)).unwrap(); assert_eq!( - computer.layout(), + accessor.computer.layout(), diskann_quantization::spherical::iface::QueryLayout::SameAsData ); } @@ -2427,9 +2426,8 @@ pub(crate) mod tests { &[f32], >>::search_accessor(&strategy, index.provider(), ctx, data.row(0)) .unwrap(); - let computer = accessor.build_query_computer(data.row(0)).unwrap(); assert_eq!( - computer.layout(), + accessor.computer.layout(), diskann_quantization::spherical::iface::QueryLayout::SameAsData ); } @@ -3273,55 +3271,55 @@ pub(crate) mod tests { Ok(index) } - #[tokio::test] - async fn test_saturate_index() { - let index_sat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), true) - .await - .unwrap(); - let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider()); - let res_sat = index_sat - .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) - .await - .unwrap(); - - let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) - .await - .unwrap(); - let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); - let res_unsat = index_unsat - .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) - .await - .unwrap(); - assert!( - res_sat.avg_degree > res_unsat.avg_degree, - "Saturated index should have higher average degree than the unsaturated index" - ); - } - - #[tokio::test] - async fn test_retry_index() { - let index_sat = create_retry_saturated_index(NonZeroU32::new(3).unwrap(), false) - .await - .unwrap(); - let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider()); - let res_sat = index_sat - .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) - .await - .unwrap(); + // #[tokio::test] + // async fn test_saturate_index() { + // let index_sat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), true) + // .await + // .unwrap(); + // let mut accessor_sat = inmem::Builder::new(index_sat.provider(), query); + // let res_sat = index_sat + // .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) + // .await + // .unwrap(); + + // let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) + // .await + // .unwrap(); + // let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); + // let res_unsat = index_unsat + // .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) + // .await + // .unwrap(); + // assert!( + // res_sat.avg_degree > res_unsat.avg_degree, + // "Saturated index should have higher average degree than the unsaturated index" + // ); + // } - let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) - .await - .unwrap(); - let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); - let res_unsat = index_sat - .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) - .await - .unwrap(); - assert!( - res_sat.avg_degree > res_unsat.avg_degree, - "Saturated index should have higher average degree than the unsaturated index" - ); - } + // #[tokio::test] + // async fn test_retry_index() { + // let index_sat = create_retry_saturated_index(NonZeroU32::new(3).unwrap(), false) + // .await + // .unwrap(); + // let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider()); + // let res_sat = index_sat + // .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) + // .await + // .unwrap(); + + // let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) + // .await + // .unwrap(); + // let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); + // let res_unsat = index_sat + // .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) + // .await + // .unwrap(); + // assert!( + // res_sat.avg_degree > res_unsat.avg_degree, + // "Saturated index should have higher average degree than the unsaturated index" + // ); + // } #[cfg(feature = "experimental_diversity_search")] #[tokio::test] diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 40a9aeea7..fa0814ead 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -560,7 +560,7 @@ pub mod noawait { // The assumption of `noawait` is that this call will always resolve to // `Poll::Ready`. let mut state = match index - .paged_search(strategy, &context, query.reborrow(), l_value) + .paged_search(&strategy, &context, query.reborrow(), l_value) .await { Ok(state) => state, diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index f87d1e936..f8e48340e 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -27,9 +27,8 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - Accessor, BuildDistanceComputer, DataProvider, DefaultContext, - DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, - NoopGuard, SetElement, + Accessor, BuildDistanceComputer, DataProvider, DefaultContext, DelegateNeighbor, Delete, + ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::{IntoUsize, VectorRepr}, }; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 76e1167df..cfa27192d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -19,8 +19,8 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - BuildDistanceComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasElementRef, HasId, + BuildDistanceComputer, DefaultContext, DelegateNeighbor, ExecutionContext, HasElementRef, + HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -102,6 +102,105 @@ where } } +///////////// +// Builder // +///////////// + +#[derive(Clone)] +pub struct Builder<'a, T> +where + T: VectorRepr, +{ + metric: Metric, + store: &'a FastMemoryVectorProviderAsync, + neighbors: &'a SimpleNeighborProviderAsync, +} + +impl HasId for Builder<'_, T> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl HasElementRef for Builder<'_, T> +where + T: VectorRepr, +{ + type ElementRef<'a> = &'a [T]; +} + +impl<'a, T> DelegateNeighbor<'a> for Builder<'_, T> +where + T: VectorRepr, +{ + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.neighbors + } +} + +impl BuildDistanceComputer for Builder<'_, T> +where + T: VectorRepr, +{ + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.metric, + Some(self.store.dim()), + )) + } +} + +// All this does is return a `&Self` - which directly accesses the underlying provider. +impl workingset::Fill for Builder<'_, T> +where + T: VectorRepr, +{ + type Error = Infallible; + + type View<'a> + = Builder<'a, T> + where + Self: 'a; + + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self.clone()) + } +} + +// Pass-through view. +impl workingset::View for Builder<'_, T> +where + T: VectorRepr, +{ + type ElementRef<'a> = &'a [T]; + type Element<'a> + = &'a [T] + where + Self: 'a; + + fn get(&self, id: u32) -> Option<&[T]> { + // SAFETY: This is unsound. We assume no concurrent writes to this slot, but + // this invariant is not enforced. See `get_vector_sync` for details + Some(unsafe { self.store.get_vector_sync(id.into_usize()) }) + } +} + ////////////////// // FullAccessor // ////////////////// @@ -119,6 +218,9 @@ where /// The host provider. provider: &'a FullPrecisionProvider, + /// The distance computer. + computer: ::QueryDistance, + /// A buffer for resolving iterators given during bulk operations. /// /// The accessor reuses this allocation to amortize allocation cost over multiple bulk @@ -165,7 +267,7 @@ where for i in self.provider.starting_points()? { // SAFETY: We're accepting the consequences of potential unsynchronized, // concurrent mutation. - let distance = computer.evaluate_similarity(unsafe { + let distance = self.computer.evaluate_similarity(unsafe { self.provider.base_vectors.get_vector_sync(i.into_usize()) }); @@ -215,7 +317,7 @@ where // SAFETY: We're accepting the consequences of potential unsynchronized, // concurrent mutation. let v = unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }; - let distance = computer.evaluate_similarity(v); + let distance = self.computer.evaluate_similarity(v); on_neighbors(distance, *id); } } @@ -233,76 +335,18 @@ where D: AsyncFriendly, Ctx: ExecutionContext, { - pub fn new(provider: &'a FullPrecisionProvider) -> Self { + pub fn new( + provider: &'a FullPrecisionProvider, + query: &[T], + ) -> Self { Self { provider, + computer: T::query_distance(query, provider.metric), id_buffer: AdjacencyList::new(), } } } -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl HasElementRef for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = &'a [T]; -} - -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) - } -} - -// impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D, Ctx> -// where -// T: VectorRepr, -// Q: AsyncFriendly, -// D: AsyncFriendly, -// Ctx: ExecutionContext, -// { -// type QueryComputerError = Panics; -// type QueryComputer = T::QueryDistance; -// -// fn build_query_computer( -// &self, -// from: &[T], -// ) -> Result { -// Ok(T::query_distance(from, self.provider.metric)) -// } -// } - //-------------------// // In-mem Extensions // //-------------------// @@ -404,9 +448,9 @@ where &'a self, provider: &'a FullPrecisionProvider, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) + Ok(FullAccessor::new(provider, query)) } } @@ -429,7 +473,7 @@ where Ctx: ExecutionContext, { type DistanceComputer<'a> = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessor<'a> = Builder<'a, T>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -438,7 +482,12 @@ where provider: &'a FullPrecisionProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) + let builder = Builder { + metric: provider.metric, + store: &provider.base_vectors, + neighbors: provider.neighbors(), + }; + Ok(builder) } fn create_working_set(&self, _capacity: usize) -> Self::WorkingSet { @@ -446,55 +495,6 @@ where } } -// All this does is return a `&Self` - which directly accesses the underlying provider. -impl workingset::Fill for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = Infallible; - - type View<'a> - = &'a Self - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - Ok(self) - } -} - -// Pass-through view. -impl workingset::View for &FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = &'a [T]; - type Element<'a> - = &'a [T] - where - Self: 'a; - - fn get(&self, id: u32) -> Option<&[T]> { - // SAFETY: This is unsound. We assume no concurrent writes to this slot, but - // this invariant is not enforced. See `get_vector_sync` for details - Some(unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }) - } -} - impl InsertStrategy, &[T]> for FullPrecision where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 429ae02c6..5838e1d52 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -16,10 +16,7 @@ use diskann::{ }, workingset, }, - provider::{ - BuildDistanceComputer, DelegateNeighbor, ExecutionContext, - HasElementRef, HasId, - }, + provider::{BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId}, utils::{IntoUsize, VectorRepr}, }; use diskann_utils::future::AsyncFriendly; @@ -83,252 +80,118 @@ where } /////////////////// -// QuantAccessor // +// Quant Builder // /////////////////// -/// An accessor that retrieves the quantized portion of the [`DefaultProvider`]. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the `DefaultProvider`. -pub struct QuantAccessor<'a, V, D, Ctx> { - provider: &'a DefaultProvider, +#[derive(Clone)] +pub struct QuantBuilder<'a> { + provider: &'a FastMemoryQuantVectorProviderAsync, + neighbors: &'a SimpleNeighborProviderAsync, } -impl GetFullPrecision for QuantAccessor<'_, FullPrecisionStore, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { - &self.provider.base_vectors - } -} - -impl HasId for QuantAccessor<'_, V, D, Ctx> { +impl HasId for QuantBuilder<'_> { type Id = u32; } -impl Explore for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } - - fn start_point_distances( - &mut self, - mut f: F, - ) -> impl std::future::Future> + Send - where - F: FnMut(Self::Id, f32) + Send, - { - async move { - for i in self.provider.starting_points()? { - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - let distance = computer.evaluate_similarity(unsafe { - self.provider.aux_vectors.get_vector_sync(i.into_usize()) - }); - - f(i, distance); - } - Ok(()) - } - } - - fn expand_beam( - &mut self, - ids: Itr, - mut pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let f = move || -> ANNResult<()> { - let mut neighbors = AdjacencyList::new(); - for n in ids { - self.provider - .neighbor_provider - .get_neighbors_sync(n.into_usize(), &mut neighbors)?; - for i in neighbors.iter().filter(|i| pred.eval_mut(i)) { - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - let distance = computer.evaluate_similarity(unsafe { - self.provider.aux_vectors.get_vector_sync(i.into_usize()) - }); - - on_neighbors(distance, *i); - } - } - Ok(()) - }; - - std::future::ready(f()) - } -} - -impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub(crate) fn new(provider: &'a DefaultProvider) -> Self { - Self { provider } - } +impl HasElementRef for QuantBuilder<'_> { + type ElementRef<'a> = &'a [u8]; } -impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ +impl<'a> DelegateNeighbor<'a> for QuantBuilder<'_> { type Delegate = &'a SimpleNeighborProviderAsync; fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() + self.neighbors } } -impl HasElementRef for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = &'a [u8]; -} - -// impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> -// where -// T: VectorRepr, -// V: AsyncFriendly, -// D: AsyncFriendly, -// Ctx: ExecutionContext, -// { -// type QueryComputerError = ANNError; -// type QueryComputer = pq::distance::QueryComputer>; -// -// fn build_query_computer( -// &self, -// from: &[T], -// ) -> Result { -// self.provider.aux_vectors.query_computer(from) -// } -// } - -impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ +impl BuildDistanceComputer for QuantBuilder<'_> { type DistanceComputerError = ANNError; type DistanceComputer = pq::distance::DistanceComputer>; fn build_distance_computer( &self, ) -> Result { - Ok(self.provider.aux_vectors.distance_computer()) + Ok(self.provider.distance_computer()) } } -//-------------------// -// In-mem Extensions // -//-------------------// +impl workingset::Fill for QuantBuilder<'_> { + type Error = std::convert::Infallible; + type View<'a> + = QuantBuilder<'a> + where + Self: 'a; -impl<'a, V, D, Ctx> AsDeletionCheck for QuantAccessor<'a, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self.clone()) } } -///////////////////// -// Hybrid Accessor // -///////////////////// +// Pass-through view — reads PQ codes directly from the provider. +impl workingset::View for QuantBuilder<'_> { + type ElementRef<'a> = &'a [u8]; + type Element<'a> + = &'a [u8] + where + Self: 'a; -/// A hybrid accessor that fetches a mixture of full-precision and quantized vectors during -/// pruning. This allows the application to trade full-precision fetches for accuracy. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`BuildDistanceComputer`] for computing distances among [`distances::pq::Hybrid`] -/// element types. -/// * [`Fill`] for populating a mixture of full-precision and quant vectors. -pub struct HybridAccessor<'a, T, D, Ctx> + fn get(&self, id: u32) -> Option<&[u8]> { + // SAFETY: This is unsound. We assume no concurrent writes to this slot, but + // this invariant is not enforced. See `get_vector_sync` for details. + Some(unsafe { self.provider.get_vector_sync(id.into_usize()) }) + } +} + +/////////////////// +// HybridBuilder // +/////////////////// + +#[derive(Clone)] +pub struct HybridBuilder<'a, T> where T: VectorRepr, { - provider: &'a FullPrecisionProvider, - - /// Maximum number of full-precision vectors to use during pruning. - /// This field is ignored during search, where full-precision vectors are never used. + full: &'a FastMemoryVectorProviderAsync, + quant: &'a FastMemoryQuantVectorProviderAsync, + neighbors: &'a SimpleNeighborProviderAsync, max_fp_vecs_per_prune: usize, } -impl<'a, T, D, Ctx> HybridAccessor<'a, T, D, Ctx> +impl HasId for HybridBuilder<'_, T> where T: VectorRepr, { - fn new( - provider: &'a FullPrecisionProvider, - max_fp_vecs_per_prune: usize, - ) -> Self { - Self { - provider, - max_fp_vecs_per_prune, - } - } + type Id = u32; } -impl HasId for HybridAccessor<'_, T, D, Ctx> +impl HasElementRef for HybridBuilder<'_, T> where T: VectorRepr, { - type Id = u32; + type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; } -impl<'a, T, D, Ctx> DelegateNeighbor<'a> for HybridAccessor<'_, T, D, Ctx> +impl<'a, T> DelegateNeighbor<'a> for HybridBuilder<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type Delegate = &'a SimpleNeighborProviderAsync; fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() + self.neighbors } } -impl HasElementRef for HybridAccessor<'_, T, D, Ctx> +impl BuildDistanceComputer for HybridBuilder<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; -} - -impl BuildDistanceComputer for HybridAccessor<'_, T, D, Ctx> -where - T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type DistanceComputerError = ANNError; type DistanceComputer = distances::pq::HybridComputer; @@ -336,10 +199,10 @@ where fn build_distance_computer( &self, ) -> Result { - let metric = self.provider.aux_vectors.metric(); + let metric = self.quant.metric(); Ok(distances::pq::HybridComputer::new( - self.provider.aux_vectors.distance_computer(), - T::distance(metric, Some(self.provider.base_vectors.dim())), + self.quant.distance_computer(), + T::distance(metric, Some(self.full.dim())), )) } } @@ -358,15 +221,13 @@ impl workingset::AsWorkingSet for Unseeded { // Selective fill — the first `max_fp_vecs_per_prune` candidates receive full-precision // vectors; the remainder are accessed through the pass-through `MaybeFullPrecision` view. -impl workingset::Fill for HybridAccessor<'_, T, D, Ctx> +impl workingset::Fill for HybridBuilder<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type Error = std::convert::Infallible; type View<'a> - = MaybeFullPrecision<'a, T, D, Ctx> + = MaybeFullPrecision<'a, T> where Self: 'a; @@ -382,25 +243,23 @@ where state.0.clear(); state.0.extend(itr.take(self.max_fp_vecs_per_prune)); Ok(MaybeFullPrecision { - accessor: self, + builder: self, full: state, }) } } -pub struct MaybeFullPrecision<'a, T, D, Ctx> +pub struct MaybeFullPrecision<'a, T> where T: VectorRepr, { - accessor: &'a HybridAccessor<'a, T, D, Ctx>, + builder: &'a HybridBuilder<'a, T>, full: &'a FullPrecisionTracker, } -impl workingset::View for MaybeFullPrecision<'_, T, D, Ctx> +impl workingset::View for MaybeFullPrecision<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; type Element<'a> @@ -409,50 +268,398 @@ where Self: 'a; fn get(&self, id: u32) -> Option> { - let provider = &self.accessor.provider; let element = if self.full.0.contains(&id) { // SAFETY: This is unsound. We assume no concurrent writes to this slot, but // this invariant is not enforced. See `get_vector_sync` for details. unsafe { - distances::pq::Hybrid::Full(provider.base_vectors.get_vector_sync(id.into_usize())) + distances::pq::Hybrid::Full(self.builder.full.get_vector_sync(id.into_usize())) } } else { // SAFETY: This is unsound. We assume no concurrent writes to this slot, but // this invariant is not enforced. See `get_vector_sync` for details. unsafe { - distances::pq::Hybrid::Quant(provider.aux_vectors.get_vector_sync(id.into_usize())) + distances::pq::Hybrid::Quant(self.builder.quant.get_vector_sync(id.into_usize())) } }; Some(element) } } -// Pass-through fill — returns `&Self` which directly accesses the underlying provider. -impl workingset::Fill for QuantAccessor<'_, V, D, Ctx> +/////////////////// +// QuantAccessor // +/////////////////// + +/// An accessor that retrieves the quantized portion of the [`DefaultProvider`]. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the `DefaultProvider`. +pub struct QuantAccessor<'a, V, D, Ctx> { + provider: &'a DefaultProvider, + computer: pq::distance::QueryComputer>, +} + +impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> where V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { - type Error = std::convert::Infallible; - type View<'a> - = &'a Self + pub(crate) fn new( + provider: &'a DefaultProvider, + query: &[f32], + ) -> Self { + let computer = provider.aux_vectors.query_computer(query).unwrap(); + Self { provider, computer } + } +} + +impl GetFullPrecision for QuantAccessor<'_, FullPrecisionStore, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { + &self.provider.base_vectors + } +} + +impl HasId for QuantAccessor<'_, V, D, Ctx> { + type Id = u32; +} + +impl Explore for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } + + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send where - Self: 'a; + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = self.computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> + f(i, distance); + } + Ok(()) + } + } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, { - Ok(self) + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + for i in neighbors.iter().filter(|i| pred.eval_mut(i)) { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = self.computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); + + on_neighbors(distance, *i); + } + } + Ok(()) + }; + + std::future::ready(f()) + } +} + +// impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Delegate = &'a SimpleNeighborProviderAsync; +// fn delegate_neighbor(&'a mut self) -> Self::Delegate { +// self.provider.neighbors() +// } +// } +// +// impl HasElementRef for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type ElementRef<'a> = &'a [u8]; +// } + +// impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> +// where +// T: VectorRepr, +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type QueryComputerError = ANNError; +// type QueryComputer = pq::distance::QueryComputer>; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// self.provider.aux_vectors.query_computer(from) +// } +// } + +// impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type DistanceComputerError = ANNError; +// type DistanceComputer = pq::distance::DistanceComputer>; +// +// fn build_distance_computer( +// &self, +// ) -> Result { +// Ok(self.provider.aux_vectors.distance_computer()) +// } +// } + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, V, D, Ctx> AsDeletionCheck for QuantAccessor<'a, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted } } +// ///////////////////// +// // Hybrid Accessor // +// ///////////////////// +// +// /// A hybrid accessor that fetches a mixture of full-precision and quantized vectors during +// /// pruning. This allows the application to trade full-precision fetches for accuracy. +// /// +// /// This type implements the following traits: +// /// +// /// * [`Accessor`] for the [`DefaultProvider`]. +// /// * [`BuildDistanceComputer`] for computing distances among [`distances::pq::Hybrid`] +// /// element types. +// /// * [`Fill`] for populating a mixture of full-precision and quant vectors. +// pub struct HybridAccessor<'a, T, D, Ctx> +// where +// T: VectorRepr, +// { +// provider: &'a FullPrecisionProvider, +// +// /// Maximum number of full-precision vectors to use during pruning. +// /// This field is ignored during search, where full-precision vectors are never used. +// max_fp_vecs_per_prune: usize, +// } +// +// impl<'a, T, D, Ctx> HybridAccessor<'a, T, D, Ctx> +// where +// T: VectorRepr, +// { +// fn new( +// provider: &'a FullPrecisionProvider, +// max_fp_vecs_per_prune: usize, +// ) -> Self { +// Self { +// provider, +// max_fp_vecs_per_prune, +// } +// } +// } +// +// impl HasId for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// { +// type Id = u32; +// } +// +// impl<'a, T, D, Ctx> DelegateNeighbor<'a> for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Delegate = &'a SimpleNeighborProviderAsync; +// fn delegate_neighbor(&'a mut self) -> Self::Delegate { +// self.provider.neighbors() +// } +// } +// +// impl HasElementRef for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +// } +// +// impl BuildDistanceComputer for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type DistanceComputerError = ANNError; +// type DistanceComputer = distances::pq::HybridComputer; +// +// fn build_distance_computer( +// &self, +// ) -> Result { +// let metric = self.provider.aux_vectors.metric(); +// Ok(distances::pq::HybridComputer::new( +// self.provider.aux_vectors.distance_computer(), +// T::distance(metric, Some(self.provider.base_vectors.dim())), +// )) +// } +// } +// +// /// Tracks which IDs should use full-precision vectors during hybrid pruning. +// /// +// /// IDs in the set get full-precision distance computations; all others fall back to +// /// quantized vectors. +// pub struct FullPrecisionTracker(hashbrown::HashSet); +// +// impl workingset::AsWorkingSet for Unseeded { +// fn as_working_set(&self, capacity: usize) -> FullPrecisionTracker { +// FullPrecisionTracker(hashbrown::HashSet::with_capacity(capacity)) +// } +// } +// +// // Selective fill — the first `max_fp_vecs_per_prune` candidates receive full-precision +// // vectors; the remainder are accessed through the pass-through `MaybeFullPrecision` view. +// impl workingset::Fill for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Error = std::convert::Infallible; +// type View<'a> +// = MaybeFullPrecision<'a, T, D, Ctx> +// where +// Self: 'a; +// +// async fn fill<'a, Itr>( +// &'a mut self, +// state: &'a mut FullPrecisionTracker, +// itr: Itr, +// ) -> Result, Self::Error> +// where +// Itr: ExactSizeIterator + Clone + Send + Sync, +// Self: 'a, +// { +// state.0.clear(); +// state.0.extend(itr.take(self.max_fp_vecs_per_prune)); +// Ok(MaybeFullPrecision { +// accessor: self, +// full: state, +// }) +// } +// } +// +// pub struct MaybeFullPrecision<'a, T, D, Ctx> +// where +// T: VectorRepr, +// { +// accessor: &'a HybridAccessor<'a, T, D, Ctx>, +// full: &'a FullPrecisionTracker, +// } +// +// impl workingset::View for MaybeFullPrecision<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +// type Element<'a> +// = distances::pq::Hybrid<&'a [T], &'a [u8]> +// where +// Self: 'a; +// +// fn get(&self, id: u32) -> Option> { +// let provider = &self.accessor.provider; +// let element = if self.full.0.contains(&id) { +// // SAFETY: This is unsound. We assume no concurrent writes to this slot, but +// // this invariant is not enforced. See `get_vector_sync` for details. +// unsafe { +// distances::pq::Hybrid::Full(provider.base_vectors.get_vector_sync(id.into_usize())) +// } +// } else { +// // SAFETY: This is unsound. We assume no concurrent writes to this slot, but +// // this invariant is not enforced. See `get_vector_sync` for details. +// unsafe { +// distances::pq::Hybrid::Quant(provider.aux_vectors.get_vector_sync(id.into_usize())) +// } +// }; +// Some(element) +// } +// } +// +// // Pass-through fill — returns `&Self` which directly accesses the underlying provider. +// impl workingset::Fill for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Error = std::convert::Infallible; +// type View<'a> +// = &'a Self +// where +// Self: 'a; +// +// async fn fill<'a, Itr>( +// &'a mut self, +// _state: &'a mut PassThrough, +// _itr: Itr, +// ) -> Result, Self::Error> +// where +// Itr: ExactSizeIterator + Clone + Send + Sync, +// Self: 'a, +// { +// Ok(self) +// } +// } + // Pass-through view — reads PQ codes directly from the provider. impl workingset::View for &QuantAccessor<'_, V, D, Ctx> where @@ -496,14 +703,12 @@ where &'a self, provider: &'a FullPrecisionProvider, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap())) } } -/// Starting points are filtered out of the final results and results are reranked using -/// the full-precision data. impl DefaultPostProcessor, &[T]> for Hybrid where @@ -521,7 +726,7 @@ where Ctx: ExecutionContext, { type DistanceComputer<'a> = distances::pq::HybridComputer; - type PruneAccessor<'a> = HybridAccessor<'a, T, D, Ctx>; + type PruneAccessor<'a> = HybridBuilder<'a, T>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = FullPrecisionTracker; @@ -534,10 +739,13 @@ where provider: &'a FullPrecisionProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(HybridAccessor::new( - provider, - self.max_fp_vecs_per_prune.unwrap_or(usize::MAX), - )) + let builder = HybridBuilder { + full: &provider.base_vectors, + quant: &provider.aux_vectors, + neighbors: provider.neighbors(), + max_fp_vecs_per_prune: self.max_fp_vecs_per_prune.unwrap_or(usize::MAX), + }; + Ok(builder) } } @@ -646,9 +854,9 @@ where &'a self, provider: &'a DefaultProvider, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap())) } } @@ -668,7 +876,7 @@ where Ctx: ExecutionContext, { type DistanceComputer<'a> = pq::distance::DistanceComputer>; - type PruneAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; + type PruneAccessor<'a> = QuantBuilder<'a>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -681,7 +889,11 @@ where provider: &'a DefaultProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(QuantAccessor::new(provider)) + let builder = QuantBuilder { + provider: &provider.aux_vectors, + neighbors: provider.neighbors(), + }; + Ok(builder) } } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 888a44217..2f9c63424 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -16,10 +16,7 @@ use diskann::{ }, workingset, }, - provider::{ - BuildDistanceComputer, DelegateNeighbor, ExecutionContext, - HasElementRef, HasId, - }, + provider::{BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId}, utils::{IntoUsize, VectorRepr}, }; use diskann_quantization::{ @@ -365,21 +362,112 @@ where } } +///////////// +// Builder // +///////////// + +pub struct Builder<'a, const NBITS: usize> { + store: &'a SQStore, + neighbors: &'a SimpleNeighborProviderAsync, +} + +impl<'a, const NBITS: usize> Builder<'a, NBITS> { + fn new(store: &'a SQStore, neighbors: &'a SimpleNeighborProviderAsync) -> Self { + Self { store, neighbors } + } +} + +impl HasId for Builder<'_, NBITS> { + type Id = u32; +} + +impl HasElementRef for Builder<'_, NBITS> +where + Unsigned: Representation, +{ + type ElementRef<'a> = CVRef<'a, NBITS>; +} + +impl<'a, const NBITS: usize> DelegateNeighbor<'a> for Builder<'_, NBITS> { + type Delegate = &'a SimpleNeighborProviderAsync; + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.neighbors + } +} + +impl BuildDistanceComputer for Builder<'_, NBITS> +where + Unsigned: Representation, + DistanceComputer: for<'a, 'b> DistanceFunction, CVRef<'b, NBITS>, f32>, +{ + type DistanceComputerError = ANNError; + type DistanceComputer = DistanceComputer; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(self.store.distance_computer()?) + } +} + +impl workingset::Fill for Builder<'_, NBITS> +where + Unsigned: Representation, +{ + type Error = std::convert::Infallible; + type View<'a> + = &'a Self + where + Self: 'a; + + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self) + } +} + +// Pass-through view — reads scalar-quantized vectors directly from the provider. +impl workingset::View for &Builder<'_, NBITS> +where + Unsigned: Representation, +{ + type ElementRef<'a> = CVRef<'a, NBITS>; + type Element<'a> + = CVRef<'a, NBITS> + where + Self: 'a; + + fn get(&self, id: u32) -> Option> { + self.store.get_vector(id.into_usize()).ok() + } +} + ////////////// // Accessor // ////////////// /// The accessor for SQ. -pub struct QuantAccessor<'a, const NBITS: usize, V, D, Ctx> { +pub struct QuantAccessor<'a, const NBITS: usize, V, D, Ctx> +where + Unsigned: Representation, +{ provider: &'a DefaultProvider, D, Ctx>, + computer: QueryComputer, id_buffer: AdjacencyList, - is_search: bool, } impl GetFullPrecision for QuantAccessor<'_, NBITS, FullPrecisionStore, D, Ctx> where T: VectorRepr, + Unsigned: Representation, { type Repr = T; fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { @@ -387,13 +475,15 @@ where } } -impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx> { +impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx> +where + Unsigned: Representation, +{ type Id = u32; } -impl Explore for QuantAccessor<'_, NBITS, V, D, Ctx> +impl Explore for QuantAccessor<'_, NBITS, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, @@ -414,7 +504,7 @@ where async move { for i in self.provider.starting_points()? { let vector = self.provider.aux_vectors.get_vector(i.into_usize())?; - f(i, computer.evaluate_similarity(vector)); + f(i, self.computer.evaluate_similarity(vector)); } Ok(()) } @@ -457,9 +547,7 @@ where } let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; - - // Invoke the passed closure on the vector. - let distance = computer.evaluate_similarity(vector); + let distance = self.computer.evaluate_similarity(vector); on_neighbors(distance, *id); } } @@ -472,87 +560,27 @@ where impl<'a, const NBITS: usize, V, D, Ctx> QuantAccessor<'a, NBITS, V, D, Ctx> where + Unsigned: Representation, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { pub(crate) fn new( provider: &'a DefaultProvider, D, Ctx>, + query: &[f32], is_search: bool, ) -> Self { Self { provider, + computer: provider + .aux_vectors + .query_computer(query, is_search) + .unwrap(), id_buffer: AdjacencyList::with_capacity(32), - is_search, } } } -impl HasElementRef for QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, - Unsigned: Representation, -{ - type ElementRef<'a> = CVRef<'a, NBITS>; -} - -impl<'a, const NBITS: usize, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -// impl BuildQueryComputer<&[T]> -// for QuantAccessor<'_, NBITS, V, D, Ctx> -// where -// T: VectorRepr, -// V: AsyncFriendly, -// D: AsyncFriendly, -// Ctx: ExecutionContext, -// Unsigned: Representation, -// QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, -// { -// type QueryComputerError = ANNError; -// type QueryComputer = QueryComputer; -// -// fn build_query_computer( -// &self, -// from: &[T], -// ) -> Result { -// // Allow rescaling if this is search. -// Ok(self -// .provider -// .aux_vectors -// .query_computer(from, self.is_search)?) -// } -// } - -impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, - Unsigned: Representation, - DistanceComputer: for<'a, 'b> DistanceFunction, CVRef<'b, NBITS>, f32>, -{ - type DistanceComputerError = ANNError; - type DistanceComputer = DistanceComputer; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(self.provider.aux_vectors.distance_computer()?) - } -} - impl AsDeletionCheck for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, @@ -582,7 +610,6 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - // type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -590,9 +617,9 @@ where &'a self, provider: &'a FullPrecisionProvider, D, Ctx>, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, true)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), true)) } } @@ -628,9 +655,9 @@ where &'a self, provider: &'a DefaultProvider, D, Ctx>, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, true)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), true)) } } @@ -656,7 +683,7 @@ where DistanceComputer: for<'a, 'b> DistanceFunction, CVRef<'b, NBITS>, f32>, { type DistanceComputer<'a> = DistanceComputer; - type PruneAccessor<'a> = QuantAccessor<'a, NBITS, V, D, Ctx>; + type PruneAccessor<'a> = Builder<'a, NBITS>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -669,57 +696,11 @@ where provider: &'a DefaultProvider, D, Ctx>, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(QuantAccessor::new(provider, false)) + Ok(Builder::new(&provider.aux_vectors, provider.neighbors())) } } // Pass-through fill — returns `&Self` which directly accesses the underlying provider. -impl workingset::Fill - for QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, - Unsigned: Representation, -{ - type Error = std::convert::Infallible; - type View<'a> - = &'a Self - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - Ok(self) - } -} - -// Pass-through view — reads scalar-quantized vectors directly from the provider. -impl workingset::View for &QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, - Unsigned: Representation, -{ - type ElementRef<'a> = CVRef<'a, NBITS>; - type Element<'a> - = CVRef<'a, NBITS> - where - Self: 'a; - - fn get(&self, id: u32) -> Option> { - self.provider.aux_vectors.get_vector(id.into_usize()).ok() - } -} - impl InsertStrategy, D, Ctx>, &[T]> for Quantized where diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 565e6214f..02cb46e5a 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -17,10 +17,7 @@ use diskann::{ }, workingset, }, - provider::{ - BuildDistanceComputer, DelegateNeighbor, ExecutionContext, - HasElementRef, HasId, - }, + provider::{BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId}, utils::{IntoUsize, VectorRepr}, }; use diskann_quantization::{ @@ -284,15 +281,85 @@ where } } +///////////// +// Builder // +///////////// + +#[derive(Clone)] +pub struct Builder<'a> { + store: &'a SphericalStore, + neighbors: &'a SimpleNeighborProviderAsync, +} + +impl HasId for Builder<'_> { + type Id = u32; +} + +impl HasElementRef for Builder<'_> { + type ElementRef<'a> = spherical::iface::Opaque<'a>; +} + +impl<'a> DelegateNeighbor<'a> for Builder<'_> { + type Delegate = &'a SimpleNeighborProviderAsync; + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.neighbors + } +} + +impl BuildDistanceComputer for Builder<'_> { + type DistanceComputerError = AllocatorError; + type DistanceComputer = + UnwrapErr; + + fn build_distance_computer( + &self, + ) -> Result { + self.store.distance_computer().map(UnwrapErr::new) + } +} + +// Pass-through fill — returns `&Self` which directly accesses the underlying provider. +impl workingset::Fill for Builder<'_> { + type Error = std::convert::Infallible; + type View<'a> + = Self + where + Self: 'a; + + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self.clone()) + } +} + +// Pass-through view — reads spherical vectors directly from the provider. +impl workingset::View for Builder<'_> { + type ElementRef<'a> = spherical::iface::Opaque<'a>; + type Element<'a> + = spherical::iface::Opaque<'a> + where + Self: 'a; + + fn get(&self, id: u32) -> Option> { + self.store.get_vector(id.into_usize()).ok() + } +} + ////////////// // Accessor // ////////////// pub struct QuantAccessor<'a, V, D, Ctx> { provider: &'a DefaultProvider, + pub(crate) computer: spherical::iface::QueryComputer, id_buffer: AdjacencyList, - layout: spherical::iface::QueryLayout, - is_search: bool, } impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> @@ -303,14 +370,19 @@ where { pub(crate) fn new( provider: &'a DefaultProvider, + query: &[f32], layout: spherical::iface::QueryLayout, is_search: bool, ) -> Self { + let computer = provider + .aux_vectors + .query_computer(query, layout, is_search) + .unwrap(); + Self { provider, + computer, id_buffer: AdjacencyList::with_capacity(32), - layout, - is_search, } } } @@ -329,9 +401,8 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl Explore for QuantAccessor<'_, V, D, Ctx> +impl Explore for QuantAccessor<'_, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, @@ -351,8 +422,8 @@ where for i in self.provider.starting_points()? { // SAFETY: We're accepting the consequences of potential unsynchronized, // concurrent mutation. - let distance = computer - .evaluate_similarity(self.provider.aux_vectors.get_vector(i.into_usize())?); + let distance = self.computer + .evaluate_similarity(self.provider.aux_vectors.get_vector(i.into_usize())?).unwrap(); f(i, distance); } @@ -396,7 +467,7 @@ where } let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; - let distance = computer.evaluate_similarity(vector); + let distance = self.computer.evaluate_similarity(vector).unwrap(); on_neighbors(distance, *id); } } @@ -407,80 +478,6 @@ where } } -impl HasElementRef for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = spherical::iface::Opaque<'a>; -} - -impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -// impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> -// where -// T: VectorRepr, -// V: AsyncFriendly, -// D: AsyncFriendly, -// Ctx: ExecutionContext, -// { -// type QueryComputerError = Bridge; -// type QueryComputer = -// UnwrapErr; -// -// fn build_query_computer( -// &self, -// query: &[T], -// ) -> Result { -// self.provider -// .aux_vectors -// .query_computer(query, self.layout, self.is_search) -// .bridge_err() -// .map(UnwrapErr::new) -// } -// } - -#[derive(Debug, Error)] -#[error("unconstructible")] -pub enum Infallible {} - -impl From for ANNError { - fn from(_: Infallible) -> Self { - unreachable!("Infallible is an unconstructible enum") - } -} - -impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = AllocatorError; - type DistanceComputer = - UnwrapErr; - - fn build_distance_computer( - &self, - ) -> Result { - self.provider - .aux_vectors - .distance_computer() - .map(UnwrapErr::new) - } -} - impl AsDeletionCheck for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, @@ -544,9 +541,9 @@ where &'a self, provider: &'a FullPrecisionProvider, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, self.layout, self.is_search)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), self.layout, self.is_search)) } } @@ -578,9 +575,9 @@ where &'a self, provider: &'a DefaultProvider, _context: &'a Ctx, - _query: &[T], + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, self.layout, self.is_search)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), self.layout, self.is_search)) } } @@ -602,7 +599,7 @@ where { type DistanceComputer<'a> = UnwrapErr; - type PruneAccessor<'a> = QuantAccessor<'a, V, D, Ctx>; + type PruneAccessor<'a> = Builder<'a>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -615,52 +612,12 @@ where provider: &'a DefaultProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - let build = Self::build(); - Ok(QuantAccessor::new(provider, build.layout, build.is_search)) - } -} - -// Pass-through fill — returns `&Self` which directly accesses the underlying provider. -impl workingset::Fill for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Error = std::convert::Infallible; - type View<'a> - = &'a Self - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - Ok(self) - } -} - -// Pass-through view — reads spherical vectors directly from the provider. -impl workingset::View for &QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = spherical::iface::Opaque<'a>; - type Element<'a> - = spherical::iface::Opaque<'a> - where - Self: 'a; + let builder = Builder { + store: &provider.aux_vectors, + neighbors: provider.neighbors(), + }; - fn get(&self, id: u32) -> Option> { - self.provider.aux_vectors.get_vector(id.into_usize()).ok() + Ok(builder) } } diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index ff62f1ea0..94dce287e 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::{HasId}, + provider::HasId, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 49b75c527..f9c113aef 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -178,9 +178,9 @@ where } } -impl Explore for BetaAccessor +impl Explore for BetaAccessor where - Inner: Explore, + Inner: Explore, { fn starting_points(&self) -> impl Future>> + Send { self.inner.starting_points() @@ -188,23 +188,20 @@ where fn start_point_distances( &mut self, - computer: &Self::QueryComputer, mut f: F, ) -> impl std::future::Future> + Send where F: FnMut(Self::Id, f32) + Send, { let filter = &self.filter; - self.inner - .start_point_distances(computer, move |id, distance| { - f(id, filter.apply(id, distance)); - }) + self.inner.start_point_distances(move |id, distance| { + f(id, filter.apply(id, distance)); + }) } fn expand_beam( &mut self, ids: Itr, - computer: &Self::QueryComputer, pred: P, mut on_neighbors: F, ) -> impl std::future::Future> + Send @@ -214,10 +211,9 @@ where F: FnMut(f32, Self::Id) + Send, { let filter = &self.filter; - self.inner - .expand_beam(ids, computer, pred, move |distance, id| { - on_neighbors(filter.apply(id, distance), id) - }) + self.inner.expand_beam(ids, pred, move |distance, id| { + on_neighbors(filter.apply(id, distance), id) + }) } } @@ -312,13 +308,9 @@ mod tests { "the underlying start points should match", ); - // Build a query computer for the point 0, 0. - let computer = accessor.build_query_computer(&[0.0, 0.0]).unwrap(); - accessor .expand_beam( [0, 5, 10, 15].into_iter(), - &computer, NotIn(&mut visited), |distance, id| buf.push((distance, id)), ) @@ -344,7 +336,7 @@ mod tests { buf.clear(); accessor - .start_point_distances(&computer, |id, distance| buf.push((distance, id))) + .start_point_distances(|id, distance| buf.push((distance, id))) .await .unwrap(); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index a3432d23f..fa446958b 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -999,52 +999,26 @@ impl provider::NeighborAccessorMut for NeighborAccessor<'_> { } } -////////////// -// Accessor // -////////////// +///////////// +// Builder // +///////////// #[derive(Debug)] -pub struct Accessor<'a> { +pub struct Builder<'a> { provider: &'a Provider, - distance: ::QueryDistance, get_vector: LocalCounter<'a>, - /// IDs that will produce transient errors when accessed. transient_ids: Option>>, } -impl<'a> Accessor<'a> { +impl<'a> Builder<'a> { /// Return the underlying [`Provider`] reference. pub fn provider(&self) -> &'a Provider { self.provider } - /// Creates an accessor with no flaky behavior (backward-compatible). - pub fn new(provider: &'a Provider, query: &[f32]) -> Self { - Self::new_inner(provider, query, None) - } - - /// Creates an accessor where `get_element` returns a transient error for - /// any ID in `transient_ids`. The ID must still exist in the provider — - /// accessing a truly missing ID remains a critical `InvalidId` error. - pub fn flaky( - provider: &'a Provider, - query: &[f32], - transient_ids: Cow<'a, HashSet>, - ) -> Self { - Self::new_inner(provider, query, Some(transient_ids)) - } - - fn new_inner( - provider: &'a Provider, - query: &[f32], - transient_ids: Option>>, - ) -> Self { - // FIXME: Return error. - assert_eq!(query.len(), provider.dim()); - let distance = f32::query_distance(query, provider.distance_metric()); + fn new(provider: &'a Provider, transient_ids: Option>>) -> Self { Self { provider, - distance, get_vector: provider.get_vector.local(), transient_ids, } @@ -1060,7 +1034,7 @@ impl<'a> Accessor<'a> { } } - pub fn copy(&mut self, id: u32) -> Result, AccessError> { + fn get(&mut self, id: u32) -> Result, AccessError> { match self.provider.terms.get(&id) { Some(term) => { self.check_flaky(id)?; @@ -1072,35 +1046,27 @@ impl<'a> Accessor<'a> { } } - pub fn get_distance(&mut self, id: u32) -> Result { - match self.provider.terms.get(&id) { - Some(term) => { - self.check_flaky(id)?; - - self.get_vector.increment(); - Ok(self.distance.evaluate_similarity(&*term.data)) - } - None => Err(AccessError::InvalidId(AccessedInvalidId(id))), - } + fn record_get(&mut self) { + self.get_vector.increment() } } -impl provider::HasId for Accessor<'_> { +impl provider::HasId for Builder<'_> { type Id = u32; } -impl provider::HasElementRef for Accessor<'_> { +impl provider::HasElementRef for Builder<'_> { type ElementRef<'a> = &'a [f32]; } -impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { +impl<'a> provider::DelegateNeighbor<'a> for Builder<'_> { type Delegate = NeighborAccessor<'a>; fn delegate_neighbor(&'a mut self) -> Self::Delegate { NeighborAccessor::new(self.provider) } } -impl provider::BuildDistanceComputer for Accessor<'_> { +impl provider::BuildDistanceComputer for Builder<'_> { type DistanceComputerError = Infallible; type DistanceComputer = ::Distance; @@ -1114,13 +1080,127 @@ impl provider::BuildDistanceComputer for Accessor<'_> { } } +//------// +// glue // +//------// + +type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; +type View<'a> = workingset::map::View<'a, u32, Box<[f32]>, workingset::map::Ref<[f32]>>; + +impl workingset::Fill for Builder<'_> { + type Error = ANNError; + type View<'a> + = View<'a> + where + Self: 'a, + WorkingSet: 'a; + + fn fill<'a, Itr>( + &'a mut self, + set: &'a mut WorkingSet, + itr: Itr, + ) -> impl SendFuture, Self::Error>> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + async { + use workingset::map::Entry; + + set.prepare(itr.clone()); + for i in itr { + match set.entry(i) { + Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } + Entry::Vacant(vacant) => { + if let Some(buf) = + self.get(i).allow_transient("transient failures allowed")? + { + vacant.insert(buf); + } + } + } + } + Ok(set.view()) + } + } +} + +////////////// +// Accessor // +////////////// + +#[derive(Debug)] +pub struct Accessor<'a> { + builder: Builder<'a>, + distance: ::QueryDistance, +} + +impl<'a> Accessor<'a> { + /// Return the underlying [`Provider`] reference. + pub fn provider(&self) -> &'a Provider { + self.builder.provider() + } + + /// Creates an accessor with no flaky behavior (backward-compatible). + pub fn new(provider: &'a Provider, query: &[f32]) -> Self { + Self::new_inner(provider, query, None) + } + + /// Creates an accessor where `get_element` returns a transient error for + /// any ID in `transient_ids`. The ID must still exist in the provider — + /// accessing a truly missing ID remains a critical `InvalidId` error. + pub fn flaky( + provider: &'a Provider, + query: &[f32], + transient_ids: Cow<'a, HashSet>, + ) -> Self { + Self::new_inner(provider, query, Some(transient_ids)) + } + + fn new_inner( + provider: &'a Provider, + query: &[f32], + transient_ids: Option>>, + ) -> Self { + // FIXME: Return error. + assert_eq!(query.len(), provider.dim()); + let distance = f32::query_distance(query, provider.distance_metric()); + let builder = Builder::new(provider, transient_ids); + + Self { builder, distance } + } + + pub fn get_distance(&mut self, id: u32) -> Result { + match self.provider().terms.get(&id) { + Some(term) => { + self.builder.check_flaky(id)?; + self.builder.record_get(); + + Ok(self.distance.evaluate_similarity(&*term.data)) + } + None => Err(AccessError::InvalidId(AccessedInvalidId(id))), + } + } +} + +impl provider::HasId for Accessor<'_> { + type Id = u32; +} + //------// // Glue // //------// impl glue::Explore for Accessor<'_> { fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(self.provider.config.start_points.keys().copied().collect()) + futures_util::future::ok( + self.provider() + .config + .start_points + .keys() + .copied() + .collect(), + ) } fn start_point_distances( @@ -1131,7 +1211,7 @@ impl glue::Explore for Accessor<'_> { F: FnMut(Self::Id, f32) + Send, { async move { - for &i in self.provider.config.start_points.keys() { + for &i in self.provider().config.start_points.keys() { f(i, self.get_distance(i).escalate("start points must exist")?) } Ok(()) @@ -1152,7 +1232,7 @@ impl glue::Explore for Accessor<'_> { async move { let mut neighbors = AdjacencyList::new(); for id in ids { - self.provider.get_neighbors(id, &mut neighbors)?; + self.provider().get_neighbors(id, &mut neighbors)?; for &n in neighbors.iter().filter(|i| pred.eval_mut(i)) { if let Some(distance) = self .get_distance(n) @@ -1167,46 +1247,9 @@ impl glue::Explore for Accessor<'_> { } } -type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; -type View<'a> = workingset::map::View<'a, u32, Box<[f32]>, workingset::map::Ref<[f32]>>; - -impl workingset::Fill for Accessor<'_> { - type Error = ANNError; - type View<'a> - = View<'a> - where - Self: 'a, - WorkingSet: 'a; - - fn fill<'a, Itr>( - &'a mut self, - set: &'a mut WorkingSet, - itr: Itr, - ) -> impl SendFuture, Self::Error>> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - async { - use workingset::map::Entry; - - set.prepare(itr.clone()); - for i in itr { - match set.entry(i) { - Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } - Entry::Vacant(vacant) => { - if let Some(buf) = - self.copy(i).allow_transient("transient failures allowed")? - { - vacant.insert(buf); - } - } - } - } - Ok(set.view()) - } - } -} +////////////// +// Strategy // +////////////// #[derive(Debug, Clone)] pub struct Strategy { @@ -1269,7 +1312,7 @@ impl glue::DefaultPostProcessor for Strategy { impl glue::PruneStrategy for Strategy { type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; type DistanceComputer<'a> = ::Distance; - type PruneAccessor<'a> = Accessor<'a>; + type PruneAccessor<'a> = Builder<'a>; type PruneAccessorError = Infallible; fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { @@ -1287,10 +1330,9 @@ impl glue::PruneStrategy for Strategy { provider: &'a Provider, _context: &'a Context, ) -> Result, Self::PruneAccessorError> { - // FIXME: Different Accessors? match &self.transient_ids { - Some(ids) => Ok(Accessor::flaky(provider, &[], Cow::Borrowed(ids))), - None => Ok(Accessor::new(provider, &[])), + Some(ids) => Ok(Builder::new(provider, Some(Cow::Borrowed(ids)))), + None => Ok(Builder::new(provider, None)), } } } @@ -1372,7 +1414,7 @@ impl<'a, 'b, O> glue::SearchPostProcessStep, &'b [f32], O> for Filt B: SearchOutputBuffer + Send + ?Sized, Next: glue::SearchPostProcess + Sync, { - let provider = accessor.provider; + let provider = accessor.provider(); next.post_process( accessor, query,