diff --git a/Cargo.lock b/Cargo.lock index 43e9027bf..ae3b4f94a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,9 +1236,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" dependencies = [ "futures-channel", "futures-core", @@ -1251,9 +1251,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -1261,15 +1261,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -1278,15 +1278,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -1295,15 +1295,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" @@ -1313,9 +1313,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-channel", "futures-core", @@ -1325,7 +1325,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..59c8591fe 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -28,7 +28,6 @@ thiserror.workspace = true tokio = { workspace = true, features = ["rt-multi-thread"] } diskann-vector.workspace = true diskann-wide.workspace = true -diskann-label-filter.workspace = true diskann-tools = { workspace = true } diskann-disk = { workspace = true, optional = true } cfg-if.workspace = true @@ -38,6 +37,7 @@ opentelemetry_sdk = { workspace = true, optional = true } scopeguard = { version = "1.2", optional = true } diskann-benchmark-core = { workspace = true, features = ["bigann"] } itertools.workspace = true +diskann-label-filter.workspace = true [lints] clippy.undocumented_unsafe_blocks = "warn" diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 1dd518781..33407b1c7 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -5,7 +5,6 @@ use std::{ collections::HashMap, - future::Future, num::NonZeroUsize, sync::{ atomic::{AtomicU64, AtomicUsize}, @@ -19,17 +18,12 @@ use diskann::{ error::IntoANNResult, graph::{ self, - glue::{ - self, DefaultPostProcessor, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy, - }, + glue::{self, DefaultPostProcessor, Explore, SearchPostProcess, SearchStrategy}, search::Knn, - search_output_buffer, AdjacencyList, DiskANNIndex, + search_output_buffer, DiskANNIndex, }, neighbor::{Neighbor, NeighborPriorityQueue}, - provider::{ - Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId, - NeighborAccessor, NoopGuard, - }, + provider::{DataProvider, DefaultContext, HasId, NoopGuard}, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, }; @@ -42,7 +36,6 @@ use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; use crate::search::pq::{quantizer_preprocess, PQData, PQScratch}; use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction}; -use futures_util::future; use tokio::runtime::Runtime; use tracing::debug; @@ -296,7 +289,6 @@ where &self, accessor: &mut DiskAccessor<'_, Data, VP>, query: &[Data::VectorDataType], - _computer: &DiskQueryComputer, candidates: I, output: &mut B, ) -> Result @@ -345,7 +337,6 @@ where Data: GraphDataType, ProviderFactory: VertexProviderFactory, { - type QueryComputer = DiskQueryComputer; type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>; type SearchAccessorError = ANNError; @@ -353,6 +344,7 @@ where &'a self, provider: &'a DiskProvider, _context: &DefaultContext, + query: &[Data::VectorDataType], ) -> Result, Self::SearchAccessorError> { DiskAccessor::new( provider, @@ -384,107 +376,6 @@ where } } -/// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates. -pub struct DiskQueryComputer { - num_pq_chunks: usize, - query_centroid_l2_distance: Vec, -} - -impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer { - fn evaluate_similarity(&self, changing: &[u8]) -> f32 { - let mut dist = 0.0f32; - #[allow(clippy::expect_used)] - compute_pq_distance_for_pq_coordinates( - changing, - self.num_pq_chunks, - &self.query_centroid_l2_distance, - std::slice::from_mut(&mut dist), - ) - .expect("PQ distance compute for PQ coordinates is expected to succeed"); - dist - } -} - -impl BuildQueryComputer<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type QueryComputerError = ANNError; - type QueryComputer = DiskQueryComputer; - - fn build_query_computer( - &self, - _from: &[Data::VectorDataType], - ) -> Result { - Ok(DiskQueryComputer { - num_pq_chunks: self.provider.pq_data.get_num_chunks(), - query_centroid_l2_distance: self - .scratch - .pq_scratch - .aligned_pqtable_dist_scratch - .to_vec(), - }) - } - - async fn distances_unordered( - &mut self, - vec_id_itr: Itr, - _computer: &Self::QueryComputer, - f: F, - ) -> Result<(), Self::GetError> - where - F: Send + FnMut(f32, Self::Id), - Itr: Iterator, - { - self.pq_distances(&vec_id_itr.collect::>(), f) - } -} - -impl ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - fn expand_beam( - &mut self, - ids: Itr, - _computer: &Self::QueryComputer, - mut pred: P, - mut f: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: glue::HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - let result = (|| { - let io_limit = self.provider.search_io_limit - self.io_tracker.io_count(); - let load_ids: Box<[_]> = ids.take(io_limit).collect(); - - self.ensure_loaded(&load_ids)?; - let mut ids = Vec::new(); - for i in load_ids { - ids.clear(); - ids.extend( - self.scratch - .vertex_provider - .get_adjacency_list(&i)? - .iter() - .copied() - .filter(|id| pred.eval_mut(id)), - ); - - self.pq_distances(&ids, &mut f)?; - } - - Ok(()) - })(); - - std::future::ready(result) - } -} - // Scratch space for disk search operations that need allocations. // These allocations are amortized across searches using the scratch pool. struct DiskSearchScratch @@ -586,7 +477,7 @@ where } } -impl SearchExt for DiskAccessor<'_, Data, VP> +impl Explore for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, @@ -596,9 +487,64 @@ where Ok(vec![start_vertex_id]) } + async fn start_point_distances( + &mut self, + mut f: F, + ) -> ANNResult<()> + where + F: FnMut(Self::Id, f32) + Send, + { + let start_vertex_id = self.provider.graph_header.metadata().medoid as u32; + self.pq_distances(&[start_vertex_id], |distance, id| f(id, distance))?; + // let vector = self + // .provider + // .pq_data + // .get_compressed_vector(start_vertex_id.into_usize())?; + // let distance = computer.evaluate_similarity(vector); + // f(start_vertex_id, distance); + Ok(()) + } + fn terminate_early(&mut self) -> bool { self.io_tracker.io_count() > self.provider.search_io_limit } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut f: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let result = (|| { + let io_limit = self.provider.search_io_limit - self.io_tracker.io_count(); + let load_ids: Box<[_]> = ids.take(io_limit).collect(); + + self.ensure_loaded(&load_ids)?; + let mut ids = Vec::new(); + for i in load_ids { + ids.clear(); + ids.extend( + self.scratch + .vertex_provider + .get_adjacency_list(&i)? + .iter() + .copied() + .filter(|id| pred.eval_mut(id)), + ); + + self.pq_distances(&ids, &mut f)?; + } + + Ok(()) + })(); + + std::future::ready(result) + } } impl<'a, Data, VP> DiskAccessor<'a, Data, VP> @@ -687,87 +633,6 @@ where type Id = u32; } -impl Accessor for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [u8] - where - Self: 'a; - - /// `ElementRef` can have arbitrary lifetimes. - type ElementRef<'a> = &'a [u8]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = ANNError; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize)) - } -} - -impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - AsNeighborAccessor(self) - } -} - -/// A light-weight wrapper around `&mut DiskAccessor` used to tailor the semantics of -/// [`NeighborAccessor`]. -/// -/// This implementation ensures that the vector data for adjacency lists is also retrieved -/// and cached to enhance reranking. -pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>) -where - Data: GraphDataType, - VP: VertexProvider; - -impl HasId for AsNeighborAccessor<'_, '_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type Id = u32; -} - -impl NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - fn get_neighbors( - self, - id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - if self.0.io_tracker.io_count() > self.0.provider.search_io_limit { - return future::ok(self); // Returning empty results in `neighbors` out param if IO limit is reached. - } - - if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) { - return future::err(e); - } - let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) { - Ok(list) => list, - Err(e) => return future::err(e), - }; - neighbors.overwrite_trusted(list); - future::ok(self) - } -} - /// [`DiskIndexSearcher`] is a helper class to make it easy to construct index /// and do repeated search operations. It is a wrapper around the index. /// This is useful for drivers such as search_disk_index.exe in tools. @@ -922,35 +787,36 @@ where { let provider = self.index.provider(); let mut accessor = strategy - .search_accessor(provider, &DefaultContext) + .search_accessor(provider, &DefaultContext, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); - let mut cmps = 0u32; - - let num_points = provider.num_points as u32; - for id in 0..num_points { - if vector_filter(&id) { - let element = accessor.get_element(id).await.into_ann_result()?; - let dist = computer.evaluate_similarity(element); - best.insert(Neighbor::new(id, dist)); - cmps += 1; - } - } - - let result_count = strategy - .default_post_processor() - .post_process(&mut accessor, query, &computer, best.iter(), output) - .await - .into_ann_result()?; - - Ok(graph::index::SearchStats { - cmps, - hops: 0, - result_count: result_count as u32, - range_search_second_round: false, - }) + unimplemented!(); + // let computer = accessor.build_query_computer(query).into_ann_result()?; + + // let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); + // let mut cmps = 0u32; + + // let num_points = provider.num_points as u32; + // for id in 0..num_points { + // if vector_filter(&id) { + // let element = accessor.get_element(id).await.into_ann_result()?; + // let dist = computer.evaluate_similarity(element); + // best.insert(Neighbor::new(id, dist)); + // cmps += 1; + // } + // } + + // let result_count = strategy + // .default_post_processor() + // .post_process(&mut accessor, query, &computer, best.iter(), output) + // .await + // .into_ann_result()?; + + // Ok(graph::index::SearchStats { + // cmps, + // hops: 0, + // result_count: result_count as u32, + // range_search_second_round: false, + // }) } /// Perform a search on the disk index. diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 0be5be9a4..695650f43 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -10,8 +10,8 @@ use diskann::{ AdjacencyList, SearchOutputBuffer, config::defaults::MAX_OCCLUSION_SIZE, glue::{ - self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + self, DefaultPostProcessor, ExpandBeam, Explore, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchPostProcess, SearchStrategy, }, workingset::{self, map::Entry}, }, @@ -448,7 +448,7 @@ impl HasId for FullAccessor<'_, T> { type Id = u32; } -impl SearchExt for FullAccessor<'_, T> { +impl Explore for FullAccessor<'_, T> { fn starting_points(&self) -> impl Future>> + Send { let points = if self.provider.start_points_exist() { vec![0] diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0d16248dd..64c81620b 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, RwLock}; use diskann::{ error::{ErrorExt, IntoANNResult}, - graph::glue::{ExpandBeam, SearchExt}, + graph::glue::Explore, provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, ANNError, ANNErrorKind, }; @@ -139,9 +139,9 @@ where } } -impl SearchExt for EncodedDocumentAccessor +impl Explore for EncodedDocumentAccessor where - IA: SearchExt, + IA: Explore, { fn starting_points( &self, diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index c5447dca2..ef8d79182 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -36,7 +36,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tempfile = { workspace = true, optional = true } bf-tree = { workspace = true, optional = true } prost = "0.14.1" -futures-util.workspace = true +futures-util = { workspace = true, features = ["async-await"] } serde_json = { workspace = true, optional = true } vfs = { workspace = true, optional = true } diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index f4adf5940..892372ed6 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -174,8 +174,8 @@ pub(crate) mod tests { }, neighbor::Neighbor, provider::{ - AsNeighbor, AsNeighborMut, BuildQueryComputer, DataProvider, DefaultContext, Delete, - ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, + AsNeighbor, AsNeighborMut, DataProvider, DefaultContext, Delete, ExecutionContext, + Guard, NeighborAccessor, NeighborAccessorMut, SetElement, }, utils::{IntoUsize, ONE}, }; @@ -433,25 +433,16 @@ pub(crate) mod tests { Q: Copy + std::fmt::Debug + Send + Sync, { assert!(max_candidates <= groundtruth.len()); - let mut state = index - .start_paged_search(strategy, ¶meters.context, query, parameters.search_l) + let mut search = index + .paged_search(&strategy, ¶meters.context, query, parameters.search_l) .await .unwrap(); - let mut buffer = vec![Neighbor::::default(); parameters.search_k]; let mut iter = 0; let mut seen = 0; while !groundtruth.is_empty() { - let count = index - .next_search_results::( - ¶meters.context, - &mut state, - parameters.search_k, - &mut buffer, - ) - .await - .unwrap(); - for (i, b) in buffer.iter().enumerate().take(count) { + let page = search.next_page(parameters.search_k).await.unwrap(); + for (i, b) in page.iter().enumerate() { let m = is_match(groundtruth, *b, 0.01); match m { None => { @@ -466,7 +457,7 @@ pub(crate) mod tests { b, iter, i, - &buffer[i..], + &page[i..], ); } Some(j) => groundtruth.remove(j), @@ -2324,10 +2315,11 @@ pub(crate) mod tests { // Ensure that the query computer used for insertion uses the `SameAsData` layout. let strategy = inmem::spherical::Quantized::build(); - let accessor = strategy.search_accessor(index.provider(), ctx).unwrap(); - let computer = accessor.build_query_computer(data.row(0)).unwrap(); + let accessor = strategy + .search_accessor(index.provider(), ctx, data.row(0)) + .unwrap(); assert_eq!( - computer.layout(), + accessor.computer.layout(), diskann_quantization::spherical::iface::QueryLayout::SameAsData ); } @@ -2432,11 +2424,10 @@ pub(crate) mod tests { let accessor = , &[f32], - >>::search_accessor(&strategy, index.provider(), ctx) + >>::search_accessor(&strategy, index.provider(), ctx, data.row(0)) .unwrap(); - let computer = accessor.build_query_computer(data.row(0)).unwrap(); assert_eq!( - computer.layout(), + accessor.computer.layout(), diskann_quantization::spherical::iface::QueryLayout::SameAsData ); } @@ -3172,78 +3163,78 @@ pub(crate) mod tests { // Flaky Provider Handling // ///////////////////////////// - // This test uses a "Flaky" accessor that spuriously fails with non-critical errors to - // check that such errors are not propagated by DiskANN. - #[tokio::test] - async fn test_flaky_build() { - let parameters = InitParams { - l_build: 64, - max_degree: 16, - metric: Metric::L2, - batchsize: NonZeroUsize::new(1).unwrap(), - }; - - let start_point = StartPointStrategy::RandomSamples { - nsamples: ONE, - seed: 0xb4de0a1298a86eea, - }; - - // This is the two level index. - let (index, data) = init_from_file( - inmem::test::Flaky::new(9), - parameters, - SIFTSMALL, - 8, - start_point, - ) - .await; - - // There should be one more reachable node than points in the dataset to account for - // the start point. - let neighbor_accessor = &mut index.provider().neighbors(); - assert_eq!( - index - .count_reachable_nodes( - &index.provider().starting_points().unwrap(), - neighbor_accessor - ) - .await - .unwrap(), - data.nrows() + 1, - ); - - let top_k = 10; - let search_l = 32; - let mut ids = vec![0; top_k]; - let mut distances = vec![0.0; top_k]; - - // Here, we use elements of the dataset to search the dataset itself. - // - // We do this for each query, computing the expected ground truth and verifying - // that our simple graph search matches. - // - // Because this dataset is small, we can expect exact equality. - let ctx = &DefaultContext; - for (q, query) in data.row_iter().enumerate() { - let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); - let mut result_output_buffer = - search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); - // Full Precision Search. - index - .search( - graph_search, - &FullPrecision, - ctx, - query, - &mut result_output_buffer, - ) - .await - .unwrap(); - - assert_top_k_exactly_match(q, >, &ids, &distances, top_k); - } - } + // // This test uses a "Flaky" accessor that spuriously fails with non-critical errors to + // // check that such errors are not propagated by DiskANN. + // #[tokio::test] + // async fn test_flaky_build() { + // let parameters = InitParams { + // l_build: 64, + // max_degree: 16, + // metric: Metric::L2, + // batchsize: NonZeroUsize::new(1).unwrap(), + // }; + + // let start_point = StartPointStrategy::RandomSamples { + // nsamples: ONE, + // seed: 0xb4de0a1298a86eea, + // }; + + // // This is the two level index. + // let (index, data) = init_from_file( + // inmem::test::Flaky::new(9), + // parameters, + // SIFTSMALL, + // 8, + // start_point, + // ) + // .await; + + // // There should be one more reachable node than points in the dataset to account for + // // the start point. + // let neighbor_accessor = &mut index.provider().neighbors(); + // assert_eq!( + // index + // .count_reachable_nodes( + // &index.provider().starting_points().unwrap(), + // neighbor_accessor + // ) + // .await + // .unwrap(), + // data.nrows() + 1, + // ); + + // let top_k = 10; + // let search_l = 32; + // let mut ids = vec![0; top_k]; + // let mut distances = vec![0.0; top_k]; + + // // Here, we use elements of the dataset to search the dataset itself. + // // + // // We do this for each query, computing the expected ground truth and verifying + // // that our simple graph search matches. + // // + // // Because this dataset is small, we can expect exact equality. + // let ctx = &DefaultContext; + // for (q, query) in data.row_iter().enumerate() { + // let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); + // let mut result_output_buffer = + // search_output_buffer::IdDistance::new(&mut ids, &mut distances); + // let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + // // Full Precision Search. + // index + // .search( + // graph_search, + // &FullPrecision, + // ctx, + // query, + // &mut result_output_buffer, + // ) + // .await + // .unwrap(); + + // assert_top_k_exactly_match(q, >, &ids, &distances, top_k); + // } + // } async fn create_retry_saturated_index( retry: NonZeroU32, @@ -3280,55 +3271,55 @@ pub(crate) mod tests { Ok(index) } - #[tokio::test] - async fn test_saturate_index() { - let index_sat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), true) - .await - .unwrap(); - let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider()); - let res_sat = index_sat - .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) - .await - .unwrap(); - - let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) - .await - .unwrap(); - let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); - let res_unsat = index_unsat - .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) - .await - .unwrap(); - assert!( - res_sat.avg_degree > res_unsat.avg_degree, - "Saturated index should have higher average degree than the unsaturated index" - ); - } - - #[tokio::test] - async fn test_retry_index() { - let index_sat = create_retry_saturated_index(NonZeroU32::new(3).unwrap(), false) - .await - .unwrap(); - let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider()); - let res_sat = index_sat - .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) - .await - .unwrap(); - - let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) - .await - .unwrap(); - let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); - let res_unsat = index_sat - .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) - .await - .unwrap(); - assert!( - res_sat.avg_degree > res_unsat.avg_degree, - "Saturated index should have higher average degree than the unsaturated index" - ); - } + // #[tokio::test] + // async fn test_saturate_index() { + // let index_sat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), true) + // .await + // .unwrap(); + // let mut accessor_sat = inmem::Builder::new(index_sat.provider(), query); + // let res_sat = index_sat + // .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) + // .await + // .unwrap(); + + // let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) + // .await + // .unwrap(); + // let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); + // let res_unsat = index_unsat + // .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) + // .await + // .unwrap(); + // assert!( + // res_sat.avg_degree > res_unsat.avg_degree, + // "Saturated index should have higher average degree than the unsaturated index" + // ); + // } + + // #[tokio::test] + // async fn test_retry_index() { + // let index_sat = create_retry_saturated_index(NonZeroU32::new(3).unwrap(), false) + // .await + // .unwrap(); + // let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider()); + // let res_sat = index_sat + // .get_degree_stats(&mut accessor_sat, index_sat.provider().iter()) + // .await + // .unwrap(); + + // let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false) + // .await + // .unwrap(); + // let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider()); + // let res_unsat = index_sat + // .get_degree_stats(&mut accessor_unsat, index_unsat.provider().iter()) + // .await + // .unwrap(); + // assert!( + // res_sat.avg_degree > res_unsat.avg_degree, + // "Saturated index should have higher average degree than the unsaturated index" + // ); + // } #[cfg(feature = "experimental_diversity_search")] #[tokio::test] diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index e8d26a609..fa0814ead 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -13,13 +13,14 @@ use diskann::{ Batch, DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, PruneStrategy, SearchStrategy, }, - index::{DegreeStats, PagedSearchState, PartitionedNeighbors, SearchState}, + index::{DegreeStats, PartitionedNeighbors}, search_output_buffer, }, neighbor::Neighbor, provider::{AsNeighbor, AsNeighborMut, DataProvider, Delete, SetElement}, utils::ONE, }; +use diskann_utils::Reborrow; use crate::storage::{LoadWith, StorageReadProvider}; @@ -339,59 +340,73 @@ where ) } - #[allow(clippy::type_complexity)] - pub fn start_paged_search( - &self, - strategy: S, - context: &DP::Context, - query: T, - l_value: usize, - ) -> ANNResult> - where - S: SearchStrategy + 'static, - T: Copy + Send, - { - self.handle.block_on( - self.inner - .start_paged_search(strategy, context, query, l_value), - ) - } - - #[allow(clippy::type_complexity)] - pub fn start_paged_search_with_init_ids( + // /// Begin a paged search over the index (synchronous wrapper). + // /// + // /// Returns a [`PagedSearch`] handle. See + // /// [`PagedSearch::next_page`] for retrieving results. + // pub fn paged_search<'a, S, T>( + // &'a self, + // strategy: S, + // context: &'a DP::Context, + // query: T, + // l_value: usize, + // ) -> ANNResult> + // where + // S: SearchStrategy + 'static, + // T: Copy + Send + 'a, + // { + // let inner = self + // .handle + // .block_on(self.inner.paged_search(strategy, context, query, l_value))?; + // Ok(PagedSearch { + // handle: self.handle.clone(), + // inner, + // }) + // } + + // /// Begin a paged search with explicit initial seed IDs (synchronous wrapper). + // pub fn paged_search_with_init_ids<'a, S, T>( + // &'a self, + // strategy: S, + // context: &'a DP::Context, + // query: T, + // l_value: usize, + // init_ids: Option<&'a [DP::InternalId]>, + // ) -> ANNResult> + // where + // S: SearchStrategy + 'static, + // T: Copy + Send + 'a, + // { + // let inner = self.handle.block_on( + // self.inner + // .paged_search_with_init_ids(strategy, context, query, l_value, init_ids), + // )?; + // Ok(PagedSearch { + // handle: self.handle.clone(), + // inner, + // }) + // } + + /// Begin a synchronous paged search over the index. + /// + /// This will construct a [`noawait::PagedSearch`] and initialize search with the + /// providers start points. Pages can be retrieved with [`noawait::PagedSearch::next`]. + /// + /// **Caution**: This method should only be used if is known that all functions reachable + /// via the implementation of [`SearchStrategy`] are known to be synchronous and never + /// truly await. This allows [`noawait::PagedSearch`] to be much more efficient. + pub fn paged_search_no_await( &self, strategy: S, - context: &DP::Context, + context: DP::Context, query: T, l_value: usize, - init_ids: Option<&[DP::InternalId]>, - ) -> ANNResult> - where - S: SearchStrategy + 'static, - T: Copy + Send, - { - self.handle.block_on( - self.inner - .start_paged_search_with_init_ids(strategy, context, query, l_value, init_ids), - ) - } - - pub fn next_search_results( - &self, - context: &DP::Context, - search_state: &mut SearchState, - k: usize, - result_output: &mut [Neighbor], - ) -> ANNResult + ) -> ANNResult> where - S: SearchStrategy, + T: for<'a> Reborrow<'a, Target: Copy + Send> + 'static, + S: for<'a> SearchStrategy>::Target> + 'static, { - self.handle.block_on(self.inner.next_search_results( - context, - search_state, - k, - result_output, - )) + noawait::PagedSearch::new(self.inner.clone(), strategy, context, query, l_value) } pub fn count_reachable_nodes( @@ -416,6 +431,233 @@ where } } +// /// Synchronous wrapper around [`graph::search::PagedSearch`] that owns a tokio runtime handle. +// /// +// /// Created by [`DiskANNIndex::paged_search`]. Each call to [`next_page`](Self::next_page) +// /// blocks the current thread to drive the underlying async search forward. +// pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { +// handle: tokio::runtime::Handle, +// inner: graph::search::PagedSearch<'a, DP, S, T>, +// } +// +// impl<'a, DP, S, T> PagedSearch<'a, DP, S, T> +// where +// DP: DataProvider, +// S: SearchStrategy, +// T: Copy + Send, +// { +// /// Returns the next page of at most `k` nearest-neighbor results. +// /// +// /// Blocks the current thread. Returns an empty `Vec` when the search is exhausted. +// pub fn next_page(&mut self, k: usize) -> ANNResult>> { +// self.handle.block_on(self.inner.next_page(k)) +// } +// } + +pub mod noawait { + //! Implementations of a synchronous wrapper around [`diskann::graph::DiskANNIndex`] that + //! assume the [`Accessor`] and associated implementations never truly `await` and are + //! in fact synchronous. + //! + //! With this assumption, we can perform lighter-weight communication with the index + //! by assuming that each `poll` returns ready. + //! + //! **Do not use this if your index ever actually await**: Doing so will lead to deadlock! + + use super::*; + + use std::{ + cell::RefCell, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, + }; + + use diskann::{ANNErrorKind, utils::VectorId}; + use diskann_utils::Reborrow; + use thiserror::Error; + + type Input = Rc>>; + type Output = Rc>>>>; + + fn channel() -> (Input, Output) + where + I: VectorId, + { + let input = Rc::new(RefCell::new(None)); + let output = Rc::new(RefCell::new(None)); + (input, output) + } + + fn step(fut: Pin<&mut dyn Future>) -> Option { + let mut cx = Context::from_waker(Waker::noop()); + match fut.poll(&mut cx) { + Poll::Ready(v) => Some(v), + Poll::Pending => None, + } + } + + /// A synchronous wrapper for [`graph::search::PagedSearch`] + /// + /// See: [`super::DiskANNIndex::paged_search_no_await`]. + pub struct PagedSearch { + // The `input` is wrapped in an `Option` so we can fuse `searcher` if it exits + // with an error. Polling a completed future risk panicking. + // + // We construct `searcher` to pull its next-page size from this input. + input: Option, + + // Output yielded from polling `searcher`. + output: Output, + + // We shut down the future by running `Drop`. Thus, the only way it can actually + // finish is if it returns with an error. + searcher: Pin>>, + } + + impl PagedSearch + where + I: VectorId, + { + /// Construct a new [`PagedSearch`]. + /// + /// This works by creating a small async task using [`graph::search::PagedSearch`] + /// internally. The requested k-nearest neighors are sent using a `Rc>` + /// channel and the actual neighbors are retrieved from a similar data structure. + /// + /// Under the assumption that the implementation of [`graph::search::PagedSearch`]'s + /// implementations are fully synchronous, we can directly poll this task instead + /// of going through a runtime since we (theoretically) control the only suspension + /// point. + /// + /// Doing so allows stepping the task state machine to be done with a single function + /// call to `Future::poll`. + /// + /// Obviously, if the "noawait" assumption is broken, then the inner async job may + /// yield before our control point, but we can detect this situation since no output + /// will be generated on the output channel. + /// + /// We rely on `Drop` to clean up the paged search resources. + pub(super) fn new( + index: Arc>, + strategy: S, + context: DP::Context, + query: T, + l_value: usize, + ) -> ANNResult + where + DP: DataProvider, + T: for<'a> Reborrow<'a, Target: Copy + Send> + 'static, + S: for<'a> SearchStrategy>::Target> + 'static, + { + // Prepare the input and output channels used to communicate with the search task. + let (input, output) = channel::(); + let input_clone = input.clone(); + let output_clone = output.clone(); + + // Create the search task. + let mut searcher: Pin>> = Box::pin(async move { + // The assumption of `noawait` is that this call will always resolve to + // `Poll::Ready`. + let mut state = match index + .paged_search(&strategy, &context, query.reborrow(), l_value) + .await + { + Ok(state) => state, + Err(err) => return err, + }; + + loop { + // This is the await point that pauses the future. + // + // Under the "noawait" assumption, this should be the only point where + // this future ever yields `Pending` and is where we expect the future + // to stop every time we poll it. + futures_util::pending!(); + + // We control the invocation of poll and should always ensure that + // input is available. + let k_value = match input_clone.take() { + Some(value) => value, + None => return InternalInvariantViolated::MissingInput.into(), + }; + + // Step paged search and propagate any errors. + let page = match state.next_page(k_value).await { + Ok(page) => page, + Err(err) => return err, + }; + + // Send output to the caller. + output_clone.replace(Some(page)); + } + }); + + // Drive the inner future one step to initialize paged search. + if let Some(err) = step(searcher.as_mut()) { + return Err(err); + } + + let this = Self { + input: Some(input), + output, + searcher, + }; + Ok(this) + } + + /// Retrieve the next results from paged search, returning any errors. + /// + /// If [`next`](Self::next) previously returned with an error, it will continue + /// to do so. + pub fn next(&mut self, k: usize) -> ANNResult>> { + // Prepare input. We use the presence of the input channel to decide whether + // or not it is safe to poll search task. + match self.input.as_ref() { + Some(input) => input.replace(Some(k)), + None => { + return Err(ANNError::message( + ANNErrorKind::Opaque, + "paged searcher errored and is no longer runnable", + )); + } + }; + + // Progress the future. + // + // The only reason to return return `Some` is if the inner future aborts with + // an error. Here, we fuse the searcher to prevent panics on re-enters and + // forward the error. + if let Some(result) = step(self.searcher.as_mut()) { + self.input = None; + return Err(result); + } + + // Profit! + match self.output.take() { + Some(v) => Ok(v), + None => Err(InternalInvariantViolated::MissingOutput.into()), + } + } + } + + #[derive(Debug, Clone, Copy, Error)] + enum InternalInvariantViolated { + #[error("INTERNAL: input channel was not configured")] + MissingInput, + #[error("noawait contract violated: future suspended before expected yield point")] + MissingOutput, + } + + impl From for ANNError { + #[track_caller] + #[cold] + fn from(err: InternalInvariantViolated) -> Self { + Self::new(ANNErrorKind::Opaque, err) + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -538,4 +780,120 @@ mod tests { assert_eq!(ids[0], 0); assert_eq!(distances[0], 0.0); } + + fn wrapped_test_provider() -> DiskANNIndex { + let provider = + graph::test::provider::Provider::grid(graph::test::synthetic::Grid::One, 100).unwrap(); + + DiskANNIndex::new_with_current_thread_runtime( + graph::config::Builder::new( + provider.max_degree(), + diskann::graph::config::MaxDegree::same(), + 100, + (Metric::L2).into(), + ) + .build() + .unwrap(), + provider, + ) + } + + // Test the `noawait` paged searcher. + // + // This relies on the test-provider being no-await. + #[test] + fn test_paged_search_noawait() { + let index = wrapped_test_provider(); + + for page_size in [1, 5, 9, 12] { + let mut paged = index + .paged_search_no_await::<_, Vec>( + graph::test::provider::Strategy::new(), + graph::test::provider::Context::new(), + vec![0.0], + 10.max(page_size), + ) + .unwrap(); + + let mut i = 0u32; + loop { + let v = paged.next(page_size).unwrap(); + assert!( + v.len() <= page_size, + "candidates returned ({}) exceeded page size ({})", + v.len(), + page_size, + ); + + if v.is_empty() { + break; + } + + for neighbor in v { + assert_ne!( + neighbor.id, + u32::MAX, + "paged search should not return start point", + ); + assert_eq!( + neighbor.id, i, + "monotonicity should at least hold for the 1d grid" + ); + assert_eq!( + neighbor.distance, + (i as f32) * (i as f32), + "distance was computed incorrectly!", + ); + i += 1; + } + } + + // Search is exhausted - make sure that subsequent searches yield empty vectors. + let exhausted = paged.next(5).unwrap(); + assert!( + exhausted.is_empty(), + "expected an empty vector when exhausted - instead got {:?}", + exhausted + ); + } + } + + // Verify that the searcher is properly fused when it returns with an error. + #[test] + fn test_paged_search_noawait_fuse() { + let index = wrapped_test_provider(); + + // To do this test, we request more neighbors than the search-L, which triggers + // an inner error. + let search_l = 10; + let bigger_than_search_l = 20; + + let mut paged = index + .paged_search_no_await::<_, Vec>( + graph::test::provider::Strategy::new(), + graph::test::provider::Context::new(), + vec![0.0], + search_l, + ) + .unwrap(); + + let expected = "search_param_l"; + let err = paged.next(bigger_than_search_l).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains(expected), + "expected error message to contain \"{}\" - instead got\n\n{}", + expected, + msg, + ); + + // Now that we've yielded an error - the next time we request pages should also error. + let err = paged.next(10).unwrap_err(); + let err_msg = err.to_string(); + assert!( + err_msg.contains("paged searcher errored"), + "unexpected error message:\n\n{}", + err_msg + ); + } } diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 7838d403d..f8e48340e 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -20,16 +20,15 @@ use diskann::{ graph::{ AdjacencyList, DiskANNIndex, SearchOutputBuffer, glue::{ - self, Batch, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - MultiInsertStrategy, PruneStrategy, SearchExt, SearchStrategy, + self, Batch, DefaultPostProcessor, ExpandBeam, Explore, InplaceDeleteStrategy, + InsertStrategy, MultiInsertStrategy, PruneStrategy, SearchStrategy, }, workingset::{self, map}, }, neighbor::Neighbor, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, - DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, - NoopGuard, SetElement, + Accessor, BuildDistanceComputer, DataProvider, DefaultContext, DelegateNeighbor, Delete, + ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::{IntoUsize, VectorRepr}, }; @@ -911,7 +910,6 @@ where /// /// * [`Accessor`] for the [`BfTreeProvider`]. /// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. /// pub struct FullAccessor<'a, T, Q, D> where @@ -934,7 +932,7 @@ where type Id = u32; } -impl SearchExt for FullAccessor<'_, T, Q, D> +impl Explore for FullAccessor<'_, T, Q, D> where T: VectorRepr, Q: AsyncFriendly, @@ -1034,22 +1032,22 @@ where } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, -{ - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) - } -} +// impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D> +// where +// T: VectorRepr, +// Q: AsyncFriendly, +// D: AsyncFriendly, +// { +// type QueryComputerError = Panics; +// type QueryComputer = T::QueryDistance; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// Ok(T::query_distance(from, self.provider.metric)) +// } +// } impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D> where T: VectorRepr, @@ -1098,7 +1096,7 @@ where type Id = u32; } -impl SearchExt for QuantAccessor<'_, T, D> +impl Explore for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, @@ -1415,7 +1413,7 @@ where Q: AsyncFriendly, D: AsyncFriendly + DeletionCheck, { - type QueryComputer = T::QueryDistance; + // type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D>; type SearchAccessorError = Panics; @@ -1494,7 +1492,7 @@ where T: VectorRepr, D: AsyncFriendly + DeletionCheck, { - type QueryComputer = pq::distance::QueryComputer>; + // type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, T, D>; type SearchAccessorError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e72f323ef..cfa27192d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -10,23 +10,23 @@ use diskann::{ ANNError, ANNResult, error::Infallible, graph::{ - SearchOutputBuffer, + AdjacencyList, SearchOutputBuffer, glue::{ - self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchStrategy, }, workingset, }, neighbor::Neighbor, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, + BuildDistanceComputer, DefaultContext, DelegateNeighbor, ExecutionContext, HasElementRef, + HasId, }, utils::{IntoUsize, VectorRepr}, }; use diskann_utils::future::AsyncFriendly; -use diskann_vector::{DistanceFunction, distance::Metric}; +use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; use crate::model::graph::provider::async_::{ FastMemoryVectorProviderAsync, SimpleNeighborProviderAsync, @@ -102,218 +102,249 @@ where } } -////////////////// -// FullAccessor // -////////////////// +///////////// +// Builder // +///////////// -/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`ComputerAccessor`] for comparing full-precision distances. -/// * [`BuildQueryComputer`]. -pub struct FullAccessor<'a, T, Q, D, Ctx> +#[derive(Clone)] +pub struct Builder<'a, T> where T: VectorRepr, { - /// The host provider. - provider: &'a FullPrecisionProvider, - - /// A buffer for resolving iterators given during bulk operations. - /// - /// The accessor reuses this allocation to amortize allocation cost over multiple bulk - /// operations. - id_buffer: Vec, + metric: Metric, + store: &'a FastMemoryVectorProviderAsync, + neighbors: &'a SimpleNeighborProviderAsync, } -impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> +impl HasId for Builder<'_, T> where T: VectorRepr, { - type Repr = T; - fn as_full_precision(&self) -> &FullPrecisionStore { - &self.provider.base_vectors - } + type Id = u32; } -impl HasId for FullAccessor<'_, T, Q, D, Ctx> +impl HasElementRef for Builder<'_, T> where T: VectorRepr, { - type Id = u32; + type ElementRef<'a> = &'a [T]; } -impl SearchExt for FullAccessor<'_, T, Q, D, Ctx> +impl<'a, T> DelegateNeighbor<'a> for Builder<'_, T> where T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, { - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) + type Delegate = &'a SimpleNeighborProviderAsync; + + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.neighbors } } -impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> +impl BuildDistanceComputer for Builder<'_, T> where T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, { - pub fn new(provider: &'a FullPrecisionProvider) -> Self { - Self { - provider, - id_buffer: Vec::new(), - } + type DistanceComputerError = Panics; + type DistanceComputer = T::Distance; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(T::distance( + self.metric, + Some(self.store.dim()), + )) } } -impl<'a, T, Q, D, Ctx> DelegateNeighbor<'a> for FullAccessor<'_, T, Q, D, Ctx> +// All this does is return a `&Self` - which directly accesses the underlying provider. +impl workingset::Fill for Builder<'_, T> where T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, { - type Delegate = &'a SimpleNeighborProviderAsync; + type Error = Infallible; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() + type View<'a> + = Builder<'a, T> + where + Self: 'a; + + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self.clone()) } } -impl Accessor for FullAccessor<'_, T, Q, D, Ctx> +// Pass-through view. +impl workingset::View for Builder<'_, T> where T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. + type ElementRef<'a> = &'a [T]; type Element<'a> = &'a [T] where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = &'a [T]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) + fn get(&self, id: u32) -> Option<&[T]> { + // SAFETY: This is unsound. We assume no concurrent writes to this slot, but + // this invariant is not enforced. See `get_vector_sync` for details + Some(unsafe { self.store.get_vector_sync(id.into_usize()) }) } +} - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.base_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.base_vectors.prefetch_hint(id.into_usize()); - } +////////////////// +// FullAccessor // +////////////////// - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .base_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } +/// An accessor for retrieving full-precision vectors from the `DefaultProvider`. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the [`DefaultProvider`]. +/// * [`ComputerAccessor`] for comparing full-precision distances. +pub struct FullAccessor<'a, T, Q, D, Ctx> +where + T: VectorRepr, +{ + /// The host provider. + provider: &'a FullPrecisionProvider, - // Invoke the passed closure on the full-precision vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }, - *id, - ) - } + /// The distance computer. + computer: ::QueryDistance, - std::future::ready(Ok(())) - } + /// A buffer for resolving iterators given during bulk operations. + /// + /// The accessor reuses this allocation to amortize allocation cost over multiple bulk + /// operations. + id_buffer: AdjacencyList, } -impl BuildDistanceComputer for FullAccessor<'_, T, Q, D, Ctx> +impl GetFullPrecision for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, { - type DistanceComputerError = Panics; - type DistanceComputer = T::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(T::distance( - self.provider.metric, - Some(self.provider.base_vectors.dim()), - )) + type Repr = T; + fn as_full_precision(&self) -> &FullPrecisionStore { + &self.provider.base_vectors } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +impl HasId for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, +{ + type Id = u32; +} + +impl Explore for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { - type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - Ok(T::query_distance(from, self.provider.metric)) + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = self.computer.evaluate_similarity(unsafe { + self.provider.base_vectors.get_vector_sync(i.into_usize()) + }); + + f(i, distance); + } + Ok(()) + } + } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let id_buffer = &mut self.id_buffer; + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), id_buffer)?; + + id_buffer.retain(|i| pred.eval_mut(i)); + let len = id_buffer.len(); + let lookahead = self.provider.base_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.base_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .base_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + // Invoke the passed closure on the full-precision vector. + // + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let v = unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }; + let distance = self.computer.evaluate_similarity(v); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +impl<'a, T, Q, D, Ctx> FullAccessor<'a, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { + pub fn new( + provider: &'a FullPrecisionProvider, + query: &[T], + ) -> Self { + Self { + provider, + computer: T::query_distance(query, provider.metric), + id_buffer: AdjacencyList::new(), + } + } } //-------------------// @@ -353,7 +384,7 @@ pub struct Rerank; impl<'a, A, T> glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<&'a [T], Id = u32> + GetFullPrecision + AsDeletionCheck, + A: HasId + GetFullPrecision + AsDeletionCheck, { type Error = Panics; @@ -361,7 +392,6 @@ where &self, accessor: &mut A, query: &'a [T], - _computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send @@ -410,7 +440,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = T::QueryDistance; + // type QueryComputer = T::QueryDistance; type SearchAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; type SearchAccessorError = Panics; @@ -418,8 +448,9 @@ where &'a self, provider: &'a FullPrecisionProvider, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) + Ok(FullAccessor::new(provider, query)) } } @@ -442,7 +473,7 @@ where Ctx: ExecutionContext, { type DistanceComputer<'a> = T::Distance; - type PruneAccessor<'a> = FullAccessor<'a, T, Q, D, Ctx>; + type PruneAccessor<'a> = Builder<'a, T>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -451,7 +482,12 @@ where provider: &'a FullPrecisionProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) + let builder = Builder { + metric: provider.metric, + store: &provider.base_vectors, + neighbors: provider.neighbors(), + }; + Ok(builder) } fn create_working_set(&self, _capacity: usize) -> Self::WorkingSet { @@ -459,55 +495,6 @@ where } } -// All this does is return a `&Self` - which directly accesses the underlying provider. -impl workingset::Fill for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Error = Infallible; - - type View<'a> - = &'a Self - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - Ok(self) - } -} - -// Pass-through view. -impl workingset::View for &FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = &'a [T]; - type Element<'a> - = &'a [T] - where - Self: 'a; - - fn get(&self, id: u32) -> Option<&[T]> { - // SAFETY: This is unsound. We assume no concurrent writes to this slot, but - // this invariant is not enforced. See `get_vector_sync` for details - Some(unsafe { self.provider.base_vectors.get_vector_sync(id.into_usize()) }) - } -} - impl InsertStrategy, &[T]> for FullPrecision where T: VectorRepr, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index b3867ad85..5838e1d52 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -9,20 +9,18 @@ use diskann::default_post_processor; use diskann::{ ANNError, ANNResult, graph::{ + AdjacencyList, glue::{ - self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, InplaceDeleteStrategy, InsertStrategy, + PruneStrategy, SearchStrategy, }, workingset, }, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, - }, + provider::{BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId}, utils::{IntoUsize, VectorRepr}, }; use diskann_utils::future::AsyncFriendly; -use diskann_vector::distance::Metric; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use crate::model::{ graph::provider::async_::{ @@ -82,280 +80,118 @@ where } /////////////////// -// QuantAccessor // +// Quant Builder // /////////////////// -/// An accessor that retrieves the quantized portion of the [`DefaultProvider`]. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the `DefaultProvider`. -/// * [`BuildQueryComputer`]. -pub struct QuantAccessor<'a, V, D, Ctx> { - provider: &'a DefaultProvider, +#[derive(Clone)] +pub struct QuantBuilder<'a> { + provider: &'a FastMemoryQuantVectorProviderAsync, + neighbors: &'a SimpleNeighborProviderAsync, } -impl GetFullPrecision for QuantAccessor<'_, FullPrecisionStore, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { - &self.provider.base_vectors - } -} - -impl HasId for QuantAccessor<'_, V, D, Ctx> { +impl HasId for QuantBuilder<'_> { type Id = u32; } -impl SearchExt for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - pub(crate) fn new(provider: &'a DefaultProvider) -> Self { - Self { provider } - } +impl HasElementRef for QuantBuilder<'_> { + type ElementRef<'a> = &'a [u8]; } -impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ +impl<'a> DelegateNeighbor<'a> for QuantBuilder<'_> { type Delegate = &'a SimpleNeighborProviderAsync; fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() + self.neighbors } } -impl Accessor for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [u8] - where - Self: 'a; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = &'a [u8]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// Return the quantized vector stored at index `i`. - /// - /// This function always completes synchronously. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.aux_vectors.get_vector_sync(id.into_usize()) - })) - } - - /// Perform a bulk operation. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), - { - for i in itr { - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f( - unsafe { self.provider.aux_vectors.get_vector_sync(i.into_usize()) }, - i, - ) - } - std::future::ready(Ok(())) - } -} - -impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = ANNError; - type QueryComputer = pq::distance::QueryComputer>; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - self.provider.aux_vectors.query_computer(from) - } -} - -impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ +impl BuildDistanceComputer for QuantBuilder<'_> { type DistanceComputerError = ANNError; type DistanceComputer = pq::distance::DistanceComputer>; fn build_distance_computer( &self, ) -> Result { - Ok(self.provider.aux_vectors.distance_computer()) + Ok(self.provider.distance_computer()) } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ +impl workingset::Fill for QuantBuilder<'_> { + type Error = std::convert::Infallible; + type View<'a> + = QuantBuilder<'a> + where + Self: 'a; + + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self.clone()) + } } -//-------------------// -// In-mem Extensions // -//-------------------// +// Pass-through view — reads PQ codes directly from the provider. +impl workingset::View for QuantBuilder<'_> { + type ElementRef<'a> = &'a [u8]; + type Element<'a> + = &'a [u8] + where + Self: 'a; -impl<'a, V, D, Ctx> AsDeletionCheck for QuantAccessor<'a, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Checker = D; - fn as_deletion_check(&self) -> &D { - &self.provider.deleted + fn get(&self, id: u32) -> Option<&[u8]> { + // SAFETY: This is unsound. We assume no concurrent writes to this slot, but + // this invariant is not enforced. See `get_vector_sync` for details. + Some(unsafe { self.provider.get_vector_sync(id.into_usize()) }) } } -///////////////////// -// Hybrid Accessor // -///////////////////// +/////////////////// +// HybridBuilder // +/////////////////// -/// A hybrid accessor that fetches a mixture of full-precision and quantized vectors during -/// pruning. This allows the application to trade full-precision fetches for accuracy. -/// -/// This type implements the following traits: -/// -/// * [`Accessor`] for the [`DefaultProvider`]. -/// * [`BuildDistanceComputer`] for computing distances among [`distances::pq::Hybrid`] -/// element types. -/// * [`Fill`] for populating a mixture of full-precision and quant vectors. -pub struct HybridAccessor<'a, T, D, Ctx> +#[derive(Clone)] +pub struct HybridBuilder<'a, T> where T: VectorRepr, { - provider: &'a FullPrecisionProvider, - - /// Maximum number of full-precision vectors to use during pruning. - /// This field is ignored during search, where full-precision vectors are never used. + full: &'a FastMemoryVectorProviderAsync, + quant: &'a FastMemoryQuantVectorProviderAsync, + neighbors: &'a SimpleNeighborProviderAsync, max_fp_vecs_per_prune: usize, } -impl<'a, T, D, Ctx> HybridAccessor<'a, T, D, Ctx> +impl HasId for HybridBuilder<'_, T> where T: VectorRepr, { - fn new( - provider: &'a FullPrecisionProvider, - max_fp_vecs_per_prune: usize, - ) -> Self { - Self { - provider, - max_fp_vecs_per_prune, - } - } + type Id = u32; } -impl HasId for HybridAccessor<'_, T, D, Ctx> +impl HasElementRef for HybridBuilder<'_, T> where T: VectorRepr, { - type Id = u32; + type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; } -impl<'a, T, D, Ctx> DelegateNeighbor<'a> for HybridAccessor<'_, T, D, Ctx> +impl<'a, T> DelegateNeighbor<'a> for HybridBuilder<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type Delegate = &'a SimpleNeighborProviderAsync; fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() + self.neighbors } } -impl Accessor for HybridAccessor<'_, T, D, Ctx> +impl BuildDistanceComputer for HybridBuilder<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// The [`distances::pq::Hybrid`] is an enum consisting of either a full-precision - /// vector or a quantized vector. - /// - /// This accessor can return either. - type Element<'a> - = distances::pq::Hybrid<&'a [T], &'a [u8]> - where - Self: 'a; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = Panics; - - /// The default behavior of `get_element` returns a full-precision vector. The - /// implementation of [`Fill`] is how the `max_fp_vecs_per_fill` is used. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - distances::pq::Hybrid::Full(self.provider.base_vectors.get_vector_sync(id.into_usize())) - })) - } -} - -impl BuildDistanceComputer for HybridAccessor<'_, T, D, Ctx> -where - T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type DistanceComputerError = ANNError; type DistanceComputer = distances::pq::HybridComputer; @@ -363,10 +199,10 @@ where fn build_distance_computer( &self, ) -> Result { - let metric = self.provider.aux_vectors.metric(); + let metric = self.quant.metric(); Ok(distances::pq::HybridComputer::new( - self.provider.aux_vectors.distance_computer(), - T::distance(metric, Some(self.provider.base_vectors.dim())), + self.quant.distance_computer(), + T::distance(metric, Some(self.full.dim())), )) } } @@ -385,15 +221,13 @@ impl workingset::AsWorkingSet for Unseeded { // Selective fill — the first `max_fp_vecs_per_prune` candidates receive full-precision // vectors; the remainder are accessed through the pass-through `MaybeFullPrecision` view. -impl workingset::Fill for HybridAccessor<'_, T, D, Ctx> +impl workingset::Fill for HybridBuilder<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type Error = std::convert::Infallible; type View<'a> - = MaybeFullPrecision<'a, T, D, Ctx> + = MaybeFullPrecision<'a, T> where Self: 'a; @@ -409,25 +243,23 @@ where state.0.clear(); state.0.extend(itr.take(self.max_fp_vecs_per_prune)); Ok(MaybeFullPrecision { - accessor: self, + builder: self, full: state, }) } } -pub struct MaybeFullPrecision<'a, T, D, Ctx> +pub struct MaybeFullPrecision<'a, T> where T: VectorRepr, { - accessor: &'a HybridAccessor<'a, T, D, Ctx>, + builder: &'a HybridBuilder<'a, T>, full: &'a FullPrecisionTracker, } -impl workingset::View for MaybeFullPrecision<'_, T, D, Ctx> +impl workingset::View for MaybeFullPrecision<'_, T> where T: VectorRepr, - D: AsyncFriendly, - Ctx: ExecutionContext, { type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; type Element<'a> @@ -436,50 +268,398 @@ where Self: 'a; fn get(&self, id: u32) -> Option> { - let provider = &self.accessor.provider; let element = if self.full.0.contains(&id) { // SAFETY: This is unsound. We assume no concurrent writes to this slot, but // this invariant is not enforced. See `get_vector_sync` for details. unsafe { - distances::pq::Hybrid::Full(provider.base_vectors.get_vector_sync(id.into_usize())) + distances::pq::Hybrid::Full(self.builder.full.get_vector_sync(id.into_usize())) } } else { // SAFETY: This is unsound. We assume no concurrent writes to this slot, but // this invariant is not enforced. See `get_vector_sync` for details. unsafe { - distances::pq::Hybrid::Quant(provider.aux_vectors.get_vector_sync(id.into_usize())) + distances::pq::Hybrid::Quant(self.builder.quant.get_vector_sync(id.into_usize())) } }; Some(element) } } -// Pass-through fill — returns `&Self` which directly accesses the underlying provider. -impl workingset::Fill for QuantAccessor<'_, V, D, Ctx> +/////////////////// +// QuantAccessor // +/////////////////// + +/// An accessor that retrieves the quantized portion of the [`DefaultProvider`]. +/// +/// This type implements the following traits: +/// +/// * [`Accessor`] for the `DefaultProvider`. +pub struct QuantAccessor<'a, V, D, Ctx> { + provider: &'a DefaultProvider, + computer: pq::distance::QueryComputer>, +} + +impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> where V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { - type Error = std::convert::Infallible; - type View<'a> - = &'a Self + pub(crate) fn new( + provider: &'a DefaultProvider, + query: &[f32], + ) -> Self { + let computer = provider.aux_vectors.query_computer(query).unwrap(); + Self { provider, computer } + } +} + +impl GetFullPrecision for QuantAccessor<'_, FullPrecisionStore, D, Ctx> +where + T: VectorRepr, +{ + type Repr = T; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { + &self.provider.base_vectors + } +} + +impl HasId for QuantAccessor<'_, V, D, Ctx> { + type Id = u32; +} + +impl Explore for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } + + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send where - Self: 'a; + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = self.computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); + + f(i, distance); + } + Ok(()) + } + } - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, { - Ok(self) + let f = move || -> ANNResult<()> { + let mut neighbors = AdjacencyList::new(); + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut neighbors)?; + for i in neighbors.iter().filter(|i| pred.eval_mut(i)) { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = self.computer.evaluate_similarity(unsafe { + self.provider.aux_vectors.get_vector_sync(i.into_usize()) + }); + + on_neighbors(distance, *i); + } + } + Ok(()) + }; + + std::future::ready(f()) + } +} + +// impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Delegate = &'a SimpleNeighborProviderAsync; +// fn delegate_neighbor(&'a mut self) -> Self::Delegate { +// self.provider.neighbors() +// } +// } +// +// impl HasElementRef for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type ElementRef<'a> = &'a [u8]; +// } + +// impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> +// where +// T: VectorRepr, +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type QueryComputerError = ANNError; +// type QueryComputer = pq::distance::QueryComputer>; +// +// fn build_query_computer( +// &self, +// from: &[T], +// ) -> Result { +// self.provider.aux_vectors.query_computer(from) +// } +// } + +// impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type DistanceComputerError = ANNError; +// type DistanceComputer = pq::distance::DistanceComputer>; +// +// fn build_distance_computer( +// &self, +// ) -> Result { +// Ok(self.provider.aux_vectors.distance_computer()) +// } +// } + +//-------------------// +// In-mem Extensions // +//-------------------// + +impl<'a, V, D, Ctx> AsDeletionCheck for QuantAccessor<'a, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly + DeletionCheck, + Ctx: ExecutionContext, +{ + type Checker = D; + fn as_deletion_check(&self) -> &D { + &self.provider.deleted } } +// ///////////////////// +// // Hybrid Accessor // +// ///////////////////// +// +// /// A hybrid accessor that fetches a mixture of full-precision and quantized vectors during +// /// pruning. This allows the application to trade full-precision fetches for accuracy. +// /// +// /// This type implements the following traits: +// /// +// /// * [`Accessor`] for the [`DefaultProvider`]. +// /// * [`BuildDistanceComputer`] for computing distances among [`distances::pq::Hybrid`] +// /// element types. +// /// * [`Fill`] for populating a mixture of full-precision and quant vectors. +// pub struct HybridAccessor<'a, T, D, Ctx> +// where +// T: VectorRepr, +// { +// provider: &'a FullPrecisionProvider, +// +// /// Maximum number of full-precision vectors to use during pruning. +// /// This field is ignored during search, where full-precision vectors are never used. +// max_fp_vecs_per_prune: usize, +// } +// +// impl<'a, T, D, Ctx> HybridAccessor<'a, T, D, Ctx> +// where +// T: VectorRepr, +// { +// fn new( +// provider: &'a FullPrecisionProvider, +// max_fp_vecs_per_prune: usize, +// ) -> Self { +// Self { +// provider, +// max_fp_vecs_per_prune, +// } +// } +// } +// +// impl HasId for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// { +// type Id = u32; +// } +// +// impl<'a, T, D, Ctx> DelegateNeighbor<'a> for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Delegate = &'a SimpleNeighborProviderAsync; +// fn delegate_neighbor(&'a mut self) -> Self::Delegate { +// self.provider.neighbors() +// } +// } +// +// impl HasElementRef for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +// } +// +// impl BuildDistanceComputer for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type DistanceComputerError = ANNError; +// type DistanceComputer = distances::pq::HybridComputer; +// +// fn build_distance_computer( +// &self, +// ) -> Result { +// let metric = self.provider.aux_vectors.metric(); +// Ok(distances::pq::HybridComputer::new( +// self.provider.aux_vectors.distance_computer(), +// T::distance(metric, Some(self.provider.base_vectors.dim())), +// )) +// } +// } +// +// /// Tracks which IDs should use full-precision vectors during hybrid pruning. +// /// +// /// IDs in the set get full-precision distance computations; all others fall back to +// /// quantized vectors. +// pub struct FullPrecisionTracker(hashbrown::HashSet); +// +// impl workingset::AsWorkingSet for Unseeded { +// fn as_working_set(&self, capacity: usize) -> FullPrecisionTracker { +// FullPrecisionTracker(hashbrown::HashSet::with_capacity(capacity)) +// } +// } +// +// // Selective fill — the first `max_fp_vecs_per_prune` candidates receive full-precision +// // vectors; the remainder are accessed through the pass-through `MaybeFullPrecision` view. +// impl workingset::Fill for HybridAccessor<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Error = std::convert::Infallible; +// type View<'a> +// = MaybeFullPrecision<'a, T, D, Ctx> +// where +// Self: 'a; +// +// async fn fill<'a, Itr>( +// &'a mut self, +// state: &'a mut FullPrecisionTracker, +// itr: Itr, +// ) -> Result, Self::Error> +// where +// Itr: ExactSizeIterator + Clone + Send + Sync, +// Self: 'a, +// { +// state.0.clear(); +// state.0.extend(itr.take(self.max_fp_vecs_per_prune)); +// Ok(MaybeFullPrecision { +// accessor: self, +// full: state, +// }) +// } +// } +// +// pub struct MaybeFullPrecision<'a, T, D, Ctx> +// where +// T: VectorRepr, +// { +// accessor: &'a HybridAccessor<'a, T, D, Ctx>, +// full: &'a FullPrecisionTracker, +// } +// +// impl workingset::View for MaybeFullPrecision<'_, T, D, Ctx> +// where +// T: VectorRepr, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +// type Element<'a> +// = distances::pq::Hybrid<&'a [T], &'a [u8]> +// where +// Self: 'a; +// +// fn get(&self, id: u32) -> Option> { +// let provider = &self.accessor.provider; +// let element = if self.full.0.contains(&id) { +// // SAFETY: This is unsound. We assume no concurrent writes to this slot, but +// // this invariant is not enforced. See `get_vector_sync` for details. +// unsafe { +// distances::pq::Hybrid::Full(provider.base_vectors.get_vector_sync(id.into_usize())) +// } +// } else { +// // SAFETY: This is unsound. We assume no concurrent writes to this slot, but +// // this invariant is not enforced. See `get_vector_sync` for details. +// unsafe { +// distances::pq::Hybrid::Quant(provider.aux_vectors.get_vector_sync(id.into_usize())) +// } +// }; +// Some(element) +// } +// } +// +// // Pass-through fill — returns `&Self` which directly accesses the underlying provider. +// impl workingset::Fill for QuantAccessor<'_, V, D, Ctx> +// where +// V: AsyncFriendly, +// D: AsyncFriendly, +// Ctx: ExecutionContext, +// { +// type Error = std::convert::Infallible; +// type View<'a> +// = &'a Self +// where +// Self: 'a; +// +// async fn fill<'a, Itr>( +// &'a mut self, +// _state: &'a mut PassThrough, +// _itr: Itr, +// ) -> Result, Self::Error> +// where +// Itr: ExactSizeIterator + Clone + Send + Sync, +// Self: 'a, +// { +// Ok(self) +// } +// } + // Pass-through view — reads PQ codes directly from the provider. impl workingset::View for &QuantAccessor<'_, V, D, Ctx> where @@ -515,7 +695,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = pq::distance::QueryComputer>; + // type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = Panics; @@ -523,13 +703,12 @@ where &'a self, provider: &'a FullPrecisionProvider, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap())) } } -/// Starting points are filtered out of the final results and results are reranked using -/// the full-precision data. impl DefaultPostProcessor, &[T]> for Hybrid where @@ -547,7 +726,7 @@ where Ctx: ExecutionContext, { type DistanceComputer<'a> = distances::pq::HybridComputer; - type PruneAccessor<'a> = HybridAccessor<'a, T, D, Ctx>; + type PruneAccessor<'a> = HybridBuilder<'a, T>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = FullPrecisionTracker; @@ -560,10 +739,13 @@ where provider: &'a FullPrecisionProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(HybridAccessor::new( - provider, - self.max_fp_vecs_per_prune.unwrap_or(usize::MAX), - )) + let builder = HybridBuilder { + full: &provider.base_vectors, + quant: &provider.aux_vectors, + neighbors: provider.neighbors(), + max_fp_vecs_per_prune: self.max_fp_vecs_per_prune.unwrap_or(usize::MAX), + }; + Ok(builder) } } @@ -664,7 +846,7 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = pq::distance::QueryComputer>; + // type QueryComputer = pq::distance::QueryComputer>; type SearchAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; type SearchAccessorError = Panics; @@ -672,8 +854,9 @@ where &'a self, provider: &'a DefaultProvider, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap())) } } @@ -693,7 +876,7 @@ where Ctx: ExecutionContext, { type DistanceComputer<'a> = pq::distance::DistanceComputer>; - type PruneAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; + type PruneAccessor<'a> = QuantBuilder<'a>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -706,7 +889,11 @@ where provider: &'a DefaultProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(QuantAccessor::new(provider)) + let builder = QuantBuilder { + provider: &provider.aux_vectors, + neighbors: provider.neighbors(), + }; + Ok(builder) } } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index a69344691..2f9c63424 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -9,16 +9,14 @@ use crate::storage::{StorageReadProvider, StorageWriteProvider}; use diskann::{ ANNError, ANNResult, default_post_processor, graph::{ + AdjacencyList, glue::{ - self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, FilterStartPoints, InsertStrategy, Pipeline, + PruneStrategy, SearchStrategy, }, workingset, }, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, - }, + provider::{BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId}, utils::{IntoUsize, VectorRepr}, }; use diskann_quantization::{ @@ -364,213 +362,222 @@ where } } -////////////// -// Accessor // -////////////// +///////////// +// Builder // +///////////// -/// The accessor for SQ. -pub struct QuantAccessor<'a, const NBITS: usize, V, D, Ctx> { - provider: &'a DefaultProvider, D, Ctx>, - id_buffer: Vec, - is_search: bool, +pub struct Builder<'a, const NBITS: usize> { + store: &'a SQStore, + neighbors: &'a SimpleNeighborProviderAsync, } -impl GetFullPrecision - for QuantAccessor<'_, NBITS, FullPrecisionStore, D, Ctx> -where - T: VectorRepr, -{ - type Repr = T; - fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { - &self.provider.base_vectors +impl<'a, const NBITS: usize> Builder<'a, NBITS> { + fn new(store: &'a SQStore, neighbors: &'a SimpleNeighborProviderAsync) -> Self { + Self { store, neighbors } } } -impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx> { +impl HasId for Builder<'_, NBITS> { type Id = u32; } -impl SearchExt for QuantAccessor<'_, NBITS, V, D, Ctx> +impl HasElementRef for Builder<'_, NBITS> where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, Unsigned: Representation, { - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) + type ElementRef<'a> = CVRef<'a, NBITS>; +} + +impl<'a, const NBITS: usize> DelegateNeighbor<'a> for Builder<'_, NBITS> { + type Delegate = &'a SimpleNeighborProviderAsync; + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.neighbors } } -impl<'a, const NBITS: usize, V, D, Ctx> QuantAccessor<'a, NBITS, V, D, Ctx> +impl BuildDistanceComputer for Builder<'_, NBITS> where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, + Unsigned: Representation, + DistanceComputer: for<'a, 'b> DistanceFunction, CVRef<'b, NBITS>, f32>, { - pub(crate) fn new( - provider: &'a DefaultProvider, D, Ctx>, - is_search: bool, - ) -> Self { - Self { - provider, - id_buffer: Vec::with_capacity(32), - is_search, - } + type DistanceComputerError = ANNError; + type DistanceComputer = DistanceComputer; + + fn build_distance_computer( + &self, + ) -> Result { + Ok(self.store.distance_computer()?) } } -impl Accessor for QuantAccessor<'_, NBITS, V, D, Ctx> +impl workingset::Fill for Builder<'_, NBITS> where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, Unsigned: Representation, { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = CVRef<'a, NBITS> + type Error = std::convert::Infallible; + type View<'a> + = &'a Self where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = CVRef<'a, NBITS>; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = ANNError; - - /// Return the quantized vector stored at index `i`. - /// - /// This function always completes synchronously. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready( - match self.provider.aux_vectors.get_vector(id.into_usize()) { - Ok(v) => Ok(v), - Err(err) => Err(err.into()), - }, - ) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl Future> + Send + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.aux_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .aux_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); - } - - let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) { - Ok(v) => v, - Err(e) => return std::future::ready(Err(e.into())), - }; + Ok(self) + } +} - // Invoke the passed closure on the vector. - // - // SAFETY: We're accepting the consequences of potential unsynchronized, - // concurrent mutation. - f(vector, *id) - } +// Pass-through view — reads scalar-quantized vectors directly from the provider. +impl workingset::View for &Builder<'_, NBITS> +where + Unsigned: Representation, +{ + type ElementRef<'a> = CVRef<'a, NBITS>; + type Element<'a> + = CVRef<'a, NBITS> + where + Self: 'a; - std::future::ready(Ok(())) + fn get(&self, id: u32) -> Option> { + self.store.get_vector(id.into_usize()).ok() } } -impl<'a, const NBITS: usize, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, NBITS, V, D, Ctx> +////////////// +// Accessor // +////////////// + +/// The accessor for SQ. +pub struct QuantAccessor<'a, const NBITS: usize, V, D, Ctx> where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, + Unsigned: Representation, { - type Delegate = &'a SimpleNeighborProviderAsync; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } + provider: &'a DefaultProvider, D, Ctx>, + computer: QueryComputer, + id_buffer: AdjacencyList, } -impl BuildQueryComputer<&[T]> - for QuantAccessor<'_, NBITS, V, D, Ctx> +impl GetFullPrecision + for QuantAccessor<'_, NBITS, FullPrecisionStore, D, Ctx> where T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, Unsigned: Representation, - QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; - - fn build_query_computer( - &self, - from: &[T], - ) -> Result { - // Allow rescaling if this is search. - Ok(self - .provider - .aux_vectors - .query_computer(from, self.is_search)?) + type Repr = T; + fn as_full_precision(&self) -> &FastMemoryVectorProviderAsync { + &self.provider.base_vectors } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> +impl HasId for QuantAccessor<'_, NBITS, V, D, Ctx> +where + Unsigned: Representation, +{ + type Id = u32; +} + +impl Explore for QuantAccessor<'_, NBITS, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { + fn starting_points(&self) -> impl Future>> { + std::future::ready(self.provider.starting_points()) + } + + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for i in self.provider.starting_points()? { + let vector = self.provider.aux_vectors.get_vector(i.into_usize())?; + f(i, self.computer.evaluate_similarity(vector)); + } + Ok(()) + } + } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let id_buffer = &mut self.id_buffer; + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), id_buffer)?; + + id_buffer.retain(|i| pred.eval_mut(i)); + + let len = id_buffer.len(); + let lookahead = self.provider.aux_vectors.prefetch_lookahead(); + + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.aux_vectors.prefetch_hint(id.into_usize()); + } + + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .aux_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } + + let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; + let distance = self.computer.evaluate_similarity(vector); + on_neighbors(distance, *id); + } + } + Ok(()) + }; + + std::future::ready(f()) + } } -impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> +impl<'a, const NBITS: usize, V, D, Ctx> QuantAccessor<'a, NBITS, V, D, Ctx> where + Unsigned: Representation, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, - Unsigned: Representation, - DistanceComputer: for<'a, 'b> DistanceFunction, CVRef<'b, NBITS>, f32>, { - type DistanceComputerError = ANNError; - type DistanceComputer = DistanceComputer; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(self.provider.aux_vectors.distance_computer()?) + pub(crate) fn new( + provider: &'a DefaultProvider, D, Ctx>, + query: &[f32], + is_search: bool, + ) -> Self { + Self { + provider, + computer: provider + .aux_vectors + .query_computer(query, is_search) + .unwrap(), + id_buffer: AdjacencyList::with_capacity(32), + } } } @@ -603,7 +610,6 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -611,8 +617,9 @@ where &'a self, provider: &'a FullPrecisionProvider, D, Ctx>, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, true)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), true)) } } @@ -640,7 +647,7 @@ where Unsigned: Representation, QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { - type QueryComputer = QueryComputer; + // type QueryComputer = QueryComputer; type SearchAccessor<'a> = QuantAccessor<'a, NBITS, NoStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -648,8 +655,9 @@ where &'a self, provider: &'a DefaultProvider, D, Ctx>, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, true)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), true)) } } @@ -675,7 +683,7 @@ where DistanceComputer: for<'a, 'b> DistanceFunction, CVRef<'b, NBITS>, f32>, { type DistanceComputer<'a> = DistanceComputer; - type PruneAccessor<'a> = QuantAccessor<'a, NBITS, V, D, Ctx>; + type PruneAccessor<'a> = Builder<'a, NBITS>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -688,57 +696,11 @@ where provider: &'a DefaultProvider, D, Ctx>, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - Ok(QuantAccessor::new(provider, false)) + Ok(Builder::new(&provider.aux_vectors, provider.neighbors())) } } // Pass-through fill — returns `&Self` which directly accesses the underlying provider. -impl workingset::Fill - for QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, - Unsigned: Representation, -{ - type Error = std::convert::Infallible; - type View<'a> - = &'a Self - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - Ok(self) - } -} - -// Pass-through view — reads scalar-quantized vectors directly from the provider. -impl workingset::View for &QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, - Unsigned: Representation, -{ - type ElementRef<'a> = CVRef<'a, NBITS>; - type Element<'a> - = CVRef<'a, NBITS> - where - Self: 'a; - - fn get(&self, id: u32) -> Option> { - self.provider.aux_vectors.get_vector(id.into_usize()).ok() - } -} - impl InsertStrategy, D, Ctx>, &[T]> for Quantized where diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 11f489aa4..02cb46e5a 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -9,18 +9,15 @@ use std::{future::Future, sync::Mutex}; use diskann::{ ANNError, ANNErrorKind, ANNResult, default_post_processor, - error::IntoANNResult, graph::{ + AdjacencyList, glue::{ - self, DefaultPostProcessor, ExpandBeam, FilterStartPoints, InsertStrategy, Pipeline, - PruneStrategy, SearchExt, SearchStrategy, + self, DefaultPostProcessor, Explore, FilterStartPoints, InsertStrategy, Pipeline, + PruneStrategy, SearchStrategy, }, workingset, }, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, - }, + provider::{BuildDistanceComputer, DelegateNeighbor, ExecutionContext, HasElementRef, HasId}, utils::{IntoUsize, VectorRepr}, }; use diskann_quantization::{ @@ -29,7 +26,7 @@ use diskann_quantization::{ spherical, }; use diskann_utils::future::AsyncFriendly; -use diskann_vector::distance::Metric; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use thiserror::Error; use super::{GetFullPrecision, PassThrough, Rerank}; @@ -284,15 +281,85 @@ where } } +///////////// +// Builder // +///////////// + +#[derive(Clone)] +pub struct Builder<'a> { + store: &'a SphericalStore, + neighbors: &'a SimpleNeighborProviderAsync, +} + +impl HasId for Builder<'_> { + type Id = u32; +} + +impl HasElementRef for Builder<'_> { + type ElementRef<'a> = spherical::iface::Opaque<'a>; +} + +impl<'a> DelegateNeighbor<'a> for Builder<'_> { + type Delegate = &'a SimpleNeighborProviderAsync; + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + self.neighbors + } +} + +impl BuildDistanceComputer for Builder<'_> { + type DistanceComputerError = AllocatorError; + type DistanceComputer = + UnwrapErr; + + fn build_distance_computer( + &self, + ) -> Result { + self.store.distance_computer().map(UnwrapErr::new) + } +} + +// Pass-through fill — returns `&Self` which directly accesses the underlying provider. +impl workingset::Fill for Builder<'_> { + type Error = std::convert::Infallible; + type View<'a> + = Self + where + Self: 'a; + + async fn fill<'a, Itr>( + &'a mut self, + _state: &'a mut PassThrough, + _itr: Itr, + ) -> Result, Self::Error> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + Ok(self.clone()) + } +} + +// Pass-through view — reads spherical vectors directly from the provider. +impl workingset::View for Builder<'_> { + type ElementRef<'a> = spherical::iface::Opaque<'a>; + type Element<'a> + = spherical::iface::Opaque<'a> + where + Self: 'a; + + fn get(&self, id: u32) -> Option> { + self.store.get_vector(id.into_usize()).ok() + } +} + ////////////// // Accessor // ////////////// pub struct QuantAccessor<'a, V, D, Ctx> { provider: &'a DefaultProvider, - id_buffer: Vec, - layout: spherical::iface::QueryLayout, - is_search: bool, + pub(crate) computer: spherical::iface::QueryComputer, + id_buffer: AdjacencyList, } impl<'a, V, D, Ctx> QuantAccessor<'a, V, D, Ctx> @@ -303,14 +370,19 @@ where { pub(crate) fn new( provider: &'a DefaultProvider, + query: &[f32], layout: spherical::iface::QueryLayout, is_search: bool, ) -> Self { + let computer = provider + .aux_vectors + .query_computer(query, layout, is_search) + .unwrap(); + Self { provider, - id_buffer: Vec::with_capacity(32), - layout, - is_search, + computer, + id_buffer: AdjacencyList::with_capacity(32), } } } @@ -329,7 +401,7 @@ impl HasId for QuantAccessor<'_, V, D, Ctx> { type Id = u32; } -impl SearchExt for QuantAccessor<'_, V, D, Ctx> +impl Explore for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, D: AsyncFriendly, @@ -338,162 +410,71 @@ where fn starting_points(&self) -> impl Future>> { std::future::ready(self.provider.starting_points()) } -} -impl Accessor for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = spherical::iface::Opaque<'a> - where - Self: 'a; - - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = spherical::iface::Opaque<'a>; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = ANNError; - - /// Return the quantized vector stored at index `i`. - /// - /// This function always completes synchronously. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // SAFETY: We've decided to live with UB that can result from potentially mixing - // unsynchronized reads and writes on the underlying memory. - std::future::ready( - self.provider - .aux_vectors - .get_vector(id.into_usize()) - .into_ann_result(), - ) - } - - /// Perform a bulk operation. - /// - /// This implementation uses prefetching. - fn on_elements_unordered( + fn start_point_distances( &mut self, - itr: Itr, mut f: F, - ) -> impl Future> + Send + ) -> impl std::future::Future> + Send where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + F: FnMut(Self::Id, f32) + Send, { - // Reuse the internal buffer to collect the results and give us random access - // capabilities. - let id_buffer = &mut self.id_buffer; - id_buffer.clear(); - id_buffer.extend(itr); - - let len = id_buffer.len(); - let lookahead = self.provider.aux_vectors.prefetch_lookahead(); - - // Prefetch the first few vectors. - for id in id_buffer.iter().take(lookahead) { - self.provider.aux_vectors.prefetch_hint(id.into_usize()); - } - - for (i, id) in id_buffer.iter().enumerate() { - // Prefetch `lookahead` iterations ahead as long as it is safe. - if lookahead > 0 && i + lookahead < len { - self.provider - .aux_vectors - .prefetch_hint(id_buffer[i + lookahead].into_usize()); + async move { + for i in self.provider.starting_points()? { + // SAFETY: We're accepting the consequences of potential unsynchronized, + // concurrent mutation. + let distance = self.computer + .evaluate_similarity(self.provider.aux_vectors.get_vector(i.into_usize())?).unwrap(); + + f(i, distance); } - - let vector = match self.provider.aux_vectors.get_vector(id.into_usize()) { - Ok(v) => v, - Err(e) => return std::future::ready(Err(e.into())), - }; - - f(vector, *id) + Ok(()) } - - std::future::ready(Ok(())) - } -} - -impl<'a, V, D, Ctx> DelegateNeighbor<'a> for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type Delegate = &'a SimpleNeighborProviderAsync; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() } -} -impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputerError = Bridge; - type QueryComputer = - UnwrapErr; - - fn build_query_computer( - &self, - query: &[T], - ) -> Result { - self.provider - .aux_vectors - .query_computer(query, self.layout, self.is_search) - .bridge_err() - .map(UnwrapErr::new) - } -} + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let f = move || -> ANNResult<()> { + let mut id_buffer = &mut self.id_buffer; + for n in ids { + self.provider + .neighbor_provider + .get_neighbors_sync(n.into_usize(), &mut id_buffer)?; -impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> -where - T: VectorRepr, - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ -} + id_buffer.retain(|i| pred.eval_mut(i)); + let len = id_buffer.len(); + let lookahead = self.provider.aux_vectors.prefetch_lookahead(); -#[derive(Debug, Error)] -#[error("unconstructible")] -pub enum Infallible {} + // Prefetch the first few vectors. + for id in id_buffer.iter().take(lookahead) { + self.provider.aux_vectors.prefetch_hint(id.into_usize()); + } -impl From for ANNError { - fn from(_: Infallible) -> Self { - unreachable!("Infallible is an unconstructible enum") - } -} + for (i, id) in id_buffer.iter().enumerate() { + // Prefetch `lookahead` iterations ahead as long as it is safe. + if lookahead > 0 && i + lookahead < len { + self.provider + .aux_vectors + .prefetch_hint(id_buffer[i + lookahead].into_usize()); + } -impl BuildDistanceComputer for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type DistanceComputerError = AllocatorError; - type DistanceComputer = - UnwrapErr; + let vector = self.provider.aux_vectors.get_vector(id.into_usize())?; + let distance = self.computer.evaluate_similarity(vector).unwrap(); + on_neighbors(distance, *id); + } + } + Ok(()) + }; - fn build_distance_computer( - &self, - ) -> Result { - self.provider - .aux_vectors - .distance_computer() - .map(UnwrapErr::new) + std::future::ready(f()) } } @@ -551,8 +532,8 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = - UnwrapErr; + // type QueryComputer = + // UnwrapErr; type SearchAccessor<'a> = QuantAccessor<'a, FullPrecisionStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -560,8 +541,9 @@ where &'a self, provider: &'a FullPrecisionProvider, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, self.layout, self.is_search)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), self.layout, self.is_search)) } } @@ -584,8 +566,8 @@ where D: AsyncFriendly + DeletionCheck, Ctx: ExecutionContext, { - type QueryComputer = - UnwrapErr; + // type QueryComputer = + // UnwrapErr; type SearchAccessor<'a> = QuantAccessor<'a, NoStore, D, Ctx>; type SearchAccessorError = ANNError; @@ -593,8 +575,9 @@ where &'a self, provider: &'a DefaultProvider, _context: &'a Ctx, + query: &[T], ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider, self.layout, self.is_search)) + Ok(QuantAccessor::new(provider, &*T::as_f32(query).unwrap(), self.layout, self.is_search)) } } @@ -616,7 +599,7 @@ where { type DistanceComputer<'a> = UnwrapErr; - type PruneAccessor<'a> = QuantAccessor<'a, V, D, Ctx>; + type PruneAccessor<'a> = Builder<'a>; type PruneAccessorError = diskann::error::Infallible; type WorkingSet = PassThrough; @@ -629,52 +612,12 @@ where provider: &'a DefaultProvider, _context: &'a Ctx, ) -> Result, Self::PruneAccessorError> { - let build = Self::build(); - Ok(QuantAccessor::new(provider, build.layout, build.is_search)) - } -} + let builder = Builder { + store: &provider.aux_vectors, + neighbors: provider.neighbors(), + }; -// Pass-through fill — returns `&Self` which directly accesses the underlying provider. -impl workingset::Fill for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type Error = std::convert::Infallible; - type View<'a> - = &'a Self - where - Self: 'a; - - async fn fill<'a, Itr>( - &'a mut self, - _state: &'a mut PassThrough, - _itr: Itr, - ) -> Result, Self::Error> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - Ok(self) - } -} - -// Pass-through view — reads spherical vectors directly from the provider. -impl workingset::View for &QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly + DeletionCheck, - Ctx: ExecutionContext, -{ - type ElementRef<'a> = spherical::iface::Opaque<'a>; - type Element<'a> - = spherical::iface::Opaque<'a> - where - Self: 'a; - - fn get(&self, id: u32) -> Option> { - self.provider.aux_vectors.get_vector(id.into_usize()).ok() + Ok(builder) } } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 2493a5ce2..79aaf307b 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -1,327 +1,327 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{ - future::Future, - sync::{Arc, Mutex}, -}; - -use diskann::{ - ANNError, ANNResult, default_post_processor, - error::{RankedError, ToRanked, TransientError}, - graph::{ - glue::{ - CopyIds, DefaultPostProcessor, ExpandBeam, InsertStrategy, MultiInsertStrategy, - PruneStrategy, SearchExt, SearchStrategy, - }, - workingset::map, - }, - neighbor::Neighbor, - provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - HasId, - }, - utils::IntoUsize, -}; -use diskann_utils::views::Matrix; - -use super::{DefaultProvider, DefaultQuant}; -use crate::model::graph::provider::async_::{ - SimpleNeighborProviderAsync, TableDeleteProviderAsync, inmem::FullPrecisionStore, -}; - -/// A full-precision accessor that spuriously fails for non-start points with a controllable -/// frequency. -/// -/// This is meant to test the non-critical error handling of index operations. -#[derive(Debug, Clone, Copy)] -pub struct Flaky { - fail_every: usize, -} - -impl Flaky { - pub(crate) fn new(fail_every: usize) -> Self { - Self { fail_every } - } -} - -#[derive(Debug)] -pub struct TestError { - is_transient: bool, - handled: bool, -} - -impl TestError { - fn transient() -> Self { - Self { - is_transient: true, - handled: false, - } - } -} - -impl std::fmt::Display for TestError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl TransientError for TestError { - fn acknowledge(mut self, _why: D) - where - D: std::fmt::Display, - { - self.handled = true; - } - - fn escalate(mut self, _why: D) -> Self - where - D: std::fmt::Display, - { - assert!(self.is_transient); - self.handled = true; - self.is_transient = false; - self - } -} - -impl Drop for TestError { - fn drop(&mut self) { - if self.is_transient { - assert!(self.handled, "dropping an unhandled transient error!"); - } - } -} - -impl From for ANNError { - fn from(value: TestError) -> Self { - assert!( - !value.is_transient, - "transient errors should not be converted!" - ); - ANNError::log_async_error(value) - } -} - -impl ToRanked for TestError { - type Transient = Self; - type Error = Self; - - fn to_ranked(self) -> RankedError { - if self.is_transient { - RankedError::Transient(self) - } else { - RankedError::Error(self) - } - } - - fn from_transient(transient: Self) -> Self { - assert!(transient.is_transient); - transient - } - - fn from_error(error: Self) -> Self { - assert!(!error.is_transient); - error - } -} - -type Tda = TableDeleteProviderAsync; -type TestProvider = DefaultProvider, DefaultQuant, Tda>; - -pub struct FlakyAccessor<'a> { - provider: &'a TestProvider, - fail_every: usize, - get_count: usize, -} - -type FullAccessor<'a> = super::FullAccessor<'a, f32, DefaultQuant, Tda, DefaultContext>; - -impl SearchExt for FlakyAccessor<'_> { - fn starting_points(&self) -> impl Future>> { - std::future::ready(self.provider.starting_points()) - } -} - -impl<'a> FlakyAccessor<'a> { - fn new(provider: &'a TestProvider, fail_every: usize, get_count: usize) -> Self { - assert_ne!(get_count, 0); - Self { - provider, - get_count, - fail_every, - } - } - - fn as_full(&self) -> FullAccessor<'a> { - FullAccessor::new(self.provider) - } -} - -impl HasId for FlakyAccessor<'_> { - type Id = u32; -} - -impl Accessor for FlakyAccessor<'_> { - /// This accessor returns raw slices. There *is* a chance of racing when the fast - /// providers are used. We just have to live with it. - type Element<'a> - = &'a [f32] - where - Self: 'a; - - type ElementRef<'a> = &'a [f32]; - - /// Choose to panic on an out-of-bounds access rather than propagate an error. - type GetError = TestError; - - /// Return the full-precision vector stored at index `i`. - /// - /// This function always completes synchronously. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // Do not fail when retrieving start points. - // - // NOTE: `is_not_start_point` takes a neighbor, but only looks at the `ID` portion, - // so we can use a dummy neighbor. - if self.provider.is_not_start_point()(&Neighbor::new(id, 0.0)) { - self.get_count -= 1; - if self.get_count == 0 { - self.get_count = self.fail_every; - return std::future::ready(Err(TestError::transient())); - } - } - - // SAFETY: We've decided to live with UB (undefined behavior) that can result from - // potentially mixing unsynchronized reads and writes on the underlying memory. - std::future::ready(Ok(unsafe { - self.provider.base_vectors.get_vector_sync(id.into_usize()) - })) - } -} - -impl<'a> BuildDistanceComputer for FlakyAccessor<'a> { - type DistanceComputerError = as BuildDistanceComputer>::DistanceComputerError; - type DistanceComputer = as BuildDistanceComputer>::DistanceComputer; - - fn build_distance_computer( - &self, - ) -> Result { - self.as_full().build_distance_computer() - } -} - -impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { - type QueryComputerError = - as BuildQueryComputer<&'a [f32]>>::QueryComputerError; - type QueryComputer = as BuildQueryComputer<&'a [f32]>>::QueryComputer; - - fn build_query_computer( - &self, - from: &'a [f32], - ) -> Result { - self.as_full().build_query_computer(from) - } -} - -impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} - -impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { - type Delegate = &'a SimpleNeighborProviderAsync; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.provider.neighbors() - } -} - -impl<'x> SearchStrategy for Flaky { - type QueryComputer = as BuildQueryComputer<&'x [f32]>>::QueryComputer; - type SearchAccessor<'a> = FlakyAccessor<'a>; - type SearchAccessorError = ANNError; - - fn search_accessor<'a>( - &'a self, - provider: &'a TestProvider, - _context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - Ok(FlakyAccessor::new( - provider, - self.fail_every, - self.fail_every, - )) - } -} - -impl DefaultPostProcessor for Flaky { - default_post_processor!(CopyIds); -} - -const STATIC_PRUNE_THRESHOLD: usize = 5; -/// We need to tune the flakiness of the `Prune` accessor so that occasionally, the first -/// item retrieved is a failure. -static START_COUNT: Mutex = Mutex::new(STATIC_PRUNE_THRESHOLD); - -type WorkingSet = map::Map, map::Ref<[f32]>>; - -impl PruneStrategy for Flaky { - type DistanceComputer<'a> = as BuildDistanceComputer>::DistanceComputer; - type PruneAccessor<'a> = FlakyAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - type WorkingSet = WorkingSet; - - fn prune_accessor<'a>( - &'a self, - provider: &'a TestProvider, - _context: &'a DefaultContext, - ) -> Result, Self::PruneAccessorError> { - let mut guard = START_COUNT.lock().unwrap(); - let start = *guard; - *guard -= 1; - if *guard == 0 { - *guard = STATIC_PRUNE_THRESHOLD; - } - - Ok(FlakyAccessor::new(provider, STATIC_PRUNE_THRESHOLD, start)) - } - - fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { - map::Builder::new(map::Capacity::Default).build(capacity) - } -} - -impl InsertStrategy for Flaky { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self { - *self - } -} - -impl MultiInsertStrategy> for Flaky { - type Seed = map::Builder>; - type WorkingSet = WorkingSet; - type FinishError = diskann::error::Infallible; - type InsertStrategy = Self; - - fn insert_strategy(&self) -> Self::InsertStrategy { - *self - } - - fn finish( - &self, - _provider: &TestProvider, - _ctx: &DefaultContext, - batch: &Arc>, - ids: Itr, - ) -> impl std::future::Future> + Send - where - Itr: ExactSizeIterator + Send, - { - std::future::ready(Ok(map::Builder::new(map::Capacity::Default) - .with_overlay(map::Overlay::from_batch(batch.clone(), ids)))) - } -} +// /* +// * Copyright (c) Microsoft Corporation. +// * Licensed under the MIT license. +// */ +// +// use std::{ +// future::Future, +// sync::{Arc, Mutex}, +// }; +// +// use diskann::{ +// ANNError, ANNResult, default_post_processor, +// error::{RankedError, ToRanked, TransientError}, +// graph::{ +// glue::{ +// CopyIds, DefaultPostProcessor, ExpandBeam, InsertStrategy, MultiInsertStrategy, +// PruneStrategy, Explore, SearchStrategy, +// }, +// workingset::map, +// }, +// neighbor::Neighbor, +// provider::{ +// Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, +// HasId, +// }, +// utils::IntoUsize, +// }; +// use diskann_utils::views::Matrix; +// +// use super::{DefaultProvider, DefaultQuant}; +// use crate::model::graph::provider::async_::{ +// SimpleNeighborProviderAsync, TableDeleteProviderAsync, inmem::FullPrecisionStore, +// }; +// +// /// A full-precision accessor that spuriously fails for non-start points with a controllable +// /// frequency. +// /// +// /// This is meant to test the non-critical error handling of index operations. +// #[derive(Debug, Clone, Copy)] +// pub struct Flaky { +// fail_every: usize, +// } +// +// impl Flaky { +// pub(crate) fn new(fail_every: usize) -> Self { +// Self { fail_every } +// } +// } +// +// #[derive(Debug)] +// pub struct TestError { +// is_transient: bool, +// handled: bool, +// } +// +// impl TestError { +// fn transient() -> Self { +// Self { +// is_transient: true, +// handled: false, +// } +// } +// } +// +// impl std::fmt::Display for TestError { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// write!(f, "{:?}", self) +// } +// } +// +// impl TransientError for TestError { +// fn acknowledge(mut self, _why: D) +// where +// D: std::fmt::Display, +// { +// self.handled = true; +// } +// +// fn escalate(mut self, _why: D) -> Self +// where +// D: std::fmt::Display, +// { +// assert!(self.is_transient); +// self.handled = true; +// self.is_transient = false; +// self +// } +// } +// +// impl Drop for TestError { +// fn drop(&mut self) { +// if self.is_transient { +// assert!(self.handled, "dropping an unhandled transient error!"); +// } +// } +// } +// +// impl From for ANNError { +// fn from(value: TestError) -> Self { +// assert!( +// !value.is_transient, +// "transient errors should not be converted!" +// ); +// ANNError::log_async_error(value) +// } +// } +// +// impl ToRanked for TestError { +// type Transient = Self; +// type Error = Self; +// +// fn to_ranked(self) -> RankedError { +// if self.is_transient { +// RankedError::Transient(self) +// } else { +// RankedError::Error(self) +// } +// } +// +// fn from_transient(transient: Self) -> Self { +// assert!(transient.is_transient); +// transient +// } +// +// fn from_error(error: Self) -> Self { +// assert!(!error.is_transient); +// error +// } +// } +// +// type Tda = TableDeleteProviderAsync; +// type TestProvider = DefaultProvider, DefaultQuant, Tda>; +// +// pub struct FlakyAccessor<'a> { +// provider: &'a TestProvider, +// fail_every: usize, +// get_count: usize, +// } +// +// type FullAccessor<'a> = super::FullAccessor<'a, f32, DefaultQuant, Tda, DefaultContext>; +// +// impl Explore for FlakyAccessor<'_> { +// fn starting_points(&self) -> impl Future>> { +// std::future::ready(self.provider.starting_points()) +// } +// } +// +// impl<'a> FlakyAccessor<'a> { +// fn new(provider: &'a TestProvider, fail_every: usize, get_count: usize) -> Self { +// assert_ne!(get_count, 0); +// Self { +// provider, +// get_count, +// fail_every, +// } +// } +// +// fn as_full(&self) -> FullAccessor<'a> { +// FullAccessor::new(self.provider) +// } +// } +// +// impl HasId for FlakyAccessor<'_> { +// type Id = u32; +// } +// +// impl Accessor for FlakyAccessor<'_> { +// // /// This accessor returns raw slices. There *is* a chance of racing when the fast +// // /// providers are used. We just have to live with it. +// // type Element<'a> +// // = &'a [f32] +// // where +// // Self: 'a; +// +// type ElementRef<'a> = &'a [f32]; +// +// // /// Choose to panic on an out-of-bounds access rather than propagate an error. +// // type GetError = TestError; +// +// // /// Return the full-precision vector stored at index `i`. +// // /// +// // /// This function always completes synchronously. +// // #[inline(always)] +// // fn get_element( +// // &mut self, +// // id: Self::Id, +// // ) -> impl Future, Self::GetError>> + Send { +// // // Do not fail when retrieving start points. +// // // +// // // NOTE: `is_not_start_point` takes a neighbor, but only looks at the `ID` portion, +// // // so we can use a dummy neighbor. +// // if self.provider.is_not_start_point()(&Neighbor::new(id, 0.0)) { +// // self.get_count -= 1; +// // if self.get_count == 0 { +// // self.get_count = self.fail_every; +// // return std::future::ready(Err(TestError::transient())); +// // } +// // } +// +// // // SAFETY: We've decided to live with UB (undefined behavior) that can result from +// // // potentially mixing unsynchronized reads and writes on the underlying memory. +// // std::future::ready(Ok(unsafe { +// // self.provider.base_vectors.get_vector_sync(id.into_usize()) +// // })) +// // } +// } +// +// impl<'a> BuildDistanceComputer for FlakyAccessor<'a> { +// type DistanceComputerError = as BuildDistanceComputer>::DistanceComputerError; +// type DistanceComputer = as BuildDistanceComputer>::DistanceComputer; +// +// fn build_distance_computer( +// &self, +// ) -> Result { +// self.as_full().build_distance_computer() +// } +// } +// +// impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { +// type QueryComputerError = +// as BuildQueryComputer<&'a [f32]>>::QueryComputerError; +// type QueryComputer = as BuildQueryComputer<&'a [f32]>>::QueryComputer; +// +// fn build_query_computer( +// &self, +// from: &'a [f32], +// ) -> Result { +// self.as_full().build_query_computer(from) +// } +// } +// +// impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} +// +// impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { +// type Delegate = &'a SimpleNeighborProviderAsync; +// fn delegate_neighbor(&'a mut self) -> Self::Delegate { +// self.provider.neighbors() +// } +// } +// +// impl<'x> SearchStrategy for Flaky { +// type QueryComputer = as BuildQueryComputer<&'x [f32]>>::QueryComputer; +// type SearchAccessor<'a> = FlakyAccessor<'a>; +// type SearchAccessorError = ANNError; +// +// fn search_accessor<'a>( +// &'a self, +// provider: &'a TestProvider, +// _context: &'a DefaultContext, +// ) -> Result, Self::SearchAccessorError> { +// Ok(FlakyAccessor::new( +// provider, +// self.fail_every, +// self.fail_every, +// )) +// } +// } +// +// impl DefaultPostProcessor for Flaky { +// default_post_processor!(CopyIds); +// } +// +// const STATIC_PRUNE_THRESHOLD: usize = 5; +// /// We need to tune the flakiness of the `Prune` accessor so that occasionally, the first +// /// item retrieved is a failure. +// static START_COUNT: Mutex = Mutex::new(STATIC_PRUNE_THRESHOLD); +// +// type WorkingSet = map::Map, map::Ref<[f32]>>; +// +// impl PruneStrategy for Flaky { +// type DistanceComputer<'a> = as BuildDistanceComputer>::DistanceComputer; +// type PruneAccessor<'a> = FlakyAccessor<'a>; +// type PruneAccessorError = diskann::error::Infallible; +// type WorkingSet = WorkingSet; +// +// fn prune_accessor<'a>( +// &'a self, +// provider: &'a TestProvider, +// _context: &'a DefaultContext, +// ) -> Result, Self::PruneAccessorError> { +// let mut guard = START_COUNT.lock().unwrap(); +// let start = *guard; +// *guard -= 1; +// if *guard == 0 { +// *guard = STATIC_PRUNE_THRESHOLD; +// } +// +// Ok(FlakyAccessor::new(provider, STATIC_PRUNE_THRESHOLD, start)) +// } +// +// fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { +// map::Builder::new(map::Capacity::Default).build(capacity) +// } +// } +// +// impl InsertStrategy for Flaky { +// type PruneStrategy = Self; +// +// fn prune_strategy(&self) -> Self { +// *self +// } +// } +// +// impl MultiInsertStrategy> for Flaky { +// type Seed = map::Builder>; +// type WorkingSet = WorkingSet; +// type FinishError = diskann::error::Infallible; +// type InsertStrategy = Self; +// +// fn insert_strategy(&self) -> Self::InsertStrategy { +// *self +// } +// +// fn finish( +// &self, +// _provider: &TestProvider, +// _ctx: &DefaultContext, +// batch: &Arc>, +// ids: Itr, +// ) -> impl std::future::Future> + Send +// where +// Itr: ExactSizeIterator + Send, +// { +// std::future::ready(Ok(map::Builder::new(map::Capacity::Default) +// .with_overlay(map::Overlay::from_batch(batch.clone(), ids)))) +// } +// } diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index 3e1849bb0..94dce287e 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::BuildQueryComputer, + provider::HasId, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -39,7 +39,7 @@ pub struct RemoveDeletedIdsAndCopy; impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + AsDeletionCheck, + A: HasId + AsDeletionCheck, { type Error = std::convert::Infallible; @@ -47,7 +47,6 @@ where &self, accessor: &mut A, _query: T, - _computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index f0ffc0451..f9c113aef 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -20,16 +20,13 @@ use diskann::{ error::StandardError, graph::{ SearchOutputBuffer, - glue::{self, ExpandBeam, SearchExt, SearchPostProcessStep, SearchStrategy}, + glue::{self, Explore, SearchPostProcessStep, SearchStrategy}, index::QueryLabelProvider, }, neighbor::Neighbor, - provider::{Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + provider::{DataProvider, HasId}, utils::VectorId, }; -use diskann_utils::Reborrow; -use diskann_vector::PreprocessedDistanceFunction; -use futures_util::FutureExt; /// A [`SearchStrategy`] type that composes the inner distance computer with beta filtering. /// @@ -70,7 +67,7 @@ pub struct Unwrap; /// Delegate post-processing to the inner strategy's post-processing routine. impl SearchPostProcessStep, T, O> for Unwrap where - A: BuildQueryComputer, + A: HasId, { type Error = NextError @@ -84,7 +81,6 @@ where next: &Next, accessor: &mut BetaAccessor, query: T, - computer: &BetaComputer, candidates: I, output: &mut B, ) -> impl Future> + Send @@ -93,13 +89,7 @@ where B: SearchOutputBuffer + Send + ?Sized, Next: glue::SearchPostProcess, { - next.post_process( - &mut accessor.inner, - query, - computer.inner(), - candidates, - output, - ) + next.post_process(&mut accessor.inner, query, candidates, output) } } @@ -121,9 +111,9 @@ where /// accessor. type SearchAccessor<'a> = BetaAccessor>; - /// A [`PreprocessedDistanceFunction`] that combines applies the beta filtering factor - /// if the vector ID portion of `Element` satisfies the filter predicate. - type QueryComputer = BetaComputer; + // /// A [`PreprocessedDistanceFunction`] that combines applies the beta filtering factor + // /// if the vector ID portion of `Element` satisfies the filter predicate. + // type QueryComputer = Strategy::QueryComputer; type SearchAccessorError = Strategy::SearchAccessorError; @@ -132,11 +122,14 @@ where &'a self, provider: &'a Provider, context: &'a Provider::Context, + query: T, ) -> Result, Self::SearchAccessorError> { Ok(BetaAccessor { - inner: self.strategy.search_accessor(provider, context)?, - labels: self.labels.clone(), - beta: self.beta, + inner: self.strategy.search_accessor(provider, context, query)?, + filter: Filter { + labels: self.labels.clone(), + beta: self.beta, + }, }) } } @@ -158,191 +151,77 @@ where } } -///////////// -// Helpers // -///////////// - -/// The `Element` and `ElementRef` types used by the [`BetaAccessor`]. -#[derive(Debug, Clone, PartialEq)] -pub struct Pair { - id: I, - element: E, -} - -impl Pair { - fn new(id: I, element: E) -> Self { - Self { id, element } - } -} - -/// `Reborrow` is implemented in terms of a full `Reborrow` of `E` while leaving the id -/// untouched. -impl<'a, I, E> Reborrow<'a> for Pair -where - E: Reborrow<'a>, - I: Copy, -{ - type Target = Pair; - fn reborrow(&'a self) -> Self::Target { - Pair { - id: self.id, - element: self.element.reborrow(), - } - } -} - /// An [`Accessor`] that composes with an `Inner` accessor to provide beta-filtering. pub struct BetaAccessor where Inner: HasId, { inner: Inner, - labels: Arc>, - beta: f32, + filter: Filter, } -impl SearchExt for BetaAccessor -where - Inner: SearchExt, -{ - fn starting_points(&self) -> impl Future>> + Send { - self.inner.starting_points() - } +struct Filter { + labels: Arc>, + beta: f32, } -impl<'a, Inner> DelegateNeighbor<'a> for BetaAccessor +impl Filter where - Inner: DelegateNeighbor<'a>, + I: VectorId, { - type Delegate = Inner::Delegate; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - self.inner.delegate_neighbor() + fn apply(&self, id: I, distance: f32) -> f32 { + if self.labels.is_match(id) { + distance * self.beta + } else { + distance + } } } -impl HasId for BetaAccessor -where - Inner: HasId, -{ - type Id = Inner::Id; -} - -impl Accessor for BetaAccessor +impl Explore for BetaAccessor where - Inner: Accessor, + Inner: Explore, { - /// Modify `Element` to retain the vector ID. - type Element<'a> - = Pair> - where - Self: 'a; - type ElementRef<'a> = Pair>; - - /// Use the same error type as `Inner`. - type GetError = Inner::GetError; - - /// Invoke `get_element` on the inner accessor and return a tuple consisting of the - /// retrieved element and `id`. - #[inline(always)] - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - // The first `map` applies to `Future`. - // The second `map` applies to the `Result`. - self.inner - .get_element(id) - .map(move |result| result.map(move |v| Pair::new(id, v))) + fn starting_points(&self) -> impl Future>> + Send { + self.inner.starting_points() } - /// Method `on_elements_unordered` is implemented by invoking - /// `inner.on_elements_unordered` with a decorated version of `f`. - async fn on_elements_unordered( + fn start_point_distances( &mut self, - itr: Itr, mut f: F, - ) -> Result<(), Self::GetError> + ) -> impl std::future::Future> + Send where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), + F: FnMut(Self::Id, f32) + Send, { - self.inner - .on_elements_unordered( - itr, - #[inline] - move |element, id| f(Pair::new(id, element), id), - ) - .await - } -} - -impl BuildQueryComputer for BetaAccessor -where - Inner: BuildQueryComputer, -{ - /// Use a [`BetaComputer`] to apply filtering. - type QueryComputer = BetaComputer; - /// Use the same error as `Inner`. - type QueryComputerError = Inner::QueryComputerError; - - fn build_query_computer( - &self, - from: T, - ) -> Result { - self.inner - .build_query_computer(from) - .map(|computer| BetaComputer::new(computer, self.labels.clone(), self.beta)) - } -} - -impl ExpandBeam for BetaAccessor where Inner: BuildQueryComputer + AsNeighbor {} - -/// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. -pub struct BetaComputer { - inner: Inner, - labels: Arc>, - beta: f32, -} - -impl BetaComputer -where - I: VectorId, -{ - /// Construct a new `BetaComputer` around `Inner`. - pub fn new(inner: Inner, labels: Arc>, beta: f32) -> Self { - Self { - inner, - labels, - beta, - } + let filter = &self.filter; + self.inner.start_point_distances(move |id, distance| { + f(id, filter.apply(id, distance)); + }) } - /// Return a reference to the inner computer. - pub fn inner(&self) -> &Inner { - &self.inner + fn expand_beam( + &mut self, + ids: Itr, + pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + let filter = &self.filter; + self.inner.expand_beam(ids, pred, move |distance, id| { + on_neighbors(filter.apply(id, distance), id) + }) } } -impl PreprocessedDistanceFunction, f32> for BetaComputer +impl HasId for BetaAccessor where - I: VectorId, - Inner: PreprocessedDistanceFunction, + Inner: HasId, { - /// Check whether the ID satisfied the predicate computed by the label provider. - /// - /// If so, multiply the distance computed by `Inner` by `beta`. - #[inline(always)] - fn evaluate_similarity(&self, x: Pair) -> f32 { - // Inner distance computation. - let distance = self.inner.evaluate_similarity(x.element); - // Check beta. - if self.labels.is_match(x.id) { - distance * self.beta - } else { - distance - } - } + type Id = Inner::Id; } /////////// @@ -351,264 +230,120 @@ where #[cfg(test)] mod tests { - use diskann::{ - ANNError, ANNResult, always_escalate, - graph::AdjacencyList, - graph::glue::CopyIds, - provider::{DefaultContext, NeighborAccessor, NoopGuard}, - }; - use futures_util::future; - use thiserror::Error; - use super::*; - /// A very simple data provider. - struct SimpleProvider; - impl DataProvider for SimpleProvider { - type Context = DefaultContext; - type InternalId = u32; - type ExternalId = u64; - type Guard = NoopGuard; - - type Error = ANNError; - - fn to_internal_id(&self, _context: &DefaultContext, gid: &u64) -> ANNResult { - Ok((*gid).try_into()?) - } - - fn to_external_id(&self, _context: &DefaultContext, id: u32) -> ANNResult { - Ok(id.into()) - } - } - - /// An `Accessor` that doubles its input ID as its output element. - /// - /// This also tracks the number of calls made to `get_element` and - /// `on_elements_unordered` to ensure that `BetaFilter` correctly forwards these methods. - #[derive(Debug, Default, Clone, Copy)] - struct Doubler { - get_element: usize, - on_elements_unordered: usize, - } - - impl SearchExt for Doubler { - async fn starting_points(&self) -> ANNResult> { - Ok(vec![0]) - } - } - - impl Doubler { - fn reset(&mut self) { - *self = Self::default(); - } - } - - /// A simple error type to test error forwarding. - #[derive(Debug, Error)] - #[error("the value {0} is not allowed")] - pub struct NotAllowed(u32); - - impl From for ANNError { - #[inline(never)] - fn from(value: NotAllowed) -> Self { - ANNError::log_async_error(value) - } - } - - impl HasId for Doubler { - type Id = u32; - } - - impl NeighborAccessor for Doubler { - fn get_neighbors( - self, - _id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - neighbors.clear(); - future::ok(self) - } - } - - always_escalate!(NotAllowed); - - impl Accessor for Doubler { - type Element<'a> - = u64 - where - Self: 'a; - type ElementRef<'a> = u64; - - type GetError = NotAllowed; - - fn get_element( - &mut self, - id: u32, - ) -> impl std::future::Future, Self::GetError>> + Send - { - self.get_element += 1; - let is_err = (100..200).contains(&id); - - async move { - if is_err { - Err(NotAllowed(id)) - } else { - let id: u64 = id.into(); - Ok(2 * id) - } - } - } + use diskann::graph::{ + glue::{HybridPredicate, Predicate, PredicateMut}, + test::{provider as test_provider, synthetic::Grid}, + }; + use std::collections::HashSet; - async fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> Result<(), Self::GetError> - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - { - self.on_elements_unordered += 1; - for i in itr { - f(self.get_element(i).await?, i); - } - Ok(()) - } - } + /// A simple `QueryLabelProvider` that matches multiples of 3. + #[derive(Debug)] + struct ThreeFilter; - struct AddingComputer(u64); - impl PreprocessedDistanceFunction for AddingComputer { - fn evaluate_similarity(&self, x: u64) -> f32 { - (self.0 + x) as f32 + impl QueryLabelProvider for ThreeFilter { + fn is_match(&self, id: u32) -> bool { + id.is_multiple_of(3) } } - impl BuildQueryComputer for Doubler { - type QueryComputer = AddingComputer; - type QueryComputerError = ANNError; + struct NotIn<'a>(&'a mut HashSet); - fn build_query_computer( - &self, - from: u64, - ) -> Result { - Ok(AddingComputer(from)) + impl Predicate for NotIn<'_> { + fn eval(&self, item: &u32) -> bool { + !self.0.contains(item) } } - impl ExpandBeam for Doubler {} - - #[derive(Debug)] - struct SimpleStrategy; - - impl SearchStrategy for SimpleStrategy { - type SearchAccessor<'a> = Doubler; - type QueryComputer = AddingComputer; - type SearchAccessorError = ANNError; - - fn search_accessor<'a>( - &'a self, - _provider: &'a SimpleProvider, - _context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - Ok(Doubler::default()) + impl PredicateMut for NotIn<'_> { + fn eval_mut(&mut self, item: &u32) -> bool { + self.0.insert(*item) } } - impl glue::DefaultPostProcessor for SimpleStrategy { - diskann::default_post_processor!(CopyIds); - } - - /// A simple `QueryLabelProvider` that matches multiples of 3. - #[derive(Debug)] - struct ThreeFilter; - - impl QueryLabelProvider for ThreeFilter { - fn is_match(&self, id: u32) -> bool { - id.is_multiple_of(3) - } - } + impl HybridPredicate for NotIn<'_> {} #[tokio::test] async fn test_beta_filter() { - let provider = SimpleProvider; - let context = &DefaultContext; + // The grid of 4x4 will look like this: + // + // | 16 + // | 3 7 11 15 + // | 2 6 10 14 + // | 1 5 9 13 + // | 0 4 8 12 + // +--------------- + // + let provider = test_provider::Provider::grid(Grid::Two, 4).unwrap(); + let context = test_provider::Context::new(); + let beta: f32 = 0.25; - let strategy = BetaFilter::new(SimpleStrategy, Arc::new(ThreeFilter), beta); - - let mut accessor: BetaAccessor<_> = strategy.search_accessor(&provider, context).unwrap(); - assert_eq!(accessor.inner.get_element, 0); - assert_eq!(accessor.inner.on_elements_unordered, 0); - - // Test non-erroring path. - let v = accessor.get_element(1).await.unwrap(); - assert_eq!(v, Pair::new(1, 2)); - - let v = accessor.get_element(2).await.unwrap(); - assert_eq!(v, Pair::new(2, 4)); - - // Test erroring path. - assert!(accessor.get_element(100).await.is_err()); - assert!(accessor.get_element(101).await.is_err()); - - assert_eq!(accessor.inner.get_element, 4); - assert_eq!(accessor.inner.on_elements_unordered, 0); - accessor.inner.reset(); - - // On elements unordered. - { - let mut v = Vec::new(); - accessor - .on_elements_unordered([1, 2, 3, 4, 5].into_iter(), |element, id| { - v.push((element, id)); - }) - .await - .unwrap(); - - assert_eq!(accessor.inner.get_element, 5); - assert_eq!(accessor.inner.on_elements_unordered, 1); - assert_eq!( - v, - &[ - (Pair::new(1, 2), 1), - (Pair::new(2, 4), 2), - (Pair::new(3, 6), 3), - (Pair::new(4, 8), 4), - (Pair::new(5, 10), 5) - ] - ); - accessor.inner.reset(); - } + let strategy = BetaFilter::new(test_provider::Strategy::new(), Arc::new(ThreeFilter), beta); - // On-elements-unordered propagates errors. - assert!( - accessor - .on_elements_unordered([1, 2, 3, 100, 4].into_iter(), |_, _| {}) - .await - .is_err() + let start_point_ids: Vec<_> = provider.start_point_ids().collect(); + assert_eq!( + start_point_ids.len(), + 1, + "grid should only have a single start point" + ); + let start_point = start_point_ids[0]; + assert_eq!( + start_point, + u32::MAX, + "`Provider::grid` is documented to use `u32::MAX` as its start point", ); - // Computation. - let query = 10; - let computer = accessor.build_query_computer(query).unwrap(); + let mut visited = HashSet::new(); + let mut buf = Vec::new(); + + let mut accessor = strategy + .search_accessor(&provider, &context, &[0.0, 0.0]) + .unwrap(); assert_eq!( - computer.evaluate_similarity(accessor.get_element(10).await.unwrap()), - (10 * 2 + query) as f32 + &*accessor.starting_points().await.unwrap(), + &*start_point_ids, + "the underlying start points should match", ); + + accessor + .expand_beam( + [0, 5, 10, 15].into_iter(), + NotIn(&mut visited), + |distance, id| buf.push((distance, id)), + ) + .await + .unwrap(); + + // The expansion order is unknown, but we know from the structure of the grid what + // the result should be. + // + // Since each entry in the beam is connected + buf.sort_by_key(|(_, id)| *id); assert_eq!( - computer.evaluate_similarity(accessor.get_element(11).await.unwrap()), - (11 * 2 + query) as f32 + &*buf, + [ + (1.0, 1), + (1.0, 4), + (5.0 * beta, 6), + (5.0 * beta, 9), + (13.0, 11), + (13.0, 14), + ] ); + + buf.clear(); + accessor + .start_point_distances(|id, distance| buf.push((distance, id))) + .await + .unwrap(); + assert_eq!( - computer.evaluate_similarity(accessor.get_element(12).await.unwrap()), - beta * ((12 * 2 + query) as f32) + &*buf, + [(32.0 * beta, start_point)], + "u32::MAX is a multiple of 3" ); - - // Test dummy implementation of `get_neighbors` for code coverage. - let mut neighbors = AdjacencyList::new(); - accessor.get_neighbors(0, &mut neighbors).await.unwrap(); - assert_eq!(neighbors.len(), 0); } } diff --git a/diskann-providers/src/storage/index_storage.rs b/diskann-providers/src/storage/index_storage.rs index 21a1d9333..8e076ea3a 100644 --- a/diskann-providers/src/storage/index_storage.rs +++ b/diskann-providers/src/storage/index_storage.rs @@ -219,10 +219,10 @@ mod tests { use crate::storage::VirtualStorageProvider; use diskann::{ graph::{AdjacencyList, config, glue::InsertStrategy}, - provider::{Accessor, SetElement}, + provider::SetElement, utils::{IntoUsize, ONE}, }; - use diskann_utils::{Reborrow, test_data_root, views::MatrixView}; + use diskann_utils::{test_data_root, views::MatrixView}; use diskann_vector::distance::Metric; use super::*; @@ -231,7 +231,6 @@ mod tests { model::graph::provider::async_::{ SimpleNeighborProviderAsync, common::{FullPrecision, NoDeletes, NoStore, TableBasedDeletes}, - inmem::{self}, }, utils::create_rnd_from_seed_in_tests, }; @@ -346,12 +345,12 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - id_iter.clone(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // id_iter.clone(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -367,12 +366,12 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - id_iter.clone(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // id_iter.clone(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -389,19 +388,19 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; - - check_accessor_equal( - inmem::product::QuantAccessor::new(index.provider()), - inmem::product::QuantAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; + + // check_accessor_equal( + // inmem::product::QuantAccessor::new(index.provider()), + // inmem::product::QuantAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -417,19 +416,19 @@ mod tests { .unwrap(); assert_eq!(id_iter, reloaded.data_provider.iter()); - check_accessor_equal( - inmem::FullAccessor::new(index.provider()), - inmem::FullAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; - - check_accessor_equal( - inmem::product::QuantAccessor::new(index.provider()), - inmem::product::QuantAccessor::new(reloaded.provider()), - index.data_provider.iter(), - ) - .await; + // check_accessor_equal( + // inmem::FullAccessor::new(index.provider()), + // inmem::FullAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; + + // check_accessor_equal( + // inmem::product::QuantAccessor::new(index.provider()), + // inmem::product::QuantAccessor::new(reloaded.provider()), + // index.data_provider.iter(), + // ) + // .await; check_graphs_equal( &index.provider().neighbor_provider, @@ -439,22 +438,22 @@ mod tests { } } - async fn check_accessor_equal(mut left: A, mut right: B, itr: Itr) - where - A: for<'a> Accessor = &'a T>, - B: for<'a> Accessor = &'a T>, - T: PartialEq + std::fmt::Debug + ?Sized, - Itr: Iterator, - { - for i in itr { - assert_eq!( - left.get_element(i).await.unwrap().reborrow(), - right.get_element(i).await.unwrap().reborrow(), - "failed for index {}", - i - ); - } - } + // async fn check_accessor_equal(mut left: A, mut right: B, itr: Itr) + // where + // A: for<'a> Accessor = &'a T>, + // B: for<'a> Accessor = &'a T>, + // T: PartialEq + std::fmt::Debug + ?Sized, + // Itr: Iterator, + // { + // for i in itr { + // assert_eq!( + // left.get_element(i).await.unwrap().reborrow(), + // right.get_element(i).await.unwrap().reborrow(), + // "failed for index {}", + // i + // ); + // } + // } fn check_graphs_equal( left: &SimpleNeighborProviderAsync, diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 5e691d515..778db1e39 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -79,27 +79,85 @@ use std::{future::Future, sync::Arc}; use diskann_utils::Reborrow; -use diskann_utils::future::AssertSend; -use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; +use diskann_vector::DistanceFunction; use crate::{ ANNError, ANNResult, - error::{ErrorExt, StandardError}, - graph::{AdjacencyList, SearchOutputBuffer, workingset}, + error::StandardError, + graph::{SearchOutputBuffer, workingset}, neighbor::Neighbor, - provider::{ - Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, HasId, NeighborAccessor, - }, + provider::{AsNeighborMut, BuildDistanceComputer, DataProvider, HasElementRef, HasId}, }; /// A trait to override search constraints such as early termination based on constraints /// by implementer. -pub trait SearchExt: Accessor { +pub trait Explore: HasId + Send + Sync { /// Return a `Vec` containing the starting points. fn starting_points(&self) -> impl std::future::Future>> + Send; + fn start_point_distances( + &mut self, + f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send; + + /// A primitive routine used by graph search. This is purposely implemented as a + /// coarse-grained operation to enable optimization opportunities by the backend. + /// + /// # Description + /// + /// For each `i` in `itr`, fetch the adjacency list `v_i` for `i`. For each `v_i`, then + /// for each id `j` in `v_i`, compute the distance `d` using `computer` to the data + /// associated with `j` and invoke the closure `f` with `d` and `j`, provided + /// `pred.eval_mut(j)` evaluates to `true`. + /// + /// No specification is made on the traversal order of `ids`, the computation order of + /// the leaf elements, nor the order in which `pred` is evaluated. + /// + /// Restriction in the implementation are as follows: + /// + /// * If `pred.eval_mut()` returns `true` for an id `i`, then `on_neighbors` must be + /// invoked for that item. + /// + /// If an item `i` is already passed to `on_neighbors`, the implementation is not + /// obligated to provided it again, though it **may** do so provided `pred.eval_mut()` + /// continues to return `true`. + /// + /// * If `pred.eval_mut()` returns `false` for an item, then `on_neighbors` must not be + /// invoked for that item. + /// + /// * `pred.eval()` may be invoked an arbitrary number of times. Proper predicate + /// implementations will ensure that + /// + /// - `pred.eval() == true` implies `pred.eval_mut() == true` if `pred.eval_mut()` is + /// invoked immediately after `pred.eval()`. + /// + /// - `pred.eval() == false` implies `pred.eval_mut() == false` and vice-versa. + /// + /// * `pred.eval_mut()` must be invoked at most once for all transitive items in the beam. + /// + /// ## Predicate Requirements + /// + /// Well behaved predicates must never return `true` (allow an id to be forwarded to + /// `on_neighbors`) if it previously returned `false`. Implementations of `ExpandBeam` + /// are allowed to assume this holds. + /// + /// Additionally, the callback `on_neighbors` and the predicate have to cooperate. If the + /// callback requires unique items, the predicate must be structured such that `eval_mut` + /// correctly filters out duplicates. + fn expand_beam( + &mut self, + ids: Itr, + pred: P, + on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send; + /// Default is to never terminate early. fn terminate_early(&mut self) -> bool { false @@ -192,99 +250,6 @@ where /// The interfaces `contains` and `insert` agree with each other. impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash::Hash {} -/// A primitive routine used by graph search. This is purposely implemented as a -/// coarse-grained operation to enable optimization opportunities by the backend. -/// -/// # Description -/// -/// For each `i` in `itr`, fetch the adjacency list `v_i` for `i`. For each `v_i`, then for -/// each id `j` in `v_i`, compute the distance `d` using `computer` to the data associated with -/// `j` and invoke the closure `f` with `d` and `j`, provided `pred.eval_mut(j)` evaluates to -/// `true`. -/// -/// No specification is made on the traversal order of `ids`, the computation order of the -/// leaf elements, nor the order in which `pred` is evaluated. -/// -/// Restriction in the implementation are as follows: -/// -/// * If `pred.eval_mut()` returns `true` for an id `i`, then `on_neighbors` must be invoked -/// for that item. -/// -/// If an item `i` is already passed to `on_neighbors`, the implementation is not obligated -/// to provided it again, though it **may** do so provided `pred.eval_mut()` continues to -/// return `true`. -/// -/// * If `pred.eval_mut()` returns `false` for an item, then `on_neighbors` must not be -/// invoked for that item. -/// -/// * `pred.eval()` may be invoked an arbitrary number of times. Proper predicate -/// implementations will ensure that -/// -/// - `pred.eval() == true` implies `pred.eval_mut() == true` if `pred.eval_mut()` is -/// invoked immediately after `pred.eval()`. -/// -/// - `pred.eval() == false` implies `pred.eval_mut() == false` and vice-versa. -/// -/// * `pred.eval_mut()` must be invoked at most once for all transitive items in the beam. -/// -/// ## Predicate Requirements -/// -/// Well behaved predicates must never return `true` (allow an id to be forwarded to -/// `on_neighbors`) if it previously returned `false`. Implementations of `ExpandBeam` are -/// allowed to assume this holds. -/// -/// Additionally, the callback `on_neighbors` and the predicate have to cooperate. If the -/// callback requires unique items, the predicate must be structured such that `eval_mut` -/// correctly filters out duplicates. -/// -/// # Provided Implementation -/// -/// The provided implementation works on each element of `ids` sequentially, pre-filters -/// the resulting candidate list using `pred.eval()` before invoking -/// [`BuildQueryComputer::distances_unordered`]. -/// -/// The callback `on_neighbors` is decorated to the uses `pred.eval_mut()`. -/// -/// This ensures that if `distances_unordered` errors, the predicate is not erroneously -/// updated. -/// -/// ## Error Handling -/// -/// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: BuildQueryComputer + AsNeighbor + Sized { - fn expand_beam( - &mut self, - ids: Itr, - computer: &Self::QueryComputer, - mut pred: P, - mut on_neighbors: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - P: HybridPredicate + Send + Sync, - F: FnMut(f32, Self::Id) + Send, - { - async move { - let mut neighbors = AdjacencyList::new(); - for id in ids { - self.get_neighbors(id, &mut neighbors).send().await?; - neighbors.retain(|i| pred.eval(i)); - - self.distances_unordered(neighbors.iter().copied(), computer, |distance, id| { - if pred.eval_mut(&id) { - on_neighbors(distance, id); - } - }) - .send() - .await - .allow_transient("allowing transient error in beam expansion")?; - } - - Ok(()) - } - } -} - /// A search strategy for query objects of type `T`. /// /// This trait should be overloaded by data providers wishing to extend @@ -293,34 +258,26 @@ pub trait SearchStrategy: Send + Sync where Provider: DataProvider, { - /// The computer used by the associated accessor. - /// - /// We could grab this type from the `SearchAccessor` associated type, but it's - /// useful enough that we move it up here. - type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as Accessor>::ElementRef<'b>, - f32, - > + Send - + Sync - + 'static; - /// An error that can occur when getting a search_accessor. - type SearchAccessorError: StandardError; + type SearchAccessorError: StandardError + for<'a> DoesNotCaptureSelf>; /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. - type SearchAccessor<'a>: ExpandBeam - + SearchExt; + type SearchAccessor<'a>: Explore; /// Construct and return the search accessor. fn search_accessor<'a>( &'a self, provider: &'a Provider, context: &'a Provider::Context, + query: T, ) -> Result, Self::SearchAccessorError>; } +pub trait DoesNotCaptureSelf {} +impl DoesNotCaptureSelf for U {} + /// Opt-in trait for strategies that have a default post-processor. /// /// Strategies implementing this trait can be used with index-level search APIs such as @@ -385,7 +342,7 @@ macro_rules! default_post_processor { /// directly into the output buffer. pub trait SearchPostProcess::Id> where - A: BuildQueryComputer, + A: HasId, { type Error: StandardError; @@ -395,7 +352,6 @@ where &self, accessor: &mut A, query: T, - computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -411,14 +367,13 @@ pub struct CopyIds; impl SearchPostProcess for CopyIds where - A: BuildQueryComputer, + A: HasId, { type Error = std::convert::Infallible; fn post_process( &self, _accessor: &mut A, _query: T, - _computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -436,7 +391,7 @@ where /// using a [`Pipeline`]. pub trait SearchPostProcessStep::Id> where - A: BuildQueryComputer, + A: HasId, { /// A potentially modified version of the error yielded by the next state in the /// processing pipeline. @@ -445,7 +400,7 @@ where NextError: StandardError; /// The accessor that will be passed to the next processing stage. - type NextAccessor: BuildQueryComputer; + type NextAccessor: HasId; /// Perform any modification the `input`, `output`, `accessor`, or `computer` objects /// and invoke the [`SearchPostProcess`] routine `next` on stage. @@ -454,7 +409,6 @@ where next: &Next, accessor: &mut A, query: T, - computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future>> + Send @@ -470,7 +424,7 @@ pub struct FilterStartPoints; impl SearchPostProcessStep for FilterStartPoints where - A: BuildQueryComputer + SearchExt, + A: Explore, T: Copy + Send + Sync, { /// A this level, sub-errors are converted into [`ANNError`] to provide additional @@ -488,7 +442,6 @@ where next: &Next, accessor: &mut A, query: T, - computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> ANNResult @@ -498,18 +451,12 @@ where Next: SearchPostProcess + Sync, { let filter = accessor.is_not_start_point().await?; - next.post_process( - accessor, - query, - computer, - candidates.filter(|n| filter(n.id)), - output, - ) - .await - .map_err(|err| { - let err = err.into(); - err.context("after filtering start points") - }) + next.post_process(accessor, query, candidates.filter(|n| filter(n.id)), output) + .await + .map_err(|err| { + let err = err.into(); + err.context("after filtering start points") + }) } } @@ -539,7 +486,7 @@ impl Pipeline { impl SearchPostProcess for Pipeline where - A: BuildQueryComputer, + A: HasId, Head: SearchPostProcessStep, Tail: SearchPostProcess + Sync, { @@ -549,7 +496,6 @@ where &self, accessor: &mut A, query: T, - computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -558,7 +504,7 @@ where B: SearchOutputBuffer + Send + ?Sized, { self.head - .post_process_step(&self.tail, accessor, query, computer, candidates, output) + .post_process_step(&self.tail, accessor, query, candidates, output) } } @@ -588,8 +534,9 @@ where &'a self, provider: &'a Provider, context: &'a Provider::Context, + vector: T, ) -> Result, Self::SearchAccessorError> { - self.search_accessor(provider, context) + self.search_accessor(provider, context, vector) } } @@ -613,9 +560,8 @@ where /// We could grab this type from the `PruneAccessor` associated type, but it's /// useful enough that we move it up here. type DistanceComputer<'computer>: for<'a, 'b, 'c, 'd> DistanceFunction< - as Accessor>::ElementRef<'b>, - as Accessor>::ElementRef<'d>, - f32, + as HasElementRef>::ElementRef<'b>, + as HasElementRef>::ElementRef<'d>, > + Send + Sync; @@ -627,10 +573,11 @@ where /// /// Implementations are encouraged to have [`Accessor::get_element`] return the /// highest-precision applicable value for a given element type. - type PruneAccessor<'a>: Accessor - + BuildDistanceComputer> + type PruneAccessor<'a>: BuildDistanceComputer> + AsNeighborMut - + workingset::Fill; + + workingset::Fill + + Send + + Sync; /// An error that can occur when getting the prune accessor. type PruneAccessorError: StandardError; @@ -785,8 +732,7 @@ where /// of associated types. /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. - type DeleteSearchAccessor<'a>: ExpandBeam, Id = Provider::InternalId> - + SearchExt; + type DeleteSearchAccessor<'a>: Explore; /// The processor used during the delete-search phase. type SearchPostProcessor: for<'a> SearchPostProcess, Self::DeleteElement<'a>> @@ -824,44 +770,18 @@ where #[cfg(test)] mod tests { - use std::sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }; - - use diskann_vector::PreprocessedDistanceFunction; - use futures_util::future; - use super::*; use crate::{ ANNResult, neighbor, - provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor}, + provider::{DefaultContext, HasId}, }; // A really simple provider that just holds floats and uses the absolute value for its // distances. - struct SimpleProvider { - items: Vec, - } - - #[derive(Default, Clone)] - struct CountGetVector { - count: Arc, - } - impl ExecutionContext for CountGetVector {} - - impl CountGetVector { - fn count(&self) -> usize { - self.count.load(Ordering::Relaxed) - } - - fn clear(&self) { - self.count.store(0, Ordering::Relaxed) - } - } + struct SimpleProvider; impl DataProvider for SimpleProvider { - type Context = CountGetVector; + type Context = DefaultContext; type InternalId = u32; type ExternalId = u32; type Error = ANNError; @@ -870,7 +790,7 @@ mod tests { /// Translate an external id to its corresponding internal id. fn to_internal_id( &self, - _context: &CountGetVector, + _context: &DefaultContext, gid: &Self::ExternalId, ) -> Result { Ok(*gid) @@ -879,7 +799,7 @@ mod tests { /// Translate an internal id to its corresponding external id. fn to_external_id( &self, - _context: &CountGetVector, + _context: &DefaultContext, id: Self::InternalId, ) -> Result { Ok(id) @@ -887,94 +807,54 @@ mod tests { } #[derive(Clone, Copy)] - struct Retriever<'a> { - provider: &'a SimpleProvider, - count: &'a CountGetVector, - } + struct Accessor; - impl SearchExt for Retriever<'_> { + impl Explore for Accessor { async fn starting_points(&self) -> ANNResult> { - Ok(vec![0]) + unimplemented!(); } - } - - impl<'a> Retriever<'a> { - fn new(provider: &'a SimpleProvider, count: &'a CountGetVector) -> Self { - Self { provider, count } - } - } - - impl HasId for Retriever<'_> { - type Id = u32; - } - impl Accessor for Retriever<'_> { - type Element<'a> - = f32 + async fn start_point_distances(&mut self, _f: F) -> ANNResult<()> where - Self: 'a; - type ElementRef<'a> = f32; - - type GetError = ANNError; - fn get_element( - &mut self, - id: Self::Id, - ) -> impl std::future::Future, Self::GetError>> + Send + F: FnMut(Self::Id, f32) + Send, { - let result = match self.provider.items.get(id as usize) { - Some(v) => { - self.count.count.fetch_add(1, Ordering::Relaxed); - Ok(*v) - } - None => panic!("invalid id: {}", id), - }; - async move { result } + unimplemented!(); } - } - - impl NeighborAccessor for Retriever<'_> { - fn get_neighbors( - self, - _id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - neighbors.clear(); - future::ok(self) - } - } - struct QueryComputer; - impl PreprocessedDistanceFunction for QueryComputer { - fn evaluate_similarity(&self, _changing: f32) -> f32 { - panic!("this method should not be called") + async fn expand_beam( + &mut self, + _ids: Itr, + _pred: P, + _on_neighbors: F, + ) -> ANNResult<()> + where + Itr: Iterator + Send, + P: HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + unimplemented!(); } } - impl BuildQueryComputer for Retriever<'_> { - type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; - fn build_query_computer(&self, _from: f32) -> Result { - Ok(QueryComputer) - } + impl HasId for Accessor { + type Id = u32; } - impl ExpandBeam for Retriever<'_> {} - // This strategy explicitly does not define `post_process` so we can test the provided // implementation. struct Strategy; impl SearchStrategy for Strategy { - type QueryComputer = QueryComputer; type SearchAccessorError = ANNError; - type SearchAccessor<'a> = Retriever<'a>; + type SearchAccessor<'a> = Accessor; fn search_accessor<'a>( &'a self, - provider: &'a SimpleProvider, - context: &'a CountGetVector, + _provider: &'a SimpleProvider, + _context: &'a DefaultContext, + _query: f32, ) -> Result, Self::SearchAccessorError> { - Ok(Retriever::new(provider, context)) + Ok(Accessor) } } @@ -984,38 +864,15 @@ mod tests { #[tokio::test(flavor = "current_thread")] async fn test_default_post_process() { - let ctx = CountGetVector::default(); + let ctx = DefaultContext; let strategy = Strategy; - - let num_points: usize = 100; - let provider = SimpleProvider { - items: (0..num_points).map(|i| i as f32).collect(), - }; + let provider = SimpleProvider; assert_eq!(provider.to_internal_id(&ctx, &10).unwrap(), 10); assert_eq!(provider.to_external_id(&ctx, 10).unwrap(), 10); - let mut accessor = strategy.search_accessor(&provider, &ctx).unwrap(); - assert_eq!(accessor.starting_points().await.unwrap().as_slice(), &[0]); - for i in 0..num_points { - assert_eq!(accessor.get_element(i as u32).await.unwrap(), i as f32); - } - - // Check dummy get_neighbors implmeentation for code coverage - let mut neighbors = AdjacencyList::new(); - accessor - .delegate_neighbor() - .get_neighbors(0, &mut neighbors) - .await - .unwrap(); - assert_eq!(neighbors.len(), 0); - - // Check that the correct number of reads were emitted. - assert_eq!(ctx.count(), num_points); - ctx.clear(); - let query = 11.5; - let computer = accessor.build_query_computer(query).unwrap(); + let mut accessor = strategy.search_accessor(&provider, &ctx, query).unwrap(); for input_len in 0..10 { let input: Vec<_> = (0..input_len) @@ -1029,7 +886,6 @@ mod tests { .post_process( &mut accessor, query, - &computer, input.iter().copied(), &mut neighbor::BackInserter::new(output.as_mut_slice()), ) @@ -1051,8 +907,5 @@ mod tests { } } } - - // Ensure that no reads were emitted. - assert_eq!(ctx.count(), 0); } } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 38e30dbc2..09496e74b 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -16,7 +16,7 @@ use diskann_utils::{ Reborrow, future::{AssertSend, SendFuture, boxit}, }; -use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; +use diskann_vector::DistanceFunction; use futures_util::FutureExt; use hashbrown::HashSet; use thiserror::Error; @@ -25,11 +25,12 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ - self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + self, Batch, Explore, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, + PruneStrategy, SearchPostProcess, SearchStrategy, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ + PagedSearch, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -41,11 +42,10 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::{ErrorExt, IntoANNResult}, internal, - neighbor::{self, Neighbor, NeighborPriorityQueue, NeighborQueue}, + neighbor::{self, Neighbor, NeighborQueue}, provider::{ - Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, Delete, ElementStatus, ExecutionContext, Guard, NeighborAccessor, - NeighborAccessorMut, SetElement, + AsNeighbor, AsNeighborMut, BuildDistanceComputer, DataProvider, Delete, ElementStatus, + ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement, }, tracked_debug, tracked_error, tracked_trace, utils::{ @@ -136,34 +136,6 @@ pub struct PartitionedNeighbors { pub deleted: Vec, } -/// Placeholder for extra state. -/// -/// The contents of the search state are designed for the synchronous index. -/// However, use cases in the asynchronous index require some extra state. -/// -/// This placeholder is used in the synchronous code-paths. -pub struct NoExtraState; - -/// Represents the state of the pagged search. -/// It can be used to do paged search by doing multiple `nextSearchResults()` queries. -/// -/// Generic extra state can be included to facilitate extra use-cases. -/// However, this extra state **must** be '`static' as we do not know how long the search -/// state will live for. -#[derive(Debug)] -pub struct SearchState { - /// Scratch space for query processing. - pub scratch: SearchScratch, - /// The computed search results ready to be returned in `nextSearchResults()` query - pub computed_result: Vec>, - /// The index of the next result to be returned. - pub next_result_index: usize, - /// The search computes results in the multiple of `search_param_l`. - pub search_param_l: usize, - /// Any extra data needed by down-stream implementations. - pub extra: ExtraState, -} - /// Edge pending submission for multi-insert. #[derive(Debug)] struct PendingEdge { @@ -203,16 +175,6 @@ where /// and `Err` paths. type BatchResult = Result; -/// State used during by paged search to perform multiple, consecutive searches over the index. -/// -/// Type parameters: -/// -/// * `DP`: The type of the [`DataProvider`]. -/// * `S`: The type of the [`SearchStrategy`]. -/// * `C`: The type of `S`'s [`BuildQueryComputer`] computer. This exists as a separate -/// type parameter because the type of the query computer depends on the type of the query. -pub type PagedSearchState = SearchState<::InternalId, (S, C)>; - impl DiskANNIndex where DP: DataProvider, @@ -313,11 +275,9 @@ where // NOTE: Use the API `insert_search_accessor` to allow `Accessor` customization. let mut accessor = strategy - .insert_search_accessor(&self.data_provider, context) + .insert_search_accessor(&self.data_provider, context, vector) .into_ann_result()?; - let computer = accessor.build_query_computer(vector).into_ann_result()?; - // NOTE: We don't filter the start points out of `visited_nodes`, as those are // needed to generate out edges from the start points. let start_ids = accessor.starting_points().await?; @@ -345,9 +305,7 @@ where self.search_internal( None, // beam_width - &start_ids, &mut accessor, - &computer, &mut scratch, &mut search_record, ) @@ -449,14 +407,13 @@ where // Copy vectors to the vector provider, quantize them and set quant vec provider if necessary let internal_id = ids[position]; + let vector = batch.get(position); + // NOTE: Use the `insert_search_accessor` API to allow insert-specific customization. let mut accessor = strategy - .insert_search_accessor(&self.data_provider, context) + .insert_search_accessor(&self.data_provider, context, vector) .into_ann_result()?; - let computer = accessor - .build_query_computer(batch.get(position)) - .into_ann_result()?; let start_ids = accessor.starting_points().await?; let mut scratch = self.search_scratch(self.l_build(), start_ids.len()); @@ -474,9 +431,7 @@ where self.search_internal( None, // beam_width - &start_ids, &mut accessor, - &computer, &mut scratch, &mut search_record, ) @@ -1277,11 +1232,7 @@ where let search_strategy = strategy.search_strategy(); let mut search_accessor = search_strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let computer = search_accessor - .build_query_computer(v.reborrow()) + .search_accessor(&self.data_provider, context, v.reborrow()) .into_ann_result()?; let start_ids = search_accessor.starting_points().await?; @@ -1290,9 +1241,7 @@ where self.search_internal( None, // beam_width - &start_ids, &mut search_accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), ) @@ -1307,7 +1256,6 @@ where .post_process( &mut search_accessor, v.reborrow(), - &computer, scratch.best.iter(), &mut neighbor::BackInserter::new(output.as_mut_slice()), ) @@ -1321,8 +1269,13 @@ where .collect(); // Collect IDs whose adjacency lists need to be updated. + let prune_strategy = strategy.prune_strategy(); + let mut prune_accessor = prune_strategy + .prune_accessor(self.provider(), context) + .into_ann_result()?; + let ids_to_modify = self - .return_refs_to_deleted_vertex(&mut search_accessor, id, &undeleted_ids) + .return_refs_to_deleted_vertex(&mut prune_accessor, id, &undeleted_ids) .await?; undeleted_ids.truncate(k_value); @@ -1701,9 +1654,9 @@ where .await .escalate("`inplace_delete` requires a successful delete")?; - let search_strategy = strategy.search_strategy(); - let accessor = &mut search_strategy - .search_accessor(&self.data_provider, context) + let prune_strategy = strategy.prune_strategy(); + let mut accessor = prune_strategy + .prune_accessor(&self.data_provider, context) .into_ann_result()?; let InplaceDeleteWorkList { @@ -1718,21 +1671,16 @@ where .await? } InplaceDeleteMethod::TwoHopAndOneHop => { - self.get_candidates_using_twohop_and_onehop(context, accessor, vector_id) + self.get_candidates_using_twohop_and_onehop(context, &mut accessor, vector_id) .await? } InplaceDeleteMethod::OneHop => { - self.get_candidates_using_onehop(context, accessor, vector_id) + self.get_candidates_using_onehop(context, &mut accessor, vector_id) .await? } }; - let prune_strategy = strategy.prune_strategy(); let mut working_set = prune_strategy.create_working_set(self.max_occlusion_size()); - let mut accessor = prune_strategy - .prune_accessor(&self.data_provider, context) - .into_ann_result()?; - let mut edges_to_add = HashMap::>::new(); // fetch the filtered adjacency list of `p`. @@ -2013,18 +1961,15 @@ where } } - // A is the accessor type, T is the query type used for BuildQueryComputer - pub(crate) fn search_internal( + pub(crate) fn search_internal( &self, beam_width: Option, - start_ids: &[DP::InternalId], accessor: &mut A, - computer: &A::QueryComputer, scratch: &mut SearchScratch, search_record: &mut SR, ) -> impl SendFuture> where - A: ExpandBeam + SearchExt, + A: Explore, SR: SearchRecord + ?Sized, Q: NeighborQueue, { @@ -2034,16 +1979,13 @@ where // paged search can call search_internal multiple times, we only need to initialize // state if not already initialized. if scratch.visited.is_empty() { - for id in start_ids { - scratch.visited.insert(*id); - let element = accessor - .get_element(*id) - .await - .escalate("start point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); - scratch.best.insert(Neighbor::new(*id, dist)); - scratch.cmps += 1; - } + accessor + .start_point_distances(|id, dist| { + scratch.visited.insert(id); + scratch.best.insert(Neighbor::new(id, dist)); + scratch.cmps += 1; + }) + .await?; } let mut neighbors = Vec::with_capacity(self.max_degree_with_slack()); @@ -2063,7 +2005,6 @@ where accessor .expand_beam( scratch.beam_nodes.iter().copied(), - computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| neighbors.push(Neighbor::new(id, distance)), ) @@ -2089,35 +2030,6 @@ where } } - /// Filter out start nodes from the best candidates in the scratch. - fn filter_search_candidates( - &self, - start_points: &[DP::InternalId], - l_value: usize, - best: &mut NeighborPriorityQueue, - ) -> impl SendFuture>, usize)>> { - async move { - let mut total = 0usize; - let mut candidates = Vec::with_capacity(l_value); - for n in best.iter() { - total += 1; - if !start_points.contains(&n.id) { - candidates.push(n); - if candidates.len() >= l_value { - break; - } - } - } - - debug_assert!( - l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), - "Not enough candidates after filtering starting points", - ); - - Ok((candidates, total)) - } - } - /// Execute a search using the unified search interface. /// /// This method provides a single entry point for all search types. The `search_params` argument @@ -2186,43 +2098,49 @@ where // Paged Search // ////////////////// - pub fn start_paged_search( - &self, - strategy: S, - context: &DP::Context, + /// Begin a paged search over the index. + /// + /// Returns a [`PagedSearch`] handle whose [`next_page`](PagedSearch::next_page) method + /// yields successive pages of nearest-neighbor results. + pub fn paged_search<'a, S, T>( + &'a self, + strategy: &'a S, + context: &'a DP::Context, query: T, l_value: usize, - ) -> impl SendFuture>> + ) -> impl SendFuture>>> where - S: SearchStrategy + 'static, - T: Copy + Send, + S: SearchStrategy, + T: Copy + Send + 'a, { async move { - self.start_paged_search_with_init_ids(strategy, context, query, l_value, None) + self.paged_search_with_init_ids(strategy, context, query, l_value, None) .await } } - pub fn start_paged_search_with_init_ids( - &self, - strategy: S, - context: &DP::Context, + /// Begin a paged search with explicit initial seed IDs. + /// + /// This is the same as [`paged_search`](Self::paged_search) but allows the caller to + /// provide custom starting points for the graph traversal. + pub fn paged_search_with_init_ids<'a, S, T>( + &'a self, + strategy: &'a S, + context: &'a DP::Context, query: T, l_value: usize, - init_ids: Option<&[DP::InternalId]>, - ) -> impl SendFuture>> + init_ids: Option<&'a [DP::InternalId]>, + ) -> impl SendFuture>>> where - S: SearchStrategy + 'static, - T: Copy + Send, + S: SearchStrategy, + T: Copy + Send + 'a, { async move { - let (computer, scratch) = { + let (accessor, scratch) = { let mut accessor = strategy - .search_accessor(&self.data_provider, context) + .search_accessor(&self.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - let start_ids = accessor.starting_points().await?; let num_start_points = start_ids.len(); @@ -2242,7 +2160,6 @@ where accessor .expand_beam( init_ids.iter().copied(), - &computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| neighbors.push(Neighbor::new(id, distance)), ) @@ -2253,124 +2170,20 @@ where .iter() .for_each(|neighbor| scratch.best.insert(*neighbor)); - (computer, scratch) + (accessor, scratch) }; - ANNResult::Ok(SearchState { + ANNResult::Ok(PagedSearch { + index: self, scratch, computed_result: vec![Neighbor::default(); l_value], next_result_index: l_value, search_param_l: l_value, - extra: (strategy, computer), + accessor, }) } } - pub fn next_search_results( - &self, - context: &DP::Context, - search_state: &mut SearchState, - k: usize, - result_output: &mut [Neighbor], - ) -> impl SendFuture> - where - S: SearchStrategy, - { - async move { - if k > search_state.search_param_l { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be less than or equal to search_param_l".to_string(), - )); - } - if k == 0 { - return ANNResult::Err(ANNError::log_paged_search_error( - "k should be greater than 0".to_string(), - )); - } - if result_output.len() < k { - return ANNResult::Err(ANNError::log_paged_search_error( - "The size of result_output should be greater than or equal to k".to_string(), - )); - } - - let copy_to_output = - |search_state: &mut SearchState, - count: usize, - result_output: &mut [Neighbor], - result_output_offset: usize| { - result_output[result_output_offset..result_output_offset + count] - .copy_from_slice( - &search_state.computed_result[search_state.next_result_index - ..search_state.next_result_index + count], - ); - search_state.next_result_index += count; - }; - - let used_computed_result_count: usize = cmp::min( - k, - search_state.computed_result.len() - search_state.next_result_index, - ); - if used_computed_result_count > 0 { - copy_to_output( - search_state, - used_computed_result_count, - result_output, - 0, // result_output_offset - ); - - if used_computed_result_count == k { - return ANNResult::Ok(k); - } - } - - let start_points = { - let mut accessor = search_state - .extra - .0 - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - - let start_ids = accessor.starting_points().await?; - self.search_internal( - None, // beam_width - &start_ids, - &mut accessor, - &search_state.extra.1, - &mut search_state.scratch, - &mut NoopSearchRecord::new(), - ) - .await?; - - start_ids - }; - - let (mut candidates, total_considered) = self - .filter_search_candidates(&start_points, k, &mut search_state.scratch.best) - .await?; - search_state.scratch.best.drain_best(total_considered); - - let computed_result_count = candidates.len(); - search_state.computed_result.clear(); - search_state.computed_result.append(&mut candidates); - - search_state.next_result_index = 0; - if computed_result_count != search_state.search_param_l { - search_state.computed_result.truncate(computed_result_count); - } - - let leftover_results = cmp::min(k - used_computed_result_count, computed_result_count); - - copy_to_output( - search_state, - leftover_results, // count of results to copy - result_output, - used_computed_result_count, // result_output_offset - ); - - ANNResult::Ok(used_computed_result_count + leftover_results) - } - } - /// Count the number of nodes in the graph reachable from the given `start_points`. /// /// This function has a large memory footprint for large graphs and should not be called @@ -2577,7 +2390,7 @@ where options: prune::Options, ) -> impl SendFuture> where - A: Accessor + BuildDistanceComputer + Fill, + A: BuildDistanceComputer + Fill + Send, Set: Send + Sync, { async move { @@ -2635,7 +2448,7 @@ where options: prune::Options, ) -> impl SendFuture>> where - A: Accessor + BuildDistanceComputer + Fill, + A: BuildDistanceComputer + Fill + Send, Set: Send + Sync, { async move { @@ -2726,7 +2539,7 @@ where options: prune::Options, ) -> impl SendFuture>> where - A: Accessor + BuildDistanceComputer + Fill, + A: BuildDistanceComputer + Fill + Send, Set: Send + Sync, Itr: ExactSizeIterator + Clone + Send + Sync, { diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index f09a63cb9..89544278a 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -14,12 +14,12 @@ use crate::{ error::IntoANNResult, graph::{ DiverseSearchParams, - glue::{SearchExt, SearchPostProcess, SearchStrategy}, + glue::{Explore, SearchPostProcess, SearchStrategy}, index::{DiskANNIndex, SearchStats}, search_output_buffer::SearchOutputBuffer, }, neighbor::{AttributeValueProvider, DiverseNeighborQueue, NeighborQueue}, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, }; /// Parameters for diversity-aware search. diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index c05594ce0..9434addde 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -15,12 +15,12 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{SearchExt, SearchPostProcess, SearchStrategy}, + glue::{Explore, SearchPostProcess, SearchStrategy}, index::{DiskANNIndex, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::SearchOutputBuffer, }, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, }; /// Error type for [`Knn`] parameter validation. @@ -191,27 +191,23 @@ where { async move { let mut accessor = strategy - .search_accessor(&index.data_provider, context) + .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; - let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len()); let stats = index .search_internal( Some(self.beam_width.get()), - &start_ids, &mut accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), ) .await?; let result_count = processor - .post_process(&mut accessor, query, &computer, scratch.best.iter(), output) + .post_process(&mut accessor, query, scratch.best.iter(), output) .await .into_ann_result()?; @@ -267,10 +263,9 @@ where { async move { let mut accessor = strategy - .search_accessor(&index.data_provider, context) + .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; let mut scratch = index.search_scratch(self.inner.l_value.get(), start_ids.len()); @@ -278,16 +273,14 @@ where let stats = index .search_internal( Some(self.inner.beam_width.get()), - &start_ids, &mut accessor, - &computer, &mut scratch, self.recorder, ) .await?; let result_count = processor - .post_process(&mut accessor, query, &computer, scratch.best.iter(), output) + .post_process(&mut accessor, query, scratch.best.iter(), output) .await .into_ann_result()?; diff --git a/diskann/src/graph/search/mod.rs b/diskann/src/graph/search/mod.rs index fac279421..350930671 100644 --- a/diskann/src/graph/search/mod.rs +++ b/diskann/src/graph/search/mod.rs @@ -46,6 +46,9 @@ mod knn_search; mod multihop_search; mod range_search; +mod paged; +pub use paged::PagedSearch; + pub mod record; pub(crate) mod scratch; diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index aba0f44c5..fe4cf1a16 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -5,19 +5,17 @@ //! Label-filtered search using multi-hop expansion. -use diskann_utils::Reborrow; use diskann_utils::future::SendFuture; -use diskann_vector::PreprocessedDistanceFunction; use hashbrown::HashSet; use super::{Knn, Search, record::SearchRecord, scratch::SearchScratch}; use crate::{ ANNResult, - error::{ErrorExt, IntoANNResult}, + error::IntoANNResult, graph::{ glue::{ - self, ExpandBeam, HybridPredicate, Predicate, PredicateMut, SearchExt, - SearchPostProcess, SearchStrategy, + self, Explore, HybridPredicate, Predicate, PredicateMut, SearchPostProcess, + SearchStrategy, }, index::{ DiskANNIndex, InternalSearchStats, QueryLabelProvider, QueryVisitDecision, SearchStats, @@ -26,7 +24,7 @@ use crate::{ search_output_buffer::SearchOutputBuffer, }, neighbor::Neighbor, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, utils::VectorId, }; @@ -77,9 +75,8 @@ where { async move { let mut accessor = strategy - .search_accessor(&index.data_provider, context) + .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; @@ -89,7 +86,6 @@ where index.max_degree_with_slack(), &self.inner, &mut accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), self.label_evaluator, @@ -100,7 +96,6 @@ where .post_process( &mut accessor, query, - &computer, scratch.best.iter().take(self.inner.l_value().get()), output, ) @@ -171,18 +166,17 @@ impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId { /// /// Performs label-filtered search by expanding through non-matching nodes /// to find matching neighbors within two hops. -pub(crate) async fn multihop_search_internal( +pub(crate) async fn multihop_search_internal( max_degree_with_slack: usize, search_params: &Knn, accessor: &mut A, - computer: &A::QueryComputer, scratch: &mut SearchScratch, search_record: &mut SR, query_label_evaluator: &dyn QueryLabelProvider, ) -> ANNResult where I: VectorId, - A: ExpandBeam + SearchExt, + A: Explore, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); @@ -194,21 +188,13 @@ where range_search_second_round: false, }; - // Initialize search state if not already initialized. - // This allows paged search to call multihop_search_internal multiple times - if scratch.visited.is_empty() { - let start_ids = accessor.starting_points().await?; - - for id in start_ids { + accessor + .start_point_distances(|id, dist| { scratch.visited.insert(id); - let element = accessor - .get_element(id) - .await - .escalate("start point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); scratch.best.insert(Neighbor::new(id, dist)); - } - } + scratch.cmps += 1; + }) + .await?; // Pre-allocate with good capacity to avoid repeated allocations let mut one_hop_neighbors = Vec::with_capacity(max_degree_with_slack); @@ -234,7 +220,6 @@ where accessor .expand_beam( scratch.beam_nodes.iter().copied(), - computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| one_hop_neighbors.push(Neighbor::new(id, distance)), ) @@ -280,7 +265,6 @@ where accessor .expand_beam( two_hop_expansion_candidate_ids.iter().copied(), - computer, NotInMutWithLabelCheck::new(&mut scratch.visited, query_label_evaluator), |distance, id| { two_hop_neighbors.push(Neighbor::new(id, distance)); diff --git a/diskann/src/graph/search/paged.rs b/diskann/src/graph/search/paged.rs new file mode 100644 index 000000000..f3670f494 --- /dev/null +++ b/diskann/src/graph/search/paged.rs @@ -0,0 +1,160 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_utils::future::SendFuture; + +use crate::{ + ANNError, ANNResult, + graph::{ + DiskANNIndex, + glue::Explore, + search::{record::NoopSearchRecord, scratch::SearchScratch}, + }, + neighbor::{Neighbor, NeighborPriorityQueue}, + provider::DataProvider, + utils::VectorId, +}; + +/// Intermediate state for paged search. +/// +/// Each call to [`next_page`](Self::next_page) resumes the graph search and returns the +/// next page of nearest-neighbor results. Returns an empty `Vec` when the search is exhausted. +/// +/// See also: [`DiskANNIndex::paged_search`], [`DiskANNIndex::paged_search_with_init_ids`]. +#[derive(Debug)] +pub struct PagedSearch<'a, DP, T> +where + DP: DataProvider, + T: Explore, +{ + pub(in crate::graph) index: &'a DiskANNIndex, + pub(in crate::graph) scratch: SearchScratch, + pub(in crate::graph) computed_result: Vec>, + pub(in crate::graph) next_result_index: usize, + pub(in crate::graph) search_param_l: usize, + pub(in crate::graph) accessor: T, +} + +impl<'a, DP, T> PagedSearch<'a, DP, T> +where + DP: DataProvider, + T: Explore, +{ + /// Returns the next page of at most `k` nearest-neighbor results. + /// + /// Results across pages are non-overlapping but not guaranteed to be monotonic with + /// respect to distance. + /// + /// Within a page, results ordered by non-decreasing distance. + /// + /// When the search is exhausted, returns an empty `Vec`. + pub fn next_page( + &mut self, + k: usize, + ) -> impl SendFuture>>> { + async move { + if k > self.search_param_l { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be less than or equal to search_param_l".to_string(), + )); + } + if k == 0 { + return ANNResult::Err(ANNError::log_paged_search_error( + "k should be greater than 0".to_string(), + )); + } + + let mut result = Vec::with_capacity(k); + + // Drain any already-computed results first. + let available = self + .computed_result + .len() + .saturating_sub(self.next_result_index); + let from_cache = std::cmp::min(k, available); + if from_cache > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + from_cache], + ); + self.next_result_index += from_cache; + + if result.len() == k { + return ANNResult::Ok(result); + } + } + + // Resume graph search to fill the next batch. + let start_points = { + // let mut accessor = self + // .strategy + // .search_accessor(&self.index.data_provider, self.context, self.query) + // .into_ann_result()?; + + let start_ids = self.accessor.starting_points().await?; + self.index + .search_internal( + None, // beam_width + &mut self.accessor, + &mut self.scratch, + &mut NoopSearchRecord::new(), + ) + .await?; + + start_ids + }; + + let (mut candidates, total_considered) = + filter_search_candidates(&start_points, k, &mut self.scratch.best); + self.scratch.best.drain_best(total_considered); + + let computed_result_count = candidates.len(); + self.computed_result.clear(); + self.computed_result.append(&mut candidates); + self.next_result_index = 0; + + let remaining_need = k - result.len(); + let leftover = std::cmp::min(remaining_need, computed_result_count); + if leftover > 0 { + result.extend_from_slice( + &self.computed_result + [self.next_result_index..self.next_result_index + leftover], + ); + self.next_result_index += leftover; + } + + ANNResult::Ok(result) + } + } +} + +// FIXME: Wire proper post-processing support into paged search. +fn filter_search_candidates( + start_points: &[I], + l_value: usize, + best: &mut NeighborPriorityQueue, +) -> (Vec>, usize) +where + I: VectorId, +{ + let mut total = 0usize; + let mut candidates = Vec::with_capacity(l_value); + for n in best.iter() { + total += 1; + if !start_points.contains(&n.id) { + candidates.push(n); + if candidates.len() >= l_value { + break; + } + } + } + + debug_assert!( + l_value.min(best.size().saturating_sub(start_points.len())) <= candidates.len(), + "Not enough candidates after filtering starting points", + ); + + (candidates, total) +} diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index b92fa1384..6cbea685a 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -13,13 +13,13 @@ use crate::{ ANNError, ANNErrorKind, ANNResult, error::IntoANNResult, graph::{ - glue::{self, ExpandBeam, SearchExt, SearchStrategy}, + glue::{self, Explore, SearchStrategy}, index::{DiskANNIndex, InternalSearchStats, SearchStats}, search::record::NoopSearchRecord, search_output_buffer::{self, SearchOutputBuffer}, }, neighbor::Neighbor, - provider::{BuildQueryComputer, DataProvider}, + provider::DataProvider, utils::IntoUsize, }; @@ -181,9 +181,8 @@ where { async move { let mut accessor = strategy - .search_accessor(&index.data_provider, context) + .search_accessor(&index.data_provider, context, query) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; let start_ids = accessor.starting_points().await?; let mut scratch = index.search_scratch(self.starting_l(), start_ids.len()); @@ -191,9 +190,7 @@ where let initial_stats = index .search_internal( self.beam_width(), - &start_ids, &mut accessor, - &computer, &mut scratch, &mut NoopSearchRecord::new(), ) @@ -222,7 +219,6 @@ where index.max_degree_with_slack(), &self, &mut accessor, - &computer, &mut scratch, ) .await?; @@ -252,7 +248,6 @@ where .post_process( &mut accessor, query, - &computer, scratch.in_range.iter().copied(), &mut filtered, ) @@ -323,16 +318,15 @@ where /// /// Expands the search frontier to find all points within the specified radius. /// Called after the initial graph search has identified starting candidates. -pub(crate) async fn range_search_internal( +pub(crate) async fn range_search_internal( max_degree_with_slack: usize, search_params: &Range, accessor: &mut A, - computer: &A::QueryComputer, scratch: &mut SearchScratch, ) -> ANNResult where I: crate::utils::VectorId, - A: ExpandBeam + SearchExt, + A: Explore, { let beam_width = search_params.beam_width().unwrap_or(1); @@ -360,7 +354,6 @@ where accessor .expand_beam( scratch.beam_nodes.iter().copied(), - computer, glue::NotInMut::new(&mut scratch.visited), |distance, id| neighbors.push(Neighbor::new(id, distance)), ) diff --git a/diskann/src/graph/search/scratch.rs b/diskann/src/graph/search/scratch.rs index 98a5b3127..2670f118c 100644 --- a/diskann/src/graph/search/scratch.rs +++ b/diskann/src/graph/search/scratch.rs @@ -158,6 +158,7 @@ where } /// Return the currently configured `search_l`: the number of best candidates to track. + #[cfg(test)] pub fn search_l(&self) -> usize { self.best.search_l() } diff --git a/diskann/src/graph/test/cases/paged_search.rs b/diskann/src/graph/test/cases/paged_search.rs index 0b8fa4706..511779201 100644 --- a/diskann/src/graph/test/cases/paged_search.rs +++ b/diskann/src/graph/test/cases/paged_search.rs @@ -6,8 +6,8 @@ //! Tests for paged (iterative) search. //! //! Paged search returns results in pages of k neighbors via a stateful -//! `SearchState`. Tests cover basic pagination, single-page retrieval, -//! and small page sizes that stress the iteration machinery. +//! [`PagedSearch`](crate::graph::search::PagedSearch) handle. Tests cover basic pagination, +//! single-page retrieval, and small page sizes that stress the iteration machinery. use std::sync::Arc; @@ -155,34 +155,19 @@ fn basic_paged_search() { let page_size = 4; let ctx = test_provider::Context::new(); - let mut state = rt - .block_on(index.start_paged_search( - test_provider::Strategy::new(), - &ctx, - query.as_slice(), - search_l, - )) + let strategy = test_provider::Strategy::new(); + let mut search = rt + .block_on(index.paged_search(&strategy, &ctx, query.as_slice(), search_l)) .unwrap(); let mut pages: Vec>> = Vec::new(); - let mut buffer = vec![Neighbor::::default(); page_size]; loop { - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - if count == 0 { + let page = rt.block_on(search.next_page(page_size)).unwrap(); + if page.is_empty() { break; } - pages.push(buffer[..count].to_vec()); + pages.push(page); } let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); @@ -208,31 +193,14 @@ fn single_page() { let search_l = 200; let page_size = 200; // larger than total points (125) let ctx = test_provider::Context::new(); + let strategy = test_provider::Strategy::new(); - let mut state = rt - .block_on(index.start_paged_search( - test_provider::Strategy::new(), - &ctx, - query.as_slice(), - search_l, - )) + let mut search = rt + .block_on(index.paged_search(&strategy, &ctx, query.as_slice(), search_l)) .unwrap(); - let mut buffer = vec![Neighbor::::default(); page_size]; - - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - let results: Vec> = buffer[..count].to_vec(); - let pages = vec![results.clone()]; + let results = rt.block_on(search.next_page(page_size)).unwrap(); + let pages = vec![results]; let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); @@ -242,18 +210,13 @@ fn single_page() { assert_no_duplicates_across_pages(&pages); assert_non_decreasing_distances(&pages); - // Verify second call returns 0 (nothing left) - let count2 = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - assert_eq!(count2, 0, "second page should be empty"); + // Verify second call returns empty (nothing left) + let page2 = rt.block_on(search.next_page(page_size)).unwrap(); + assert!(page2.is_empty(), "second page should be empty"); + + // Verify repeated calls after exhaustion remain empty (idempotent, no panic). + let page3 = rt.block_on(search.next_page(page_size)).unwrap(); + assert!(page3.is_empty(), "third page should still be empty"); } #[test] @@ -269,35 +232,20 @@ fn small_page_size() { let search_l = 32; let page_size = 1; // one result per page, maximum iterations let ctx = test_provider::Context::new(); + let strategy = test_provider::Strategy::new(); - let mut state = rt - .block_on(index.start_paged_search( - test_provider::Strategy::new(), - &ctx, - query.as_slice(), - search_l, - )) + let mut search = rt + .block_on(index.paged_search(&strategy, &ctx, query.as_slice(), search_l)) .unwrap(); let mut pages: Vec>> = Vec::new(); - let mut buffer = vec![Neighbor::::default(); page_size]; loop { - let count = rt - .block_on( - index.next_search_results::( - &ctx, - &mut state, - page_size, - &mut buffer, - ), - ) - .unwrap(); - - if count == 0 { + let page = rt.block_on(search.next_page(page_size)).unwrap(); + if page.is_empty() { break; } - pages.push(buffer[..count].to_vec()); + pages.push(page); } let baseline = build_baseline(grid_size, &dims, &query, search_l, page_size, &pages); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index afa80b38c..fa446958b 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -13,13 +13,13 @@ use std::{ }; use dashmap::{DashMap, mapref::entry::Entry}; -use diskann_utils::views::Matrix; -use diskann_vector::distance::Metric; +use diskann_utils::{future::SendFuture, views::Matrix}; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use thiserror::Error; use crate::{ ANNError, ANNResult, default_post_processor, - error::{Infallible, RankedError, StandardError, ToRanked, TransientError, message}, + error::{ErrorExt, Infallible, RankedError, StandardError, ToRanked, TransientError, message}, graph::{AdjacencyList, SearchOutputBuffer, glue, test::synthetic, workingset}, internal::counter::{Counter, LocalCounter}, neighbor::Neighbor, @@ -472,6 +472,11 @@ impl Provider { .map(|ref_multi| *ref_multi.key()) .filter(|id| !self.is_start_point(*id)) } + + // Return all start point Ids. + pub fn start_point_ids(&self) -> impl Iterator + '_ { + self.config.start_points.keys().copied() + } } /// Provider level metrics. @@ -994,98 +999,74 @@ impl provider::NeighborAccessorMut for NeighborAccessor<'_> { } } -////////////// -// Accessor // -////////////// +///////////// +// Builder // +///////////// #[derive(Debug)] -pub struct Accessor<'a> { +pub struct Builder<'a> { provider: &'a Provider, - buffer: Box<[f32]>, get_vector: LocalCounter<'a>, - /// IDs that will produce transient errors when accessed. transient_ids: Option>>, } -impl<'a> Accessor<'a> { +impl<'a> Builder<'a> { /// Return the underlying [`Provider`] reference. pub fn provider(&self) -> &'a Provider { self.provider } - /// Creates an accessor with no flaky behavior (backward-compatible). - pub fn new(provider: &'a Provider) -> Self { - Self::new_inner(provider, None) - } - - /// Creates an accessor where `get_element` returns a transient error for - /// any ID in `transient_ids`. The ID must still exist in the provider — - /// accessing a truly missing ID remains a critical `InvalidId` error. - pub fn flaky(provider: &'a Provider, transient_ids: Cow<'a, HashSet>) -> Self { - Self::new_inner(provider, Some(transient_ids)) - } - - fn new_inner(provider: &'a Provider, transient_ids: Option>>) -> Self { - let buffer = (0..provider.dim()).map(|_| 0.0).collect(); + fn new(provider: &'a Provider, transient_ids: Option>>) -> Self { Self { provider, - buffer, get_vector: provider.get_vector.local(), transient_ids, } } -} - -impl provider::HasId for Accessor<'_> { - type Id = u32; -} -impl provider::Accessor for Accessor<'_> { - type Element<'a> - = &'a [f32] - where - Self: 'a; - type ElementRef<'a> = &'a [f32]; - type GetError = AccessError; + fn check_flaky(&self, id: u32) -> Result<(), AccessError> { + if let Some(transient) = &self.transient_ids + && transient.contains(&id) + { + Err(AccessError::Transient(TransientAccessError::new(id))) + } else { + Ok(()) + } + } - async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> { + fn get(&mut self, id: u32) -> Result, AccessError> { match self.provider.terms.get(&id) { Some(term) => { - if let Some(transient) = &self.transient_ids - && transient.contains(&id) - { - return Err(AccessError::Transient(TransientAccessError::new(id))); - } + self.check_flaky(id)?; self.get_vector.increment(); - self.buffer.copy_from_slice(&term.data); - Ok(&*self.buffer) + Ok((*term.data).into()) } None => Err(AccessError::InvalidId(AccessedInvalidId(id))), } } -} -impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { - type Delegate = NeighborAccessor<'a>; - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - NeighborAccessor::new(self.provider) + fn record_get(&mut self) { + self.get_vector.increment() } } -impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { - type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; +impl provider::HasId for Builder<'_> { + type Id = u32; +} + +impl provider::HasElementRef for Builder<'_> { + type ElementRef<'a> = &'a [f32]; +} - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(f32::query_distance(from, self.provider.config.metric)) +impl<'a> provider::DelegateNeighbor<'a> for Builder<'_> { + type Delegate = NeighborAccessor<'a>; + fn delegate_neighbor(&'a mut self) -> Self::Delegate { + NeighborAccessor::new(self.provider) } } -impl provider::BuildDistanceComputer for Accessor<'_> { +impl provider::BuildDistanceComputer for Builder<'_> { type DistanceComputerError = Infallible; type DistanceComputer = ::Distance; @@ -1099,17 +1080,176 @@ impl provider::BuildDistanceComputer for Accessor<'_> { } } +//------// +// glue // +//------// + +type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; +type View<'a> = workingset::map::View<'a, u32, Box<[f32]>, workingset::map::Ref<[f32]>>; + +impl workingset::Fill for Builder<'_> { + type Error = ANNError; + type View<'a> + = View<'a> + where + Self: 'a, + WorkingSet: 'a; + + fn fill<'a, Itr>( + &'a mut self, + set: &'a mut WorkingSet, + itr: Itr, + ) -> impl SendFuture, Self::Error>> + where + Itr: ExactSizeIterator + Clone + Send + Sync, + Self: 'a, + { + async { + use workingset::map::Entry; + + set.prepare(itr.clone()); + for i in itr { + match set.entry(i) { + Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } + Entry::Vacant(vacant) => { + if let Some(buf) = + self.get(i).allow_transient("transient failures allowed")? + { + vacant.insert(buf); + } + } + } + } + Ok(set.view()) + } + } +} + +////////////// +// Accessor // +////////////// + +#[derive(Debug)] +pub struct Accessor<'a> { + builder: Builder<'a>, + distance: ::QueryDistance, +} + +impl<'a> Accessor<'a> { + /// Return the underlying [`Provider`] reference. + pub fn provider(&self) -> &'a Provider { + self.builder.provider() + } + + /// Creates an accessor with no flaky behavior (backward-compatible). + pub fn new(provider: &'a Provider, query: &[f32]) -> Self { + Self::new_inner(provider, query, None) + } + + /// Creates an accessor where `get_element` returns a transient error for + /// any ID in `transient_ids`. The ID must still exist in the provider — + /// accessing a truly missing ID remains a critical `InvalidId` error. + pub fn flaky( + provider: &'a Provider, + query: &[f32], + transient_ids: Cow<'a, HashSet>, + ) -> Self { + Self::new_inner(provider, query, Some(transient_ids)) + } + + fn new_inner( + provider: &'a Provider, + query: &[f32], + transient_ids: Option>>, + ) -> Self { + // FIXME: Return error. + assert_eq!(query.len(), provider.dim()); + let distance = f32::query_distance(query, provider.distance_metric()); + let builder = Builder::new(provider, transient_ids); + + Self { builder, distance } + } + + pub fn get_distance(&mut self, id: u32) -> Result { + match self.provider().terms.get(&id) { + Some(term) => { + self.builder.check_flaky(id)?; + self.builder.record_get(); + + Ok(self.distance.evaluate_similarity(&*term.data)) + } + None => Err(AccessError::InvalidId(AccessedInvalidId(id))), + } + } +} + +impl provider::HasId for Accessor<'_> { + type Id = u32; +} + //------// // Glue // //------// -impl glue::SearchExt for Accessor<'_> { +impl glue::Explore for Accessor<'_> { fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(self.provider.config.start_points.keys().copied().collect()) + futures_util::future::ok( + self.provider() + .config + .start_points + .keys() + .copied() + .collect(), + ) + } + + fn start_point_distances( + &mut self, + mut f: F, + ) -> impl std::future::Future> + Send + where + F: FnMut(Self::Id, f32) + Send, + { + async move { + for &i in self.provider().config.start_points.keys() { + f(i, self.get_distance(i).escalate("start points must exist")?) + } + Ok(()) + } + } + + fn expand_beam( + &mut self, + ids: Itr, + mut pred: P, + mut on_neighbors: F, + ) -> impl std::future::Future> + Send + where + Itr: Iterator + Send, + P: glue::HybridPredicate + Send + Sync, + F: FnMut(f32, Self::Id) + Send, + { + async move { + let mut neighbors = AdjacencyList::new(); + for id in ids { + self.provider().get_neighbors(id, &mut neighbors)?; + for &n in neighbors.iter().filter(|i| pred.eval_mut(i)) { + if let Some(distance) = self + .get_distance(n) + .allow_transient("transient failures allowed")? + { + on_neighbors(distance, n) + } + } + } + Ok(()) + } } } -impl glue::ExpandBeam<&[f32]> for Accessor<'_> {} +////////////// +// Strategy // +////////////// #[derive(Debug, Clone)] pub struct Strategy { @@ -1152,7 +1292,6 @@ impl Default for Strategy { } impl glue::SearchStrategy for Strategy { - type QueryComputer = ::QueryDistance; type SearchAccessorError = Infallible; type SearchAccessor<'a> = Accessor<'a>; @@ -1160,8 +1299,9 @@ impl glue::SearchStrategy for Strategy { &'a self, provider: &'a Provider, _context: &'a Context, + query: &[f32], ) -> Result, Infallible> { - Ok(Accessor::new(provider)) + Ok(Accessor::new(provider, query)) } } @@ -1172,7 +1312,7 @@ impl glue::DefaultPostProcessor for Strategy { impl glue::PruneStrategy for Strategy { type WorkingSet = workingset::Map, workingset::map::Ref<[f32]>>; type DistanceComputer<'a> = ::Distance; - type PruneAccessor<'a> = Accessor<'a>; + type PruneAccessor<'a> = Builder<'a>; type PruneAccessorError = Infallible; fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { @@ -1191,8 +1331,8 @@ impl glue::PruneStrategy for Strategy { _context: &'a Context, ) -> Result, Self::PruneAccessorError> { match &self.transient_ids { - Some(ids) => Ok(Accessor::flaky(provider, Cow::Borrowed(ids))), - None => Ok(Accessor::new(provider)), + Some(ids) => Ok(Builder::new(provider, Some(Cow::Borrowed(ids)))), + None => Ok(Builder::new(provider, None)), } } } @@ -1208,8 +1348,9 @@ impl glue::InsertStrategy for Strategy { &'a self, provider: &'a Provider, _context: &'a Context, + vector: &[f32], ) -> Result, Self::SearchAccessorError> { - Ok(Accessor::new(provider)) + Ok(Accessor::new(provider, vector)) } } @@ -1265,7 +1406,6 @@ impl<'a, 'b, O> glue::SearchPostProcessStep, &'b [f32], O> for Filt next: &Next, accessor: &mut Accessor<'a>, query: &'b [f32], - computer: &::QueryDistance, candidates: I, output: &mut B, ) -> impl std::future::Future>> + Send @@ -1274,11 +1414,10 @@ impl<'a, 'b, O> glue::SearchPostProcessStep, &'b [f32], O> for Filt B: SearchOutputBuffer + Send + ?Sized, Next: glue::SearchPostProcess + Sync, { - let provider = accessor.provider; + let provider = accessor.provider(); next.post_process( accessor, query, - computer, candidates.filter(|n| !provider.is_deleted(n.id).unwrap_or(true)), output, ) @@ -1630,52 +1769,52 @@ mod tests { ); } - #[test] - fn test_set_element() { - use provider::{Accessor, Guard, SetElement}; - - let provider = create_test_provider(); - let rt = current_thread_runtime(); - - let context = Context::new(); - let mut accessor = super::Accessor::new(&provider); - let id = 5; - - assert!(rt.block_on(accessor.get_element(5)).is_err()); - - // Setting with the wrong dimension is an error. - { - let v = vec![1.0f32; provider.dim() + 1]; - let err = rt - .block_on(provider.set_element(&context, &id, &v)) - .unwrap_err(); - let msg = err.to_string(); - assert_message_contains!(msg, "wrong dim"); - assert!(rt.block_on(accessor.get_element(id)).is_err()); - } - - // Setting with the correct dimension is successful. - { - let v = vec![1.0f32; provider.dim()]; - let guard = rt - .block_on(provider.set_element(&context, &id, &v)) - .unwrap(); - rt.block_on(guard.complete()); - - let element = rt.block_on(accessor.get_element(id)).unwrap(); - assert_eq!(v, element); - } - - // Setting again is an error. - { - let v = vec![1.0f32; provider.dim()]; - let err = rt - .block_on(provider.set_element(&context, &id, &v)) - .unwrap_err(); - let msg = err.to_string(); - assert_message_contains!(msg, "vector id 5 is already assigned"); - } - } + // #[test] + // fn test_set_element() { + // use provider::{Guard, SetElement}; + + // let provider = create_test_provider(); + // let rt = current_thread_runtime(); + + // let context = Context::new(); + // let mut accessor = super::Accessor::new(&provider); + // let id = 5; + + // assert!(accessor.get(5).is_err()); + + // // Setting with the wrong dimension is an error. + // { + // let v = vec![1.0f32; provider.dim() + 1]; + // let err = rt + // .block_on(provider.set_element(&context, &id, &v)) + // .unwrap_err(); + // let msg = err.to_string(); + // assert_message_contains!(msg, "wrong dim"); + // assert!(accessor.get(id).is_err()); + // } + + // // Setting with the correct dimension is successful. + // { + // let v = vec![1.0f32; provider.dim()]; + // let guard = rt + // .block_on(provider.set_element(&context, &id, &v)) + // .unwrap(); + // rt.block_on(guard.complete()); + + // let element = accessor.get(id).unwrap(); + // assert_eq!(v, element); + // } + + // // Setting again is an error. + // { + // let v = vec![1.0f32; provider.dim()]; + // let err = rt + // .block_on(provider.set_element(&context, &id, &v)) + // .unwrap_err(); + // let msg = err.to_string(); + // assert_message_contains!(msg, "vector id 5 is already assigned"); + // } + // } #[test] fn test_neighbor_accessor() { diff --git a/diskann/src/graph/workingset/map.rs b/diskann/src/graph/workingset/map.rs index eaa11c089..8fffeb910 100644 --- a/diskann/src/graph/workingset/map.rs +++ b/diskann/src/graph/workingset/map.rs @@ -129,16 +129,12 @@ use std::{fmt::Debug, hash::Hash, sync::Arc}; -use diskann_utils::{Reborrow, future::SendFuture}; +use diskann_utils::Reborrow; use hashbrown::hash_map; -use crate::{ - error::{RankedError, ToRanked, TransientError}, - graph::glue, - provider::Accessor, -}; +use crate::graph::glue; -use super::{AsWorkingSet, Fill}; +use super::AsWorkingSet; ///////// // Map // @@ -284,72 +280,6 @@ where } } -impl Map -where - K: Copy + Hash + Eq + Send + Sync, - V: Send + Sync, - P: Projection, -{ - /// Fill using `accessor.get_element` for each missing key. - /// - /// Calls [`prepare`](Self::prepare) to pin and evict entries, then fetches any - /// keys still missing. For incremental fills that skip preparation, use - /// [`fill_with`](Self::fill_with) directly. - /// - /// Transient errors are acknowledged and skipped. Only critical errors are propagated. - pub fn fill<'a, A, Itr>( - &'a mut self, - accessor: &'a mut A, - itr: Itr, - ) -> impl SendFuture, ::Error>> - where - A: for<'b> Accessor: Into>, - Itr: ExactSizeIterator + Clone + Send + Sync, - { - self.prepare(itr.clone()); - self.fill_with(accessor, itr, |element| element.into()) - } - - /// Fill using a projection from `Accessor::Element` to `V`. - /// - /// Transient errors are acknowledged and skipped. Only critical errors are propagated. - pub fn fill_with<'a, A, Itr, F>( - &'a mut self, - accessor: &'a mut A, - itr: Itr, - f: F, - ) -> impl SendFuture, ::Error>> - where - A: Accessor, - Itr: ExactSizeIterator + Send + Sync, - F: Fn(A::ElementRef<'_>) -> V + Send, - { - async move { - for i in itr { - match self.entry(i) { - Entry::Seeded(_) | Entry::Occupied(_) => { /* nothing to do */ } - Entry::Vacant(vacant) => match accessor.get_element(i).await { - Ok(element) => { - vacant.insert(f(element.reborrow())); - } - Err(local_error) => match local_error.to_ranked() { - RankedError::Transient(transient) => { - transient.acknowledge( - "transient error during fill; element will be absent from working set", - ); - } - RankedError::Error(critical) => { - return Err(critical); - } - }, - }, - } - } - Ok(self.view()) - } - } -} - impl Map where K: Hash + Eq, @@ -526,37 +456,6 @@ impl VacantEntry<'_, K, V> { } } -/// Blanket implementation of [`Fill`] for [`Map`]-backed working sets. -/// -/// This covers the common case where the accessor's `Extended` type is stored directly -/// in the map. Accessors that need custom fill logic (e.g. hybrid full-precision/quantized) -/// should use a different `State` type and provide their own `Fill` impl. -impl Fill> for A -where - P: Projection, - A: for<'a> Accessor = P::ElementRef<'a>>, - V: Project

+ Send + Sync + 'static, - for<'a> A::ElementRef<'a>: Into, -{ - type Error = ::Error; - type View<'a> - = View<'a, A::Id, V, P> - where - Self: 'a; - - fn fill<'a, Itr>( - &'a mut self, - map: &'a mut Map, - itr: Itr, - ) -> impl SendFuture, Self::Error>> - where - Itr: ExactSizeIterator + Clone + Send + Sync, - Self: 'a, - { - map.fill(self, itr) - } -} - ///////////////// // Projections // ///////////////// @@ -980,13 +879,11 @@ where mod tests { use super::*; - use std::{borrow::Cow, sync::Arc}; + use std::sync::Arc; use diskann_utils::views::Matrix; - use crate::graph::{ - test::provider::Accessor as TestAccessor, workingset::View as WorkingSetView, - }; + use crate::graph::workingset::View as WorkingSetView; /// Convenience alias matching the test provider's working set type. type TestMap = Map, Ref<[f32]>>; @@ -1932,202 +1829,4 @@ mod tests { let ar = AsReborrowed(&*value); let _ = ar.clone(); } - - //----------------------------------// - // Fill / fill_with (async, Tier 1) // - //----------------------------------// - - /// Create a grid-backed provider (1-D, size 5). - /// - /// IDs 0–4 have vectors `[0.0]` .. `[4.0]`, start point is `u32::MAX`. - fn fill_provider() -> crate::graph::test::provider::Provider { - use crate::graph::test::synthetic::Grid; - crate::graph::test::provider::Provider::grid(Grid::One, 5).unwrap() - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_happy_path() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - let current = map.generation; - let view = map - .fill(&mut accessor, [0u32, 1, 2].into_iter()) - .await - .unwrap(); - - assert_eq!(view.get(0).unwrap(), &[0.0]); - assert_eq!(view.get(1).unwrap(), &[1.0]); - assert_eq!(view.get(2).unwrap(), &[2.0]); - assert!(view.get(99).is_none()); - - assert_eq!( - map.generation, - current + 1, - "`fill` should bump the generation" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_clears_previous_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: Map, Ref<[f32]>> = Builder::new(Capacity::None).build(0); - - // First fill with IDs 0 and 1. - let view = map - .fill(&mut accessor, [0u32, 1].into_iter()) - .await - .unwrap(); - - assert!(view.get(0).is_some()); - assert!(view.get(1).is_some()); - assert!(view.get(2).is_none()); - - // Second fill with only ID 2 — previous entries should be cleared. - let view = map.fill(&mut accessor, [2u32].into_iter()).await.unwrap(); - assert!(view.get(0).is_none(), "fill should have cleared id 0"); - assert!(view.get(1).is_none(), "fill should have cleared id 1"); - assert_eq!(view.get(2).unwrap(), &[2.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_with_preserves_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // Populate with ID 0. - let view = map - .fill_with(&mut accessor, [0u32].into_iter(), |e| e.into()) - .await - .unwrap(); - assert!(view.get(0).is_some()); - - // fill_with with ID 1 — should NOT clear ID 0. - let view = map - .fill_with(&mut accessor, [1u32].into_iter(), |e| e.into()) - .await - .unwrap(); - assert_eq!( - view.get(0).unwrap(), - &[0.0], - "fill_with should preserve id 0" - ); - assert_eq!(view.get(1).unwrap(), &[1.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_skips_transient_errors() { - let provider = fill_provider(); - - let mut accessor = - TestAccessor::flaky(&provider, Cow::Owned(std::collections::HashSet::from([1]))); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // ID 1 is transient — should be skipped, not propagated. - let view = map - .fill(&mut accessor, [0u32, 1, 2].into_iter()) - .await - .unwrap(); - assert_eq!(view.get(0).unwrap(), &[0.0]); - assert!(view.get(1).is_none(), "transient ID should be absent"); - assert_eq!(view.get(2).unwrap(), &[2.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_propagates_critical_errors() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // ID 99 doesn't exist — critical InvalidId error. - let err = map - .fill(&mut accessor, [0u32, 99].into_iter()) - .await - .unwrap_err(); - let msg = err.to_string(); - assert!( - msg.contains("99"), - "error should mention the invalid id: {msg}" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_with_skips_occupied_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // Pre-insert a sentinel value for ID 0. - map.insert(0, Box::new([99.0])); - - // fill_with should skip the occupied entry. - let view = map - .fill_with(&mut accessor, [0u32, 1].into_iter(), |e| e.into()) - .await - .unwrap(); - - // ID 0 retains its pre-inserted sentinel. - assert_eq!( - view.get(0).unwrap(), - &[99.0], - "occupied entry should be preserved" - ); - assert_eq!(view.get(1).unwrap(), &[1.0]); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_with_skips_seeded_entries() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - - // Seed with a batch containing a different value for ID 0. - let batch = Arc::new(Matrix::try_from(Box::new([99.0, 88.0]), 2, 1).unwrap()); - let overlay = Overlay::>::from_batch(batch, [0u32, 1].into_iter()); - let mut map = seeded_map(overlay, Capacity::Unbounded); - - // fill_with requests IDs 0 and 2. ID 0 is seeded → skip, ID 2 is filled. - let view = map - .fill_with(&mut accessor, [0u32, 2].into_iter(), |e| e.into()) - .await - .unwrap(); - - // ID 0 comes from the seed (batch row 0 = [99.0]), NOT the accessor. - assert_eq!(view.get(0).unwrap(), &[99.0]); - // ID 2 was filled from the accessor. - assert_eq!(view.get(2).unwrap(), &[2.0]); - // Verify ID 0 is NOT in the fill layer. - assert!( - map.get(&0).is_none(), - "ID 0 should only be in the seed, not the fill layer" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn fill_empty_iterator() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - let view = map.fill(&mut accessor, std::iter::empty()).await.unwrap(); - assert!(view.get(0).is_none()); - } - - #[tokio::test(flavor = "current_thread")] - async fn blanket_fill_trait() { - let provider = fill_provider(); - let mut accessor = TestAccessor::new(&provider); - let mut map: TestMap = Builder::new(Capacity::Unbounded).build(0); - - // Exercise the blanket Fill impl. - let view = <_ as Fill>::fill(&mut accessor, &mut map, [0u32, 1, 2].into_iter()) - .await - .unwrap(); - - assert_eq!(view.get(0).unwrap(), &[0.0]); - assert_eq!(view.get(1).unwrap(), &[1.0]); - assert_eq!(view.get(2).unwrap(), &[2.0]); - } } diff --git a/diskann/src/graph/workingset/mod.rs b/diskann/src/graph/workingset/mod.rs index d879c5d5a..d0a31cc37 100644 --- a/diskann/src/graph/workingset/mod.rs +++ b/diskann/src/graph/workingset/mod.rs @@ -67,7 +67,10 @@ use diskann_utils::{Reborrow, future::SendFuture}; -use crate::{ANNError, provider::Accessor}; +use crate::{ + ANNError, + provider::{HasElementRef, HasId}, +}; ///////////// // Exports // @@ -96,7 +99,7 @@ pub use map::Map; /// directly accessible by the `WorkingSet`/[`View`] types. /// /// See Also: [`View`], [`AsWorkingSet`], [`Map`]. -pub trait Fill: Accessor { +pub trait Fill: HasId + HasElementRef { /// Any critical error that occurs during [`fill`](Self::fill). /// /// Implementations of `fill` are expected to swallow any non-critical errors. diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f55af356c..3c00dc0a6 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -71,10 +71,6 @@ //! * [`BuildDistanceComputer`]: A sub-trait of [`Accessor`] that allows for random-access //! distance computations on the retrieved elements. //! -//! * [`BuildQueryComputer`]: A sub-trait of [`Accessor`] that allows for specialized query -//! based computations. This allows a query to be pre-processed in a way that allows -//! faster computations. -//! //! # Neighbor Delegation //! //! Index search requires that accessor types implement both the data-centric [`Accessor`] @@ -97,8 +93,7 @@ use std::ops::Deref; -use diskann_utils::Reborrow; -use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction}; +use diskann_vector::DistanceFunction; use sealed::{BoundTo, Sealed}; use crate::{ANNError, ANNResult, error::ToRanked, graph::AdjacencyList, utils::VectorId}; @@ -306,6 +301,14 @@ where type Id = T::Id; } +/////////////////// +// HasElementRef // +/////////////////// + +pub trait HasElementRef { + type ElementRef<'a>; +} + //////////////// // SetElement // //////////////// @@ -381,101 +384,8 @@ where } } -////////////// -// Accessor // -////////////// - -/// A lens through which [`DataProvider`]s contextually viewed. -/// -/// Accessors are **not** required to be `'static` and almost always contain a scoped -/// reference to their parent provider. -/// -/// # Element Relationship -/// -/// Accessors are expected to define two associated element types: -/// -/// * `Element<'_>`: The type returned by `get_element`. This is scoped to the borrow -/// of the accessor at the `get_element` call site. As a consequence, there may only -/// be one such `Element` active at a time. -/// -/// * `ElementRef<'_>`: A generalized borrowed form of `Element` obtainable via -/// `Reborrow`. This is the type on which distance computations are defined and is the -/// element type provided to the `on_element_unordered` bulk operation. -/// -/// The below diagram summarizes the relationship. -/// -/// ```text -/// Element<'_> ------ Reborrow ----> ElementRef<'_> -/// ~~~~ ~~~~ -/// ^ ^ -/// | | -/// Lifetime tied Arbitrarily short -/// to the Accessor lifetime decoupled -/// from the Accessor -/// ``` -/// -/// ## Technical Details -/// -/// The need for `ElementRef` arises to allow HRTB bounds to distance computers without -/// inducing a `'static` bound on `Self`. In traits like [`BuildQueryComputer`], attempting -/// to use `Element` directly will result in such a requirement on the implementing Accessor. -pub trait Accessor: HasId + Send + Sync { - /// A generalized reference type used for distance computations. - /// - /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a - /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` - /// requirement on `Self`. - type ElementRef<'a>; - - /// The concrete type of the data element associated with this accessor. - /// - /// For distance computations, this should be cheaply convertible via [`Reborrow`] to - /// `Self::ElementRef`. - type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync - where - Self: 'a; - - /// The error (if any) returned by [`Self::get_element`]. - type GetError: ToRanked + std::fmt::Debug + Send + Sync + 'static; - - /// Return the value associated with the key `id`. - /// - /// It is expected that index algorithms will only invoke `get_element` on valid IDs, - /// that can be derived from [`SetElement::set_element`] or by some other means. - /// - /// Implementations are suggested to return an error if this invariant is broken, but - /// may also panic if that is an acceptable error mode. - fn get_element( - &mut self, - id: Self::Id, - ) -> impl std::future::Future, Self::GetError>> + Send; - - /// A bulk interface for invoking [`Self::get_element`] on each item in an iterator and - /// invoking the closure with the reborrowed element. - /// - /// Algorithms are encouraged to use this interface if appropriate as accessor - /// implementations may specialize the implementation for better performance. - fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> impl std::future::Future> + Send - where - Self: Sync, - Itr: Iterator + Send, - F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id), - { - async move { - for i in itr { - f(self.get_element(i).await?.reborrow(), i); - } - Ok(()) - } - } -} - /// A specialized [`Accessor`] that provides random-access distance computations. -pub trait BuildDistanceComputer: Accessor { +pub trait BuildDistanceComputer: HasElementRef { /// The error type (if any) associated with distance computer construction. /// /// Implementations are encouraged to make distance computer construction infallible. @@ -496,52 +406,6 @@ pub trait BuildDistanceComputer: Accessor { ) -> Result; } -/// A specialized [`Accessor`] that provides query computations for a query type `T`. -/// -/// Query computers are allowed to preprocess the query to enable more efficient distance -/// computations. -pub trait BuildQueryComputer: Accessor { - /// The error type (if any) associated with distance computer construction. - type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; - - /// The concrete type of the distance computer, which must be applicable for all - /// elements yielded by the [`Accessor`]. - type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> - + Send - + Sync - + 'static; - - /// Build the query computer for this accessor. - /// - /// This method is encouraged to be as fast as possible, but will generally only be - /// invoked once per search or graph insert. - fn build_query_computer( - &self, - from: T, - ) -> Result; - - /// Compute the distances for the elements in the iterator `itr` using the - /// `computer` and apply the closure `f` to each distance and ID. The default - /// implementation uses on_elements_unordered to iterate over the elements - /// and compute the distances using `computer` parameter. - fn distances_unordered( - &mut self, - vec_id_itr: Itr, - computer: &Self::QueryComputer, - mut f: F, - ) -> impl std::future::Future> + Send - where - Itr: Iterator + Send, - F: Send + FnMut(f32, Self::Id), - { - self.on_elements_unordered(vec_id_itr, move |element, i| { - // Default is to use the computer to evaluate the similarity. - let distance = computer.evaluate_similarity(element); - f(distance, i); - }) - } -} - ///////////////////////// // Neighbor Delegation // ///////////////////////// @@ -774,11 +638,10 @@ mod sealed { #[cfg(test)] mod tests { use std::{ - collections::HashMap, future::Future, pin::Pin, sync::{ - Arc, Mutex, + Arc, atomic::{AtomicUsize, Ordering}, }, task, @@ -787,7 +650,7 @@ mod tests { use pin_project::{pin_project, pinned_drop}; use super::*; - use crate::{always_escalate, error::Infallible}; + use crate::always_escalate; //////////////////// // DefaultContext // @@ -988,40 +851,6 @@ mod tests { assert!(deleted.is_deleted()); } - /// A simple data provider that contains values consisting of floats and strings. - /// - /// The start point for this provider is as `u32::MAX`. - struct SimpleProvider { - data: Mutex>, - } - - impl SimpleProvider { - fn new(v: f32, st: String) -> Self { - let mut data = HashMap::new(); - data.insert(u32::MAX, (v, st)); - Self { - data: Mutex::new(data), - } - } - } - - impl DataProvider for SimpleProvider { - type Context = DefaultContext; - // Use the identity mapping for IDs. - type InternalId = u32; - type ExternalId = u32; - type Error = ANNError; - type Guard = NoopGuard; - - fn to_internal_id(&self, _context: &DefaultContext, gid: &u32) -> Result { - Ok(*gid) - } - - fn to_external_id(&self, _context: &DefaultContext, id: u32) -> Result { - Ok(id) - } - } - #[derive(Debug, Clone, Copy, PartialEq)] pub struct Missing; @@ -1040,390 +869,4 @@ mod tests { } always_escalate!(Missing); - - // An accessor for the `f32` portion of the data stored in the SimpleProvider. - struct FloatAccessor<'a>(&'a SimpleProvider); - impl HasId for FloatAccessor<'_> { - type Id = u32; - } - impl Accessor for FloatAccessor<'_> { - type Element<'a> - = f32 - where - Self: 'a; - type ElementRef<'a> = f32; - - type GetError = Missing; - - fn get_element( - &mut self, - id: u32, - ) -> impl Future, Self::GetError>> + Send { - let guard = self.0.data.lock().unwrap(); - let v = match guard.get(&id) { - None => Err(Missing), - Some(v) => Ok(v.0), - }; - std::future::ready(v) - } - - // Implement `on_elements_unordered` by only acquiring the lock once. - // - // Real implementations will need to take care to avoid deadlocks. - async fn on_elements_unordered( - &mut self, - itr: Itr, - mut f: F, - ) -> Result<(), Self::GetError> - where - Self: Sync, - Itr: Iterator, - F: Send + FnMut(f32, u32), - { - let guard = self.0.data.lock().unwrap(); - for i in itr { - match guard.get(&i) { - None => return Err(Missing), - Some(v) => f(v.0, i), - } - } - Ok(()) - } - } - - // An accessor for the `String` portion of the data stored in the SimpleProvider. - // - // We keep a local buffer `buf` into which the contents of the string are copied. - // This allows us to elide allocating on `get_element` calls. - struct StringAccessor<'a> { - provider: &'a SimpleProvider, - buf: String, - } - - impl<'a> StringAccessor<'a> { - fn new(provider: &'a SimpleProvider) -> Self { - Self { - provider, - buf: String::new(), - } - } - } - - impl HasId for StringAccessor<'_> { - type Id = u32; - } - impl Accessor for StringAccessor<'_> { - type Element<'a> - = &'a str - where - Self: 'a; - type ElementRef<'a> = &'a str; - - type GetError = Missing; - - fn get_element( - &mut self, - id: u32, - ) -> impl Future, Self::GetError>> + Send { - let guard = self.provider.data.lock().unwrap(); - let v = match guard.get(&id) { - None => Err(Missing), - Some(v) => { - self.buf.clone_from(&v.1); - Ok(&*self.buf) - } - }; - std::future::ready(v) - } - } - - #[tokio::test] - async fn test_default_implementations() { - let provider = SimpleProvider::new(-1.0, "hello".to_string()); - { - let mut data = provider.data.lock().unwrap(); - data.insert(0, (0.0, "world".to_string())); - data.insert(1, (1.0, "foo".to_string())); - data.insert(2, (2.0, "bar".to_string())); - } - - // Float accessor - { - let mut accessor = FloatAccessor(&provider); - assert_eq!(accessor.get_element(0).await.unwrap(), 0.0); - assert_eq!(accessor.get_element(1).await.unwrap(), 1.0); - assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), -1.0); - - let mut v = Vec::new(); - accessor - .on_elements_unordered([2, 1, 0].into_iter(), |element, id| v.push((element, id))) - .await - .unwrap(); - - assert_eq!(&v, &[(2.0, 2), (1.0, 1), (0.0, 0)]); - - // Test error propagation. - // Trying to access element 3 will result in an error, which should be propagated - // up. - let err = accessor - .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { - v.push((element, id)) - }) - .await - .unwrap_err(); - assert_eq!(err, Missing); - } - - // String accessor - { - let mut accessor = StringAccessor::new(&provider); - assert_eq!(accessor.get_element(0).await.unwrap(), "world"); - assert_eq!(accessor.get_element(1).await.unwrap(), "foo"); - assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), "hello"); - - // This method tests the provided implementation of `on_elements_unordered`. - let expected = [("bar", 2), ("foo", 1), ("world", 0)]; - - let mut expected_iter = expected.into_iter(); - accessor - .on_elements_unordered([2, 1, 0].into_iter(), |element, id| { - assert_eq!((element, id), expected_iter.next().unwrap()); - }) - .await - .unwrap(); - assert!(expected_iter.next().is_none()); - - // Test error propagation. - // Trying to access element 3 will result in an error, which should be propagated - // up. - let mut expected_iter = expected.into_iter(); - let err = accessor - .on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| { - assert_eq!((element, id), expected_iter.next().unwrap()); - }) - .await - .unwrap_err(); - assert_eq!(err, Missing); - assert!(expected_iter.next().is_none()); - } - } - - ///////////////////////////////// - // Supported Accessor Patterns // - ///////////////////////////////// - - // This suite of tests ensure that patterns we want out of the `Accessor` associated - // trait hierarchy are all supported. - // - // These include: - // - // * Accessors that always allocate. - // * Accessors that simply reference the underlying store directly. - // * Accessors that use a local buffer. - - #[derive(Debug)] - struct Store { - data: Box<[u8]>, - } - - impl Store { - fn new() -> Self { - Self { - data: Box::from([1, 2, 3, 4]), - } - } - - fn dim(&self) -> usize { - self.data.len() - } - } - - macro_rules! common_test_accessor { - ($T:ty) => { - impl HasId for $T { - type Id = u32; - } - - impl BuildDistanceComputer for $T { - type DistanceComputerError = Infallible; - type DistanceComputer = ::Distance; - - fn build_distance_computer(&self) -> Result { - Ok(::distance( - diskann_vector::distance::Metric::L2, - None, - )) - } - } - }; - } - - // An accessor that always allocates. - struct Allocating<'a> { - store: &'a Store, - } - - impl<'a> Allocating<'a> { - fn new(store: &'a Store) -> Self { - Self { store } - } - } - - common_test_accessor!(Allocating<'_>); - - impl Accessor for Allocating<'_> { - type Element<'a> - = Box<[u8]> - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result, Infallible> { - Ok(self.store.data.clone()) - } - } - - // An accessor that forwards - returning references directly into the underlying - // store without reallocation or copying. - struct Forwarding<'a> { - store: &'a Store, - } - - impl<'a> Forwarding<'a> { - fn new(store: &'a Store) -> Self { - Self { store } - } - } - - common_test_accessor!(Forwarding<'_>); - - impl<'provider> Accessor for Forwarding<'provider> { - // NOTE: The lifetime of `Element` is `'provider` - not `'a`. This is what makes - // it a forwarding accessor. - type Element<'a> - = &'provider [u8] - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result<&'provider [u8], Infallible> { - Ok(&*self.store.data) - } - } - - // An accessor that returns a non-reference type with a lifetime. - struct Wrapping<'a> { - store: &'a Store, - } - - impl<'a> Wrapping<'a> { - fn new(store: &'a Store) -> Self { - Self { store } - } - } - - #[derive(Debug)] - struct Wrapped<'a>(&'a [u8]); - - impl<'a> Reborrow<'a> for Wrapped<'_> { - type Target = &'a [u8]; - fn reborrow(&'a self) -> Self::Target { - self.0 - } - } - - impl From> for Box<[u8]> { - fn from(wrapped: Wrapped<'_>) -> Self { - wrapped.0.into() - } - } - - common_test_accessor!(Wrapping<'_>); - - impl Accessor for Wrapping<'_> { - type Element<'a> - = Wrapped<'a> - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result, Infallible> { - Ok(Wrapped(&self.store.data)) - } - } - - // An accessor that shares local state. - #[derive(Debug)] - struct Sharing<'a> { - store: &'a Store, - local: Box<[u8]>, - } - - impl<'a> Sharing<'a> { - fn new(store: &'a Store) -> Self { - Self { - store, - local: (0..store.dim()).map(|_| 0).collect(), - } - } - } - - common_test_accessor!(Sharing<'_>); - - impl Accessor for Sharing<'_> { - type Element<'a> - = &'a [u8] - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - type GetError = Infallible; - - async fn get_element(&mut self, _: u32) -> Result<&[u8], Infallible> { - self.local.copy_from_slice(&self.store.data); - Ok(&self.local) - } - } - - #[tokio::test] - async fn test_accessor_patterns() { - let store = Store::new(); - - // A slice against which we compute distances. - let base: &[u8] = &[2, 3, 4, 5]; - - { - let mut accessor = Allocating::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - - { - let mut accessor = Forwarding::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - - { - let mut accessor = Wrapping::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - - { - let mut accessor = Sharing::new(&store); - let computer = accessor.build_distance_computer().unwrap(); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0); - } - } } diff --git a/rfcs/01078-paged-search.md b/rfcs/01078-paged-search.md new file mode 100644 index 000000000..50ac58d2c --- /dev/null +++ b/rfcs/01078-paged-search.md @@ -0,0 +1,193 @@ +# Overhaul Paged Search + +| | | +|---|---| +| **Authors** | Mark Hildebrand | +| **Contributors** | | +| **Created** | 2026-05-18 | +| **Updated** | 2026-05-18 | + +## Summary + +Replace the `SearchState<..., ExtraState: 'static>` pattern for paged search with a lifetime-bound `PagedSearch<'a, ...>` in `diskann`, and document a channel-based spawned-task pattern for downstream consumers that need to cross `tokio::spawn` or FFI boundaries. +This removes the `'static` requirement on query computers and search strategies, enabling future trait simplification. + +## Motivation + +### Background + +Paged search allows callers to retrieve nearest-neighbor results incrementally (one "page" at a time) without restarting the graph traversal. +The search state (scratch buffers, the priority queue, the query computer) must persist across page boundaries. + +Earlier synchronous version of DiskANN did this by persisting the search state manually and passing the search state explicitly to the next-page requests. +The async rewrite stuck with this pattern where callers were required to manage a `SearchState` struct whose `ExtraState` type parameter carried a `'static` bound: +```rust +// OLD: ExtraState must be 'static +pub struct SearchState { ... } +pub type PagedSearchState = SearchState<::InternalId, (S, C)>; + +// Note +// * S: SearchStrategy for some T +// * C: S::QueryComputer +``` +In downstream crates that expose paged search across FFI or task boundaries, the state was traditionally type-erased behind `Box` (which also captured the `DiskANNIndex`) and sent as an opaque pointer. +This required both the strategy and the query computer (parameters `S` and `C`) to be `'static`. + +### Problem Statement + +The `'static` bound on `BuildQueryComputer::QueryComputer` propagates throughout the trait hierarchy: + +```rust +// OLD +type QueryComputer: ... + Send + Sync + 'static; +``` + +This prevents: + +1. Query computers that borrow from the index or context (common with quantization tables). +2. Fusing the query-computer into the accessor. + Due to the lifetime needed by accessors, they can't be persisted in a `'static` struct this way. + However, query computers may contain non-trivial pre-processed state, meaning recreating them on each new page retrieval is a performance footgun. +3. Simplifying the `SearchStrategy` / `Accessor` trait tower by removing unnecessary indirection introduced solely to satisfy `'static`. + +### Goals + +1. Remove the `'static` bound from `BuildQueryComputer::QueryComputer`. +2. Remove `SearchState`, `NoExtraState`, and `PagedSearchState` from the public API. +3. Provide a `PagedSearch<'a, ...>` handle that is lifetime-bound to the index and context, encapsulating all search state. +4. Document the channel-based pattern for downstream consumers that need to cross task/FFI boundaries (where `'static` is inherently required by the runtime, not by `diskann`). + +## Proposal + +The key idea is this: DiskANN for better or worse is already fully async, and async Rust despite its flaws already provides a clean way of doing this without requiring our traits to bend over backwards. +So let's embrace async to actually help us for a change. + +### Core library (`diskann`) + +Replace the old split API: + +```rust +// OLD +index.start_paged_search(strategy, ctx, query, l) -> SearchState<...> +index.next_search_results(ctx, &mut state, k, &mut buf) -> usize +``` + +With a self-contained handle: + +```rust +// NEW +impl DiskANNIndex { + pub fn paged_search<'a, S, T>( + &'a self, + strategy: S, + context: &'a DP::Context, + query: T, + l_value: usize, + ) -> impl SendFuture>> + where + S: SearchStrategy, + T: Copy + Send + 'a; +} + +pub struct PagedSearch<'a, DP: DataProvider, S: SearchStrategy, T> { + index: &'a DiskANNIndex, + context: &'a DP::Context, + scratch: SearchScratch, + computed_result: Vec>, + next_result_index: usize, + search_param_l: usize, + strategy: S, + computer: S::QueryComputer, + _query: PhantomData, // covariant, always Send+Sync +} + +impl PagedSearch<'a, DP, S, T> { + pub fn next_page(&mut self, k: usize) -> impl SendFuture>>>; +} +``` + +The key change is that `PagedSearch` borrows all the necessary components. +There is no need to deconstruct the search components into `'static` pieces after each paged search and "reassemble" them on subsequent searches. + +### Crossing spawn boundaries: the channel pattern + +`PagedSearch<'a, ...>` borrows the index, so it cannot be sent to a `tokio::spawn`'d task directly. +When a long-lived session or an FFI boundary requires `'static` ownership, the recommended pattern is to **spawn a task that owns the search state as a local variable** and communicate with it via channels: + +```rust +// Types are illustrative — adapt names to your crate. + +type PageResult = ANNResult>>; + +/// Spawn a paged search session. The index is held by Arc so the task is 'static. +/// +/// Returns a request channel and a result channel. The caller sends the desired +/// page size (`k`) and awaits the corresponding result on the other end. +fn spawn_paged_session( + index: Arc>, + context: Arc, + query: T, + l: usize, +) -> (mpsc::Sender, mpsc::Receiver) { + let (req_tx, mut req_rx) = mpsc::channel::(1); + let (res_tx, res_rx) = mpsc::channel::(1); + + tokio::spawn(async move { + // Borrow from the Arc — these references are scoped to the task. + let mut search = index.paged_search(strategy, &*context, query, l).await.unwrap(); + + while let Some(k) = req_rx.recv().await { + let page = search.next_page(k).await; + if res_tx.send(page).await.is_err() { + break; // caller dropped the result receiver + } + } + // Request channel closed -> caller dropped sender -> clean shutdown. + }); + + (req_tx, res_rx) +} +``` + +Key properties of this pattern: + +1. **`'static` is confined to the spawn boundary**: the `Arc` satisfies the runtime's requirement, while the borrow from it lives entirely inside the task's local scope. + Importantly, even though `PagedSearch` borrows, it can be embedded inside a `'static` future. +2. **State is fully encapsulated**: callers never see `SearchScratch`, `QueryComputer`, or any internal types. +3. **Clean shutdown**: dropping the request sender closes the channel; the task exits gracefully. +4. **Per-request context**: the request channel can carry additional metadata (profiling tokens, cancellation flags, etc.) without polluting the core API. + +### Migration guide + +| Old pattern | New pattern | +|---|---| +| `index.start_paged_search(s, ctx, q, l)` | `index.paged_search(s, ctx, q, l).await` | +| `index.next_search_results(ctx, &mut state, k, &mut buf)` | `search.next_page(k).await` | +| `SearchState` | `PagedSearch<'a, DP, S, T>` | +| `PagedSearchState` | `PagedSearch<'a, DP, S, T>` | +| Check return count for exhaustion | Check `page.is_empty()` | +| Type-erased `Box` across task/FFI boundaries | Channel + spawned task (see above) | + +## Feasibility via FFI + +This is a pretty big change in API, but it enables some significant future simplifications to our trait hierarchy by removing the `'static` special case introduced by paged search. +An internal user of paged search was ported to this new approach to check the feasibility. +While it was a bit of work to overcome the impedance mismatch of the quirks of that integration, the end result is cleaner, has fewer overall task spawns, and fewer FFI related race conditions. +And really, this integration was already basically doing this same thing behind the scenes. + +## Alternatives + +The main alternative I see is to keep the status quo with explicit state management. +While some of planned trait simplifications are still on the table, I think the opportunity to align paged search with the rest of the trait hierarchy is well worth it. + +## Benchmark Results + +No performance change expected (nor observed in the simulator for the aforementioned internal FFI user) since the search algorithm is identical. +Existing Rust code will have a similar pattern of future usage as before, just packaged slightly differently. + +## References + +1. [RFC 3498 — Lifetime Capture Rules 2024](https://rust-lang.github.io/rfcs/3498-lifetime-capture-rules-2024.html) — + Rust edition 2024 changes that make `impl Trait + 'a` returns more ergonomic. +2. [tokio::sync::mpsc](https://docs.rs/tokio/latest/tokio/sync/mpsc/index.html) — the channel + primitive used in the spawned-task pattern.