From c5e2b831e0dcf83ede483ebee0ef865ee7bec25a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 19 Mar 2026 22:15:52 +0000 Subject: [PATCH 1/4] Initial plan From 21ca91018f5c69637d28d7b44ca3790efdf4d8f4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 19 Mar 2026 22:35:09 +0000 Subject: [PATCH 2/4] Remove DebugProvider from diskann-providers --- .../graph/provider/async_/caching/example.rs | 968 ----------- .../graph/provider/async_/caching/mod.rs | 3 - .../graph/provider/async_/debug_provider.rs | 1453 ----------------- .../src/model/graph/provider/async_/mod.rs | 4 - 4 files changed, 2428 deletions(-) delete mode 100644 diskann-providers/src/model/graph/provider/async_/caching/example.rs delete mode 100644 diskann-providers/src/model/graph/provider/async_/debug_provider.rs diff --git a/diskann-providers/src/model/graph/provider/async_/caching/example.rs b/diskann-providers/src/model/graph/provider/async_/caching/example.rs deleted file mode 100644 index ab3edf0aa..000000000 --- a/diskann-providers/src/model/graph/provider/async_/caching/example.rs +++ /dev/null @@ -1,968 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use diskann::{ - graph::AdjacencyList, - provider::{self as core_provider, DefaultContext}, -}; -use diskann_utils::future::AsyncFriendly; -use diskann_vector::distance::Metric; - -use crate::model::graph::provider::async_::{ - common::FullPrecision, - debug_provider::{self, DebugProvider}, -}; - -use super::{ - bf_cache::{self, Cache}, - error::CacheAccessError, - provider::{self as cache_provider, CachingError, NeighborStatus}, - utils::{CacheKey, Graph, HitStats, KeyGen, LocalStats}, -}; - -/////////////////// -// Example Cache // -/////////////////// - -/// A representation for the adjacency list term. -/// -/// The `repr(u8)` allows this to be converted to an integer and thus be used in a -/// `bytemuck::Pod` struct. -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -enum Tag { - AdjacencyList = 0, - Vector = 1, -} - -const TAGS: [Tag; 2] = [Tag::AdjacencyList, Tag::Vector]; - -impl KeyGen for Tag { - type Key = CacheKey; - fn generate(&self, id: u32) -> Self::Key { - CacheKey::new(id, (*self as u8).into()) - } -} - -/// An example cache for compatibility with the full-precision provider. -/// -/// This is meant largely for demonstration purposes, showing how to build a cache tailored -/// for a certain provider. -#[derive(Debug)] -pub struct ExampleCache { - cache: Cache, - // This field records a collection of uncacheable IDs to ensure the cache skipping - // logic in the provider works as expected. - uncacheable: Option>, - neighbor_stats: HitStats, - vector_stats: HitStats, -} - -impl ExampleCache { - /// Construct a new cache with the given capacity. - pub fn new( - bytes: diskann_quantization::num::PowerOfTwo, - uncacheable: Option>, - ) -> Self { - Self { - cache: Cache::new(bytes).unwrap(), - uncacheable, - neighbor_stats: HitStats::new(), - vector_stats: HitStats::new(), - } - } - - /// Invalidate all cached items associated with `id`. - fn invalidate(&self, id: u32) { - // Since we use a unified cache under the scenes type the list types disambiguated - // by a tag - we need to delete all possible tags. - for tag in TAGS { - self.cache.delete(tag.generate(id)) - } - } - - /// Return a cache accessor for adjacency list terms. - fn neighbors(&self, max_degree: usize) -> Graph<'_, u32, Tag> { - Graph::new( - &self.cache, - max_degree, - Tag::AdjacencyList, - &self.neighbor_stats, - ) - } -} - -impl cache_provider::Evict for ExampleCache { - fn evict(&self, id: u32) { - self.invalidate(id) - } -} - -/// The `CacheAccessor` is the `C` in `cache_provider::CachingAccessor` and interfaces the -/// accessor with the underlying cache. -#[derive(Debug)] -pub struct CacheAccessor<'a, T> { - graph: Graph<'a, u32, Tag>, - cacher: T, - stats: LocalStats<'a>, - uncacheable: Option<&'a [u32]>, - /// Key generation for the element being accessed. - keygen: Tag, -} - -impl cache_provider::ElementCache> - for CacheAccessor<'_, bf_cache::VecCacher> -where - T: AsyncFriendly + bytemuck::Pod + Default + std::fmt::Debug, -{ - type Error = CacheAccessError; - - fn get_cached(&mut self, k: u32) -> Result, CacheAccessError> { - match self - .graph - .cache() - .get(self.keygen.generate(k), &mut self.cacher) - { - Ok(Some(value)) => { - self.stats.hit(); - Ok(Some(value)) - } - Ok(None) => { - self.stats.miss(); - Ok(None) - } - Err(err) => Err(CacheAccessError::read(k, err)), - } - } - - fn set_cached(&mut self, k: u32, element: &&[T]) -> Result<(), CacheAccessError> { - self.graph - .cache() - .set(self.keygen.generate(k), &mut self.cacher, element) - .map_err(|err| CacheAccessError::write(k, err)) - } -} - -impl cache_provider::NeighborCache for CacheAccessor<'_, T> -where - T: AsyncFriendly, -{ - type Error = CacheAccessError; - - fn try_get_neighbors( - &mut self, - id: u32, - neighbors: &mut AdjacencyList, - ) -> Result { - if let Some(uncacheable) = self.uncacheable - && uncacheable.contains(&id) - { - self.graph.stats_mut().miss(); - Ok(NeighborStatus::Uncacheable) - } else { - self.graph.try_get_neighbors(id, neighbors) - } - } - - fn set_neighbors(&mut self, id: u32, neighbors: &[u32]) -> Result<(), CacheAccessError> { - self.graph.set_neighbors(id, neighbors) - } - - fn invalidate_neighbors(&mut self, id: u32) { - self.graph.invalidate_neighbors(id) - } -} - -///////////////////// -// Provider Bridge // -///////////////////// - -impl<'a> cache_provider::AsCacheAccessorFor<'a, debug_provider::FullAccessor<'a>> for ExampleCache { - type Accessor = CacheAccessor<'a, bf_cache::VecCacher>; - type Error = diskann::error::Infallible; - fn as_cache_accessor_for( - &'a self, - inner: debug_provider::FullAccessor<'a>, - ) -> Result< - cache_provider::CachingAccessor, Self::Accessor>, - Self::Error, - > { - let provider = inner.provider(); - let cache_accessor = CacheAccessor { - graph: self.neighbors(provider.max_degree()), - cacher: bf_cache::VecCacher::::new(provider.dim()), - uncacheable: self.uncacheable.as_deref(), - stats: LocalStats::new(&self.vector_stats), - keygen: Tag::Vector, - }; - Ok(cache_provider::CachingAccessor::new(inner, cache_accessor)) - } -} - -impl<'a> cache_provider::CachedFillSet>> - for debug_provider::FullAccessor<'a> -{ -} - -impl<'a> cache_provider::CachedAsElement<&'a [f32], CacheAccessor<'a, bf_cache::VecCacher>> - for debug_provider::FullAccessor<'a> -{ - type Error = CachingError; - async fn cached_as_element<'b>( - &'b mut self, - cache: &'b mut CacheAccessor<'a, bf_cache::VecCacher>, - _vector: &'a [f32], - id: u32, - ) -> Result, Self::Error> { - cache_provider::get_or_insert(self, cache, id).await - } -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use diskann::{ - graph::{DiskANNIndex, glue::SearchStrategy}, - provider::{ - Accessor, DataProvider, Delete, NeighborAccessor, NeighborAccessorMut, SetElement, - }, - utils::async_tools, - }; - use diskann_quantization::num::PowerOfTwo; - use diskann_utils::views::Matrix; - use diskann_vector::{PureDistanceFunction, distance::SquaredL2}; - use rstest::rstest; - - use crate::{ - index::diskann_async::{self, tests as async_tests}, - model::graph::provider::async_::caching::provider::{AsCacheAccessorFor, CachingProvider}, - utils as crate_utils, - }; - - const CTX: &DefaultContext = &DefaultContext; - - fn test_provider( - uncacheable: Option>, - ) -> CachingProvider { - let dim = 2; - - let config = debug_provider::DebugConfig { - start_id: u32::MAX, - start_point: vec![0.0; dim], - max_degree: 10, - metric: Metric::L2, - }; - - let table = diskann_async::train_pq( - Matrix::new(0.0, 1, dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(); - - CachingProvider::new( - DebugProvider::new(config, Arc::new(table)).unwrap(), - ExampleCache::new(PowerOfTwo::new(1024 * 16).unwrap(), uncacheable), - ) - } - - #[tokio::test] - async fn basic_operations_happy_path() { - let provider = test_provider(None); - let ctx = &DefaultContext; - - // Translations do not yet exist. - assert!(provider.to_external_id(ctx, 0).is_err()); - assert!(provider.to_internal_id(ctx, &0).is_err()); - - assert_eq!(provider.inner().data_writes.get(), 0); - provider.set_element(CTX, &0, &[1.0, 2.0]).await.unwrap(); - assert_eq!(provider.inner().data_writes.get(), 1 /* increased */); - - assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0); - assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0); - - // Retrieval of a valid element. - let mut accessor = provider - .cache() - .as_cache_accessor_for(debug_provider::FullAccessor::new(provider.inner())) - .unwrap(); - - // Hit served from the underlying provider. - assert_eq!(provider.inner().full_reads.get(), 0); - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 1); - assert_eq!( - accessor.cache().stats.get_local_misses(), - 1, /* increased */ - ); - assert_eq!(accessor.cache().stats.get_local_hits(), 0); - - // This time, the hit is served from the underlying cache. - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 1); - assert_eq!(accessor.cache().stats.get_local_misses(), 1); - assert_eq!( - accessor.cache().stats.get_local_hits(), - 1, /* increased */ - ); - - // Adjacency List from Underlying - assert_eq!(provider.inner().neighbor_writes.get(), 0); - accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); - assert_eq!( - provider.inner().neighbor_writes.get(), - 1, /* increased */ - ); - - let mut list = AdjacencyList::new(); - assert_eq!(provider.inner().neighbor_reads.get(), 0); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!( - provider.inner().neighbor_reads.get(), - 1, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 1, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0); - assert_eq!(&*list, &[1, 2, 3]); - - // Adjacency List From Cache - list.clear(); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!(&*list, &[1, 2, 3]); - assert_eq!(provider.inner().neighbor_reads.get(), 1); - assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); - assert_eq!( - accessor.cache().graph.stats().get_local_hits(), - 1, /* increased */ - ); - - // If we invalidate the key - these elements should be retrieved from the backing - // provider instead. - provider.cache().invalidate(0); - - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2 /* increased */,); - assert_eq!( - accessor.cache().stats.get_local_misses(), - 2, /* increased */ - ); - assert_eq!(accessor.cache().stats.get_local_hits(), 1); - - // Once more from the cache. - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2); - assert_eq!(accessor.cache().stats.get_local_misses(), 2); - assert_eq!( - accessor.cache().stats.get_local_hits(), - 2, /* increased */ - ); - - list.clear(); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!(&*list, &[1, 2, 3]); - assert_eq!( - provider.inner().neighbor_reads.get(), - 2, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 2, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); - - // Setting adjacency lists invalidates the cache. - accessor.set_neighbors(0, &[2, 3, 4]).await.unwrap(); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!(&*list, &[2, 3, 4]); - assert_eq!( - provider.inner().neighbor_writes.get(), - 2, /* increased */ - ); - assert_eq!( - provider.inner().neighbor_reads.get(), - 3, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 3, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); - - accessor.append_vector(0, &[1]).await.unwrap(); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!(&*list, &[2, 3, 4, 1]); - - assert_eq!( - provider.inner().neighbor_writes.get(), - 3, /* increased */ - ); - assert_eq!( - provider.inner().neighbor_reads.get(), - 4, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 4, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); - - // Deletion. - assert_eq!( - provider.status_by_internal_id(CTX, 0).await.unwrap(), - core_provider::ElementStatus::Valid - ); - assert_eq!( - provider.status_by_external_id(CTX, &0).await.unwrap(), - core_provider::ElementStatus::Valid - ); - assert!(provider.status_by_internal_id(CTX, 1).await.is_err()); - assert!(provider.status_by_external_id(CTX, &1).await.is_err()); - - provider.delete(CTX, &0).await.unwrap(); - - assert_eq!( - provider.status_by_internal_id(CTX, 0).await.unwrap(), - core_provider::ElementStatus::Deleted - ); - assert_eq!( - provider.status_by_external_id(CTX, &0).await.unwrap(), - core_provider::ElementStatus::Deleted - ); - assert!(provider.status_by_internal_id(CTX, 1).await.is_err()); - assert!(provider.status_by_external_id(CTX, &1).await.is_err()); - - // Access the deleted element is still valid. - let element = accessor.get_element(0).await.unwrap(); - assert_eq!(element, &[1.0, 2.0]); - assert_eq!(provider.inner().full_reads.get(), 2); - assert_eq!(accessor.cache().stats.get_local_misses(), 2); - assert_eq!( - accessor.cache().stats.get_local_hits(), - 3, /* increased */ - ); - - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!(&*list, &[2, 3, 4, 1]); - assert_eq!(provider.inner().neighbor_writes.get(), 3); - assert_eq!(provider.inner().neighbor_reads.get(), 4); - assert_eq!(accessor.cache().graph.stats().get_local_misses(), 4); - assert_eq!( - accessor.cache().graph.stats().get_local_hits(), - 2, /* increased */ - ); - - provider.release(CTX, 0).await.unwrap(); - assert!(provider.status_by_internal_id(CTX, 0).await.is_err()); - assert!(provider.status_by_external_id(CTX, &0).await.is_err()); - - assert!(accessor.get_element(0).await.is_err()); - assert_eq!(provider.inner().full_reads.get(), 2); - assert_eq!( - accessor.cache().stats.get_local_misses(), - 3 /* increased */ - ); - assert_eq!(accessor.cache().stats.get_local_hits(), 3); - - assert!(accessor.get_neighbors(0, &mut list).await.is_err()); - assert_eq!(provider.inner().neighbor_writes.get(), 3); - assert_eq!(provider.inner().neighbor_reads.get(), 4); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 5 /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 2); - - // Ensure that the stats get properly recorded when the accessor is dropped. - let vector_hits = accessor.cache().stats.get_local_hits(); - let vector_misses = accessor.cache().stats.get_local_misses(); - let neighbor_hits = accessor.cache().graph.stats().get_local_hits(); - let neighbor_misses = accessor.cache().graph.stats().get_local_misses(); - - std::mem::drop(accessor); - - assert_eq!(provider.cache().vector_stats.get_hits(), vector_hits); - assert_eq!(provider.cache().vector_stats.get_misses(), vector_misses); - - assert_eq!(provider.cache().neighbor_stats.get_hits(), neighbor_hits); - assert_eq!( - provider.cache().neighbor_stats.get_misses(), - neighbor_misses - ); - } - - #[tokio::test] - async fn test_uncacheable() { - // Test that returning `Uncacheable` for an adjacency list is handled correctly by - // the provider and a call to `set_neighbors` is not made. - let uncacheable = u32::MAX; - let provider = test_provider(Some(vec![uncacheable])); - - let mut accessor = provider - .cache() - .as_cache_accessor_for(debug_provider::FullAccessor::new(provider.inner())) - .unwrap(); - - provider.set_element(CTX, &0, &[1.0, 2.0]).await.unwrap(); - - //---------------// - // Cacheable IDs // - //---------------// - - // Adjacency List from Underlying - assert_eq!(provider.inner().neighbor_writes.get(), 0); - accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); - assert_eq!( - provider.inner().neighbor_writes.get(), - 1, /* increased */ - ); - - let mut list = AdjacencyList::new(); - assert_eq!(provider.inner().neighbor_reads.get(), 0); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!( - provider.inner().neighbor_reads.get(), - 1, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 1, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0); - assert_eq!(&*list, &[1, 2, 3]); - - // Adjacency List From Cache - list.clear(); - accessor.get_neighbors(0, &mut list).await.unwrap(); - assert_eq!(&*list, &[1, 2, 3]); - assert_eq!(provider.inner().neighbor_reads.get(), 1); - assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); - assert_eq!( - accessor.cache().graph.stats().get_local_hits(), - 1, /* increased */ - ); - - //-----------------// - // Uncacheable IDs // - //-----------------// - - assert_eq!(provider.inner().neighbor_writes.get(), 1); - accessor.set_neighbors(uncacheable, &[4, 5]).await.unwrap(); - assert_eq!( - provider.inner().neighbor_writes.get(), - 2, /* increased */ - ); - - // The retrieval is served by the inner provider. - assert_eq!(provider.inner().neighbor_reads.get(), 1); - accessor - .get_neighbors(uncacheable, &mut list) - .await - .unwrap(); - assert_eq!( - provider.inner().neighbor_reads.get(), - 2, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 2, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); - assert_eq!(&*list, &[4, 5]); - - // Again, retrieval is served by the inner provider. - assert_eq!(provider.inner().neighbor_reads.get(), 2); - accessor - .get_neighbors(uncacheable, &mut list) - .await - .unwrap(); - assert_eq!( - provider.inner().neighbor_reads.get(), - 3, /* increased */ - ); - assert_eq!( - accessor.cache().graph.stats().get_local_misses(), - 3, /* increased */ - ); - assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); - assert_eq!(&*list, &[4, 5]); - } - - //----------------// - // Standard Tests // - //----------------// - - #[rstest] - #[case(1, 100)] - #[case(3, 7)] - #[case(4, 5)] - #[tokio::test] - async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - let start_id = u32::MAX; - let start_point = vec![grid_size as f32; dim]; - let metric = Metric::L2; - let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); - - let index_config = diskann::graph::config::Builder::new( - max_degree, - diskann::graph::config::MaxDegree::default_slack(), - l, - metric.into(), - ) - .build() - .unwrap(); - - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), - metric, - }; - - let mut vectors = ::generate_grid(dim, grid_size); - let table = diskann_async::train_pq( - async_tests::squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(); - - let provider = CachingProvider::new( - DebugProvider::new(test_config, Arc::new(table)).unwrap(), - ExampleCache::new(cache_size, None), - ); - let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); - - let adjacency_lists = match dim { - 1 => crate_utils::generate_1d_grid_adj_list(grid_size as u32), - 3 => crate_utils::genererate_3d_grid_adj_list(grid_size as u32), - 4 => crate_utils::generate_4d_grid_adj_list(grid_size as u32), - _ => panic!("Unsupported number of dimensions"), - }; - assert_eq!(adjacency_lists.len(), num_points); - assert_eq!(vectors.len(), num_points); - - let strategy = cache_provider::Cached::new(FullPrecision); - async_tests::populate_data(index.provider(), CTX, &vectors).await; - { - // Note: Without the fully qualified syntax - this fails to compile. - let mut accessor = as SearchStrategy< - cache_provider::CachingProvider, - [f32], - >>::search_accessor(&strategy, index.provider(), CTX) - .unwrap(); - async_tests::populate_graph(&mut accessor, &adjacency_lists).await; - - accessor - .set_neighbors(start_id, &[num_points as u32 - 1]) - .await - .unwrap(); - } - - let corpus: diskann_utils::views::Matrix = async_tests::squish(vectors.iter(), dim); - let mut paged_tests = Vec::new(); - - // Test with the zero query. - let query = vec![0.0; dim]; - let gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| { - SquaredL2::evaluate(a, b) - }); - paged_tests.push(async_tests::PagedSearch::new(query, gt)); - - // Test with the start point to ensure it is filtered out. - let gt = crate::test_utils::groundtruth(corpus.as_view(), &start_point, |a, b| { - SquaredL2::evaluate(a, b) - }); - paged_tests.push(async_tests::PagedSearch::new(start_point.clone(), gt)); - - // Unfortunately - this is needed for the `check_grid_search` test. - vectors.push(start_point.clone()); - async_tests::check_grid_search(&index, &vectors, &paged_tests, strategy, strategy).await; - } - - fn check_stats(caching: &CachingProvider) { - let provider = caching.inner(); - let cache = caching.cache(); - - println!("neighbor reads: {}", provider.neighbor_reads.get()); - println!("neighbor writes: {}", provider.neighbor_writes.get()); - println!("vector reads: {}", provider.full_reads.get()); - println!("vector writes: {}", provider.data_writes.get()); - - println!("neighbor hits: {}", cache.neighbor_stats.get_hits()); - println!("neighbor misses: {}", cache.neighbor_stats.get_misses()); - println!("vector hits: {}", cache.vector_stats.get_hits()); - println!("vector misses: {}", cache.vector_stats.get_misses()); - - // Neighbors - assert_eq!( - provider.neighbor_reads.get(), - cache.neighbor_stats.get_misses() - ); - - // Vectors - assert_eq!(provider.full_reads.get(), cache.vector_stats.get_misses()); - } - - #[rstest] - #[tokio::test] - async fn grid_search_with_build( - #[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize), - ) { - let dim = dim_and_size.0; - let grid_size = dim_and_size.1; - let l = 10; - let start_id = u32::MAX; - let start_point = vec![grid_size as f32; dim]; - let metric = Metric::L2; - let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); - - // NOTE: Be careful changing `max_degree`. It needs to be high enough that the - // graph is navigable, but low enough that the batch parallel handling inside - // `multi_insert` is needed for the multi-insert graph to be navigable. - // - // With the current configured values, removing the other elements in the batch - // from the visited set during `multi_insert` results in a graph failure. - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - - let mut vectors = ::generate_grid(dim, grid_size); - let table = Arc::new( - diskann_async::train_pq( - async_tests::squish(vectors.iter(), dim).as_view(), - 2.min(dim), - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(), - ); - - let index_config = diskann::graph::config::Builder::new_with( - max_degree, - diskann::graph::config::MaxDegree::default_slack(), - l, - metric.into(), - |b| { - b.max_minibatch_par(10); - }, - ) - .build() - .unwrap(); - - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), - metric, - }; - assert_eq!(vectors.len(), num_points); - - // This is a little subtle, but we need `vectors` to contain the start point as - // its last element, but we **don't** want to include it in the index build. - // - // This basically means that we need to be careful with out index initialization. - vectors.push(vec![grid_size as f32; dim]); - - // Initialize an index for a new round of building. - let init_index = || { - let provider = CachingProvider::new( - DebugProvider::new(test_config.clone(), table.clone()).unwrap(), - ExampleCache::new(cache_size, None), - ); - Arc::new(DiskANNIndex::new(index_config.clone(), provider, None)) - }; - - let strategy = cache_provider::Cached::new(FullPrecision); - - // Build with full-precision single insert - { - let index = init_index(); - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(strategy, CTX, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - check_stats(index.provider()); - - async_tests::check_grid_search(&index, &vectors, &[], strategy, strategy).await; - check_stats(index.provider()); - } - - // Build with full-precision multi-insert - { - let index = init_index(); - let batch: Box<[_]> = vectors - .iter() - .take(num_points) - .enumerate() - .map(|(id, v)| async_tools::VectorIdBoxSlice::new(id as u32, v.as_slice().into())) - .collect(); - - index.multi_insert(strategy, CTX, batch).await.unwrap(); - - async_tests::check_grid_search(&index, &vectors, &[], strategy, strategy).await; - check_stats(index.provider()); - } - } - - #[tokio::test] - async fn test_inplace_delete_2d() { - // create small index instance - let metric = Metric::L2; - let num_points = 4; - let strategy = cache_provider::Cached::new(FullPrecision); - let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); - let start_id = num_points as u32; - let start_point = vec![0.5, 0.5]; - let dim = start_point.len(); - - let index_config = diskann::graph::config::Builder::new( - 4, // target_degree - diskann::graph::config::MaxDegree::default_slack(), - 10, // l_build - metric.into(), - ) - .build() - .unwrap(); - - let test_config = debug_provider::DebugConfig { - start_id, - start_point: start_point.clone(), - max_degree: index_config.max_degree().get(), - metric, - }; - - // The contents of the table don't matter for this test because we use full - // precision only. - let table = diskann_async::train_pq( - Matrix::new(0.5, 1, dim).as_view(), - dim, - &mut crate::utils::create_rnd_from_seed_in_tests(0), - 1usize, - ) - .unwrap(); - - let index = DiskANNIndex::new( - index_config, - CachingProvider::new( - DebugProvider::new(test_config, Arc::new(table)).unwrap(), - ExampleCache::new(cache_size, None), - ), - None, - ); - - // vectors are the four corners of a square, with the start point in the middle - // the middle point forms an edge to each corner, while corners form an edge - // to their opposite vertex vertically as well as the middle - let vectors = [ - vec![0.0, 0.0], - vec![0.0, 1.0], - vec![1.0, 0.0], - vec![1.0, 1.0], - ]; - let adjacency_lists = [ - AdjacencyList::from_iter_untrusted([4, 1]), - AdjacencyList::from_iter_untrusted([4, 0]), - AdjacencyList::from_iter_untrusted([4, 3]), - AdjacencyList::from_iter_untrusted([4, 2]), - AdjacencyList::from_iter_untrusted([0, 1, 2, 3]), - ]; - - // Note: Without the fully qualified syntax - this fails to compile. - let mut accessor = as SearchStrategy< - cache_provider::CachingProvider, - [f32], - >>::search_accessor(&strategy, index.provider(), CTX) - .unwrap(); - - async_tests::populate_data(index.provider(), CTX, &vectors).await; - async_tests::populate_graph(&mut accessor, &adjacency_lists).await; - - index - .inplace_delete( - strategy, - CTX, - &3, // id to delete - 3, // num_to_replace - diskann::graph::InplaceDeleteMethod::VisitedAndTopK { - k_value: 4, - l_value: 10, - }, - ) - .await - .unwrap(); - - // Check that the vertex was marked as deleted. - assert!( - index - .data_provider - .status_by_internal_id(CTX, 3) - .await - .unwrap() - .is_deleted() - ); - - // expected outcome: - // vertex 4 (the start point) has its edge to 3 deleted - // vertex 2 (the other point with edge pointing to 3) should have its edge to point 3 deleted, - // and replaced with edges to points 0 and 1 - // vertices 0 and 1 should add an edge pointing to 2. - // vertex 3 should be dropped - { - let mut list = AdjacencyList::new(); - accessor.get_neighbors(4, &mut list).await.unwrap(); - list.sort(); - assert_eq!(&*list, &[0, 1, 2]); - } - - { - let mut list = AdjacencyList::new(); - accessor.get_neighbors(2, &mut list).await.unwrap(); - list.sort(); - assert_eq!(&*list, &[0, 1, 4]); - } - - { - let mut list = AdjacencyList::new(); - accessor.get_neighbors(0, &mut list).await.unwrap(); - list.sort(); - assert_eq!(&*list, &[1, 2, 4]); - } - - { - let mut list = AdjacencyList::new(); - accessor.get_neighbors(1, &mut list).await.unwrap(); - list.sort(); - assert_eq!(&*list, &[0, 2, 4]); - } - - { - let mut list = AdjacencyList::new(); - accessor.get_neighbors(3, &mut list).await.unwrap(); - assert!(list.is_empty()); - } - } -} diff --git a/diskann-providers/src/model/graph/provider/async_/caching/mod.rs b/diskann-providers/src/model/graph/provider/async_/caching/mod.rs index dc93ca52c..aeac402ef 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/mod.rs @@ -7,6 +7,3 @@ pub mod bf_cache; pub mod error; pub mod provider; pub mod utils; - -#[cfg(test)] -pub mod example; diff --git a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs b/diskann-providers/src/model/graph/provider/async_/debug_provider.rs deleted file mode 100644 index 0c0ba780e..000000000 --- a/diskann-providers/src/model/graph/provider/async_/debug_provider.rs +++ /dev/null @@ -1,1453 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -use std::{ - collections::{HashMap, hash_map}, - sync::{ - Arc, RwLock, RwLockReadGuard, RwLockWriteGuard, - atomic::{AtomicUsize, Ordering}, - }, -}; - -use diskann::{ - ANNError, ANNErrorKind, ANNResult, - graph::{ - AdjacencyList, - glue::{ - AsElement, ExpandBeam, FillSet, FilterStartPoints, InplaceDeleteStrategy, - InsertStrategy, Pipeline, PruneStrategy, SearchExt, SearchStrategy, - }, - }, - provider::{ - self, Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultAccessor, - DefaultContext, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, - NeighborAccessorMut, - }, - tracked_warn, - utils::VectorRepr, -}; -use diskann_quantization::CompressInto; -use diskann_vector::distance::Metric; -use thiserror::Error; - -use crate::{ - model::{ - FixedChunkPQTable, - distance::{DistanceComputer, QueryComputer}, - graph::provider::async_::{ - common::{FullPrecision, Internal, Panics, Quantized}, - distances::{self, pq::Hybrid}, - postprocess, - }, - pq, - }, - utils::BridgeErr, -}; - -#[derive(Debug, Clone)] -pub struct DebugConfig { - pub start_id: u32, - pub start_point: Vec, - pub max_degree: usize, - pub metric: Metric, -} - -/// A version of `DebugConfig` that has the compressed representation of `start_point` in -/// addition to the full-precision representation. -#[derive(Debug, Clone)] -struct InternalConfig { - start_id: u32, - start_point: Datum, - max_degree: usize, - metric: Metric, -} - -/// A combined full-precision and PQ quantized vector. -#[derive(Debug, Default, Clone)] -pub struct Datum { - full: Vec, - quant: Vec, -} - -impl Datum { - /// Create a new `Datum`. - fn new(full: Vec, quant: Vec) -> Self { - Self { full, quant } - } - - /// Return a reference to the full-precision vector. - fn full(&self) -> &[f32] { - &self.full - } - - /// Return a reference to the quantized vector. - fn quant(&self) -> &[u8] { - &self.quant - } -} - -/// A container for `Datum`s within the `Debug` provider. -/// -/// This tracks whether items are valid or have been marked as deleted inline. -#[derive(Debug, Clone)] -pub enum Vector { - Valid(Datum), - Deleted(Datum), -} - -impl Vector { - /// Change `self` to be `Self::Deleted`, leaving the internal `Datum` unchanged. - fn mark_deleted(&mut self) { - *self = match self.take() { - Self::Valid(v) => Self::Deleted(v), - Self::Deleted(v) => Self::Deleted(v), - } - } - - /// Take the internal `Datum` and construct a new instance of `Self`. - /// - /// Leave the caller with an empty `Datum`. - fn take(&mut self) -> Self { - match self { - Self::Valid(v) => Self::Valid(std::mem::take(v)), - Self::Deleted(v) => Self::Deleted(std::mem::take(v)), - } - } - - /// Return `true` if `self` has been marked as deleted. Otherwise, return `false`. - fn is_deleted(&self) -> bool { - matches!(self, Self::Deleted(_)) - } -} - -impl std::ops::Deref for Vector { - type Target = Datum; - fn deref(&self) -> &Datum { - match self { - Self::Valid(v) => v, - Self::Deleted(v) => v, - } - } -} - -/// A simple increment-only thread-safe counter. -#[derive(Debug)] -pub struct Counter(AtomicUsize); - -impl Counter { - /// Construct a new counter with a count of 0. - fn new() -> Self { - Self(AtomicUsize::new(0)) - } - - /// Increment the counter by 1. - pub(crate) fn increment(&self) { - self.0.fetch_add(1, Ordering::Relaxed); - } - - /// Return the current value of the counter. - pub(crate) fn get(&self) -> usize { - self.0.load(Ordering::Relaxed) - } -} - -pub struct DebugProvider { - config: InternalConfig, - - pub pq_table: Arc, - pub data: RwLock>, - pub neighbors: RwLock>>, - - // Counters - pub full_reads: Counter, - pub quant_reads: Counter, - pub neighbor_reads: Counter, - pub data_writes: Counter, - pub neighbor_writes: Counter, - - // Track whether the `insert_search_accessor` is invoked. - pub insert_search_accessor_calls: Counter, -} - -impl DebugProvider { - pub fn new(config: DebugConfig, pq_table: Arc) -> ANNResult { - // Compress the start point. - let mut pq = vec![0u8; pq_table.get_num_chunks()]; - pq_table - .compress_into(config.start_point.as_slice(), pq.as_mut_slice()) - .bridge_err()?; - - let config = InternalConfig { - start_id: config.start_id, - start_point: Datum::new(config.start_point, pq), - max_degree: config.max_degree, - metric: config.metric, - }; - - let mut data = HashMap::new(); - data.insert(config.start_id, Vector::Valid(config.start_point.clone())); - - let mut neighbors = HashMap::new(); - neighbors.insert(config.start_id, Vec::new()); - - Ok(Self { - config, - pq_table: pq_table.clone(), - data: RwLock::new(data), - neighbors: RwLock::new(neighbors), - full_reads: Counter::new(), - quant_reads: Counter::new(), - neighbor_reads: Counter::new(), - data_writes: Counter::new(), - neighbor_writes: Counter::new(), - insert_search_accessor_calls: Counter::new(), - }) - } - - /// Return the dimension of the full-precision data. - pub fn dim(&self) -> usize { - self.config.start_point.full().len() - } - - /// Return the maximum degree that can be held by this graph. - pub fn max_degree(&self) -> usize { - self.config.max_degree - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn data(&self) -> RwLockReadGuard<'_, HashMap> { - self.data.read().expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn data_mut(&self) -> RwLockWriteGuard<'_, HashMap> { - self.data.write().expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn neighbors(&self) -> RwLockReadGuard<'_, HashMap>> { - self.neighbors - .read() - .expect("cannot recover from lock poison") - } - - #[expect( - clippy::expect_used, - reason = "DebugProvider is not a production data structure" - )] - fn neighbors_mut(&self) -> RwLockWriteGuard<'_, HashMap>> { - self.neighbors - .write() - .expect("cannot recover from lock poison") - } - - fn is_deleted(&self, id: u32) -> Result { - match self.data().get(&id) { - Some(element) => Ok(element.is_deleted()), - None => Err(InvalidId::Internal(id)), - } - } -} - -/// Light-weight error type for reporting access to an invalid ID. -#[derive(Debug, Clone, Copy, Error)] -pub enum InvalidId { - #[error("internal id {0} not initialized")] - Internal(u32), - #[error("external id {0} not initialized")] - External(u32), - #[error("is start point {0}")] - IsStartPoint(u32), -} - -diskann::always_escalate!(InvalidId); - -impl From for ANNError { - #[track_caller] - fn from(err: InvalidId) -> ANNError { - ANNError::opaque(err) - } -} - -////////////////// -// DataProvider // -////////////////// - -impl DataProvider for DebugProvider { - type Context = DefaultContext; - type InternalId = u32; - type ExternalId = u32; - type Error = InvalidId; - - fn to_internal_id( - &self, - _context: &DefaultContext, - gid: &Self::ExternalId, - ) -> Result { - // Check that the ID actually exists - let valid = self.data().contains_key(gid); - if valid { - Ok(*gid) - } else { - Err(InvalidId::External(*gid)) - } - } - - fn to_external_id( - &self, - _context: &DefaultContext, - id: Self::InternalId, - ) -> Result { - // Check that the ID actually exists - let valid = self.data().contains_key(&id); - if valid { - Ok(id) - } else { - Err(InvalidId::External(id)) - } - } -} - -impl Delete for DebugProvider { - async fn delete( - &self, - _context: &Self::Context, - gid: &Self::ExternalId, - ) -> Result<(), Self::Error> { - if *gid == self.config.start_id { - return Err(InvalidId::IsStartPoint(*gid)); - } - - let mut guard = self.data_mut(); - match guard.entry(*gid) { - hash_map::Entry::Occupied(mut occupied) => { - occupied.get_mut().mark_deleted(); - Ok(()) - } - hash_map::Entry::Vacant(_) => Err(InvalidId::External(*gid)), - } - } - - async fn release( - &self, - _context: &Self::Context, - id: Self::InternalId, - ) -> Result<(), Self::Error> { - if id == self.config.start_id { - return Err(InvalidId::IsStartPoint(id)); - } - - // NOTE: acquire the locks in the order `data` then `neighbors`. - let mut data = self.data_mut(); - let mut neighbors = self.neighbors_mut(); - - let v = data.remove(&id); - let u = neighbors.remove(&id); - - if v.is_none() || u.is_none() { - Err(InvalidId::Internal(id)) - } else { - Ok(()) - } - } - - async fn status_by_internal_id( - &self, - _context: &Self::Context, - id: Self::InternalId, - ) -> Result { - if self.is_deleted(id)? { - Ok(provider::ElementStatus::Deleted) - } else { - Ok(provider::ElementStatus::Valid) - } - } - - fn status_by_external_id( - &self, - context: &Self::Context, - gid: &Self::ExternalId, - ) -> impl Future> + Send { - self.status_by_internal_id(context, *gid) - } -} - -impl provider::SetElement<[f32]> for DebugProvider { - type SetError = ANNError; - - type Guard = provider::NoopGuard; - - fn set_element( - &self, - _context: &Self::Context, - id: &Self::ExternalId, - element: &[f32], - ) -> impl Future> + Send { - #[derive(Debug, Clone, Copy, Error)] - #[error("vector id {0} is already assigned")] - pub struct AlreadyAssigned(u32); - - diskann::always_escalate!(AlreadyAssigned); - - impl From for ANNError { - #[track_caller] - fn from(err: AlreadyAssigned) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } - } - - // NOTE: acquire the locks in the order `vectors` then `neighbors`. - let result = match self.data_mut().entry(*id) { - hash_map::Entry::Occupied(_) => Err(AlreadyAssigned(*id).into()), - hash_map::Entry::Vacant(data) => match self.neighbors_mut().entry(*id) { - hash_map::Entry::Occupied(_) => Err(AlreadyAssigned(*id).into()), - hash_map::Entry::Vacant(neighbors) => { - self.data_writes.increment(); - - let mut pq = vec![0u8; self.pq_table.get_num_chunks()]; - match self - .pq_table - .compress_into(element, pq.as_mut_slice()) - .bridge_err() - { - Ok(()) => { - data.insert(Vector::Valid(Datum::new(element.into(), pq))); - neighbors.insert(Vec::new()); - Ok(provider::NoopGuard::new(*id)) - } - Err(err) => Err(ANNError::from(err)), - } - } - }, - }; - - std::future::ready(result) - } -} - -impl postprocess::DeletionCheck for DebugProvider { - fn deletion_check(&self, id: u32) -> bool { - match self.is_deleted(id) { - Ok(is_deleted) => is_deleted, - Err(err) => { - tracked_warn!("Deletion post-process failed with error {err} - continuing"); - true - } - } - } -} - -/////////////// -// Accessors // -/////////////// - -#[derive(Debug, Clone, Copy, Error)] -#[error("Attempt to access an invalid id: {0}")] -pub struct AccessedInvalidId(u32); - -diskann::always_escalate!(AccessedInvalidId); - -impl From for ANNError { - #[track_caller] - fn from(err: AccessedInvalidId) -> Self { - Self::opaque(err) - } -} - -impl DefaultAccessor for DebugProvider { - type Accessor<'a> = DebugNeighborAccessor<'a>; - - fn default_accessor(&self) -> Self::Accessor<'_> { - DebugNeighborAccessor::new(self) - } -} - -#[derive(Clone, Copy)] -pub struct DebugNeighborAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> DebugNeighborAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for DebugNeighborAccessor<'_> { - type Id = u32; -} - -impl NeighborAccessor for DebugNeighborAccessor<'_> { - fn get_neighbors( - self, - id: Self::Id, - neighbors: &mut AdjacencyList, - ) -> impl Future> + Send { - let result = match self.provider.neighbors().get(&id) { - Some(v) => { - self.provider.neighbor_reads.increment(); - neighbors.overwrite_trusted(v); - Ok(self) - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } -} - -impl NeighborAccessorMut for DebugNeighborAccessor<'_> { - fn set_neighbors( - self, - id: Self::Id, - neighbors: &[Self::Id], - ) -> impl Future> + Send { - assert!(neighbors.len() <= self.provider.config.max_degree); - let result = match self.provider.neighbors_mut().get_mut(&id) { - Some(v) => { - self.provider.neighbor_writes.increment(); - v.clear(); - v.extend_from_slice(neighbors); - Ok(self) - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } - - fn append_vector( - self, - id: Self::Id, - neighbors: &[Self::Id], - ) -> impl Future> + Send { - let result = match self.provider.neighbors_mut().get_mut(&id) { - Some(v) => { - assert!( - v.len().checked_add(neighbors.len()).unwrap() - <= self.provider.config.max_degree, - "current = {:?}, new = {:?}, id = {}", - v, - neighbors, - id - ); - - let check = neighbors.iter().try_for_each(|n| { - if v.contains(n) { - Err(ANNError::message( - ANNErrorKind::Opaque, - format!("id {} is duplicated", n), - )) - } else { - Ok(()) - } - }); - - match check { - Ok(()) => { - self.provider.neighbor_writes.increment(); - v.extend_from_slice(neighbors); - Ok(self) - } - Err(err) => Err(err), - } - } - None => Err(ANNError::opaque(AccessedInvalidId(id))), - }; - - std::future::ready(result) - } -} - -//---------------// -// Full Accessor // -//---------------// - -pub struct FullAccessor<'a> { - provider: &'a DebugProvider, - buffer: Box<[f32]>, -} - -impl<'a> FullAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - let buffer = (0..provider.dim()).map(|_| 0.0).collect(); - Self { provider, buffer } - } - - pub fn provider(&self) -> &DebugProvider { - self.provider - } -} - -impl HasId for FullAccessor<'_> { - type Id = u32; -} - -impl Accessor for FullAccessor<'_> { - type Extended = Box<[f32]>; - type Element<'a> - = &'a [f32] - where - Self: 'a; - type ElementRef<'a> = &'a [f32]; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.full_reads.increment(); - self.buffer.copy_from_slice(v.full()); - Ok(&*self.buffer) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl diskann::provider::CacheableAccessor for FullAccessor<'_> { - type Map = diskann_utils::lifetime::Slice; - - fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32] - where - Self: 'a + 'b, - { - element - } - - fn from_cached<'a>(element: &'a [f32]) -> &'a [f32] - where - Self: 'a, - { - element - } -} - -impl SearchExt for FullAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for FullAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildDistanceComputer for FullAccessor<'_> { - type DistanceComputerError = Panics; - type DistanceComputer = ::Distance; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(f32::distance( - self.provider.config.metric, - Some(self.provider.dim()), - )) - } -} - -impl BuildQueryComputer<[f32]> for FullAccessor<'_> { - type QueryComputerError = Panics; - type QueryComputer = ::QueryDistance; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(f32::query_distance(from, self.provider.config.metric)) - } -} - -impl ExpandBeam<[f32]> for FullAccessor<'_> {} -impl FillSet for FullAccessor<'_> {} - -impl postprocess::AsDeletionCheck for FullAccessor<'_> { - type Checker = DebugProvider; - fn as_deletion_check(&self) -> &Self::Checker { - self.provider - } -} - -//----------------// -// Quant Accessor // -//----------------// - -pub struct QuantAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> QuantAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for QuantAccessor<'_> { - type Id = u32; -} - -impl Accessor for QuantAccessor<'_> { - type Extended = Vec; - type Element<'a> - = Vec - where - Self: 'a; - type ElementRef<'a> = &'a [u8]; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.quant_reads.increment(); - Ok(v.quant().to_owned()) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl SearchExt for QuantAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for QuantAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildQueryComputer<[f32]> for QuantAccessor<'_> { - type QueryComputerError = Panics; - type QueryComputer = pq::distance::QueryComputer>; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(QueryComputer::new( - self.provider.pq_table.clone(), - self.provider.config.metric, - from, - None, - ) - .unwrap()) - } -} - -impl ExpandBeam<[f32]> for QuantAccessor<'_> {} - -impl postprocess::AsDeletionCheck for QuantAccessor<'_> { - type Checker = DebugProvider; - fn as_deletion_check(&self) -> &Self::Checker { - self.provider - } -} - -//-----------------// -// Hybrid Accessor // -//-----------------// - -pub struct HybridAccessor<'a> { - provider: &'a DebugProvider, -} - -impl<'a> HybridAccessor<'a> { - pub fn new(provider: &'a DebugProvider) -> Self { - Self { provider } - } -} - -impl HasId for HybridAccessor<'_> { - type Id = u32; -} - -impl Accessor for HybridAccessor<'_> { - type Extended = Hybrid, Vec>; - type Element<'a> - = Hybrid, Vec> - where - Self: 'a; - type ElementRef<'a> = Hybrid<&'a [f32], &'a [u8]>; - - type GetError = AccessedInvalidId; - - fn get_element( - &mut self, - id: Self::Id, - ) -> impl Future, Self::GetError>> + Send { - let result = match self.provider.data().get(&id) { - Some(v) => { - self.provider.full_reads.increment(); - Ok(Hybrid::Full(v.full().to_owned())) - } - None => Err(AccessedInvalidId(id)), - }; - - std::future::ready(result) - } -} - -impl SearchExt for HybridAccessor<'_> { - fn starting_points(&self) -> impl Future>> + Send { - futures_util::future::ok(vec![self.provider.config.start_id]) - } -} - -impl<'a> DelegateNeighbor<'a> for HybridAccessor<'_> { - type Delegate = DebugNeighborAccessor<'a>; - - fn delegate_neighbor(&'a mut self) -> Self::Delegate { - DebugNeighborAccessor::new(self.provider) - } -} - -impl BuildDistanceComputer for HybridAccessor<'_> { - type DistanceComputerError = Panics; - type DistanceComputer = distances::pq::HybridComputer; - - fn build_distance_computer( - &self, - ) -> Result { - Ok(distances::pq::HybridComputer::new( - DistanceComputer::new(self.provider.pq_table.clone(), self.provider.config.metric) - .unwrap(), - f32::distance(self.provider.config.metric, Some(self.provider.dim())), - )) - } -} - -impl FillSet for HybridAccessor<'_> { - async fn fill_set( - &mut self, - set: &mut HashMap, - itr: Itr, - ) -> Result<(), Self::GetError> - where - Itr: Iterator + Send + Sync, - { - let threshold = 1; // one full vec per fill - let data = self.provider.data(); - itr.enumerate().for_each(|(i, id)| { - let e = set.entry(id); - if i < threshold { - e.and_modify(|v| { - if !v.is_full() { - let element = data.get(&id).unwrap(); - *v = Hybrid::Full(element.full().to_owned()); - } - }) - .or_insert_with(|| { - let element = data.get(&id).unwrap(); - Hybrid::Full(element.full().to_owned()) - }); - } else { - e.or_insert_with(|| { - let element = data.get(&id).unwrap(); - Hybrid::Quant(element.quant().to_owned()) - }); - } - }); - Ok(()) - } -} - -//////////////// -// Strategies // -//////////////// - -impl SearchStrategy for Internal { - type QueryComputer = ::QueryDistance; - type PostProcessor = postprocess::RemoveDeletedIdsAndCopy; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = FullAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl SearchStrategy for FullPrecision { - type QueryComputer = ::QueryDistance; - type PostProcessor = Pipeline; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = FullAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl SearchStrategy for Internal { - type QueryComputer = pq::distance::QueryComputer>; - type PostProcessor = postprocess::RemoveDeletedIdsAndCopy; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = QuantAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl SearchStrategy for Quantized { - type QueryComputer = pq::distance::QueryComputer>; - type PostProcessor = Pipeline; - type SearchAccessorError = Panics; - type SearchAccessor<'a> = QuantAccessor<'a>; - - fn search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::SearchAccessorError> { - Ok(QuantAccessor::new(provider)) - } - - fn post_processor(&self) -> Self::PostProcessor { - Default::default() - } -} - -impl PruneStrategy for FullPrecision { - type DistanceComputer = ::Distance; - type PruneAccessor<'a> = FullAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider)) - } -} - -impl<'a> AsElement<&'a [f32]> for FullAccessor<'a> { - type Error = Panics; - fn as_element( - &mut self, - vector: &'a [f32], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(vector)) - } -} - -impl PruneStrategy for Quantized { - type DistanceComputer = distances::pq::HybridComputer; - type PruneAccessor<'a> = HybridAccessor<'a>; - type PruneAccessorError = diskann::error::Infallible; - - fn prune_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - ) -> Result, Self::PruneAccessorError> { - Ok(HybridAccessor::new(provider)) - } -} - -impl<'a> AsElement<&'a [f32]> for HybridAccessor<'a> { - type Error = Panics; - fn as_element( - &mut self, - vector: &'a [f32], - _id: Self::Id, - ) -> impl Future, Self::Error>> + Send { - std::future::ready(Ok(Hybrid::Full(vector.to_vec()))) - } -} - -impl InsertStrategy for FullPrecision { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn insert_search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - provider.insert_search_accessor_calls.increment(); - self.search_accessor(provider, context) - } -} - -impl InsertStrategy for Quantized { - type PruneStrategy = Self; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn insert_search_accessor<'a>( - &'a self, - provider: &'a DebugProvider, - context: &'a DefaultContext, - ) -> Result, Self::SearchAccessorError> { - provider.insert_search_accessor_calls.increment(); - self.search_accessor(provider, context) - } -} - -impl InplaceDeleteStrategy for FullPrecision { - type DeleteElement<'a> = [f32]; - type DeleteElementGuard = Vec; - type DeleteElementError = Panics; - type PruneStrategy = Self; - type SearchStrategy = Internal; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) - } - - fn get_delete_element<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - id: ::InternalId, - ) -> impl Future> + Send - { - futures_util::future::ok(provider.data().get(&id).unwrap().full().to_vec()) - } -} - -impl InplaceDeleteStrategy for Quantized { - type DeleteElement<'a> = [f32]; - type DeleteElementGuard = Vec; - type DeleteElementError = Panics; - type PruneStrategy = Self; - type SearchStrategy = Internal; - - fn prune_strategy(&self) -> Self::PruneStrategy { - *self - } - - fn search_strategy(&self) -> Self::SearchStrategy { - Internal(*self) - } - - fn get_delete_element<'a>( - &'a self, - provider: &'a DebugProvider, - _context: &'a ::Context, - id: ::InternalId, - ) -> impl Future> + Send - { - futures_util::future::ok(provider.data().get(&id).unwrap().full().to_vec()) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use diskann::{ - graph::{self, DiskANNIndex}, - provider::{Guard, SetElement}, - utils::async_tools::VectorIdBoxSlice, - }; - use diskann_vector::{PureDistanceFunction, distance::SquaredL2}; - use rstest::rstest; - - use super::*; - use crate::{ - index::diskann_async::{ - tests::{ - GenerateGrid, PagedSearch, check_grid_search, populate_data, populate_graph, squish, - }, - train_pq, - }, - test_utils::groundtruth, - utils, - }; - - #[tokio::test] - async fn basic_operations() { - let dim = 2; - let ctx = &DefaultContext; - - let debug_config = DebugConfig { - start_id: u32::MAX, - start_point: vec![0.0; dim], - max_degree: 10, - metric: Metric::L2, - }; - - let vectors = [vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]]; - let pq_table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - let provider = DebugProvider::new(debug_config, Arc::new(pq_table)).unwrap(); - - provider - .set_element(ctx, &0, &[1.0, 1.0]) - .await - .unwrap() - .complete() - .await; - - // internal id = external id - assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0); - assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0); - - let mut accessor = FullAccessor::new(&provider); - - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 1); - - let mut neighbors = AdjacencyList::new(); - - let accessor = provider.default_accessor(); - let res = accessor.get_neighbors(0, &mut neighbors).await; - assert!(res.is_ok()); - assert_eq!(provider.neighbor_reads.get(), 1); - - let accessor = provider.default_accessor(); - let res = accessor.set_neighbors(0, &[1, 2, 3]).await; - assert!(res.is_ok()); - assert_eq!(provider.neighbor_writes.get(), 1); - - // delete and release vector 0 - let res = provider.delete(&DefaultContext, &0).await; - assert!(res.is_ok()); - assert_eq!( - ElementStatus::Deleted, - provider - .status_by_external_id(&DefaultContext, &0) - .await - .unwrap() - ); - - let mut accessor = FullAccessor::new(&provider); - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 2); - - let mut accessor = HybridAccessor::new(&provider); - let res = accessor.get_element(0).await; - assert!(res.is_ok()); - assert_eq!(provider.full_reads.get(), 3); - - // Releasing should make the element unreachable. - let res = provider.release(&DefaultContext, 0).await; - assert!(res.is_ok()); - assert!( - provider - .status_by_external_id(&DefaultContext, &0) - .await - .is_err() - ); - } - - pub fn new_quant_index( - index_config: graph::Config, - debug_config: DebugConfig, - pq_table: FixedChunkPQTable, - ) -> Arc> { - let data = DebugProvider::new(debug_config, Arc::new(pq_table)).unwrap(); - Arc::new(DiskANNIndex::new(index_config, data, None)) - } - - #[rstest] - #[case(1, 100)] - #[case(3, 7)] - #[case(4, 5)] - #[tokio::test] - async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { - let l = 10; - let max_degree = 2 * dim; - let num_points = (grid_size).pow(dim as u32); - let start_id = u32::MAX; - - let index_config = graph::config::Builder::new( - max_degree, - graph::config::MaxDegree::default_slack(), - l, - (Metric::L2).into(), - ) - .build() - .unwrap(); - - let debug_config = DebugConfig { - start_id, - start_point: vec![grid_size as f32; dim], - max_degree, - metric: Metric::L2, - }; - - let adjacency_lists = match dim { - 1 => utils::generate_1d_grid_adj_list(grid_size as u32), - 3 => utils::genererate_3d_grid_adj_list(grid_size as u32), - 4 => utils::generate_4d_grid_adj_list(grid_size as u32), - _ => panic!("Unsupported number of dimensions"), - }; - let mut vectors = f32::generate_grid(dim, grid_size); - - assert_eq!(adjacency_lists.len(), num_points); - assert_eq!(vectors.len(), num_points); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - let index = new_quant_index(index_config, debug_config, table); - { - let mut neighbor_accessor = index.provider().default_accessor(); - populate_data(index.provider(), &DefaultContext, &vectors).await; - populate_graph(&mut neighbor_accessor, &adjacency_lists).await; - - // Set the adjacency list for the start point. - neighbor_accessor - .set_neighbors(start_id, &[num_points as u32 - 1]) - .await - .unwrap(); - } - - // The corpus of actual vectors consists of all but the last point, which we use - // as the start point. - // - // So, when we compute the corpus used during groundtruth generation, we take all - // but this last point. - let corpus: diskann_utils::views::Matrix = - squish(vectors.iter().take(num_points), dim); - - let mut paged_tests = Vec::new(); - - // Test with the zero query. - let query = vec![0.0; dim]; - let gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b)); - paged_tests.push(PagedSearch::new(query, gt)); - - // Test with the start point to ensure it is filtered out. - let query = vectors.last().unwrap(); - let gt = groundtruth(corpus.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); - paged_tests.push(PagedSearch::new(query.clone(), gt)); - - // Unfortunately - this is needed for the `check_grid_search` test. - vectors.push(index.provider().config.start_point.full().to_owned()); - check_grid_search(&index, &vectors, &paged_tests, FullPrecision, Quantized).await; - } - - #[rstest] - #[tokio::test] - async fn grid_search_with_build( - #[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize), - ) { - let dim = dim_and_size.0; - let grid_size = dim_and_size.1; - let start_id = u32::MAX; - - let l = 10; - - // NOTE: Be careful changing `max_degree`. It needs to be high enough that the - // graph is navigable, but low enough that the batch parallel handling inside - // `multi_insert` is needed for the multi-insert graph to be navigable. - // - // With the current configured values, removing the other elements in the batch - // from the visited set during `multi_insert` results in a graph failure. - let max_degree = 2 * dim; - - let num_points = (grid_size).pow(dim as u32); - - let index_config = graph::config::Builder::new_with( - max_degree, - graph::config::MaxDegree::default_slack(), - l, - (Metric::L2).into(), - |b| { - b.max_minibatch_par(10); - }, - ) - .build() - .unwrap(); - - let debug_config = DebugConfig { - start_id, - start_point: vec![grid_size as f32; dim], - max_degree: index_config.max_degree().into(), - metric: Metric::L2, - }; - - let mut vectors = f32::generate_grid(dim, grid_size); - assert_eq!(vectors.len(), num_points); - - // This is a little subtle, but we need `vectors` to contain the start point as - // its last element, but we **don't** want to include it in the index build. - // - // This basically means that we need to be careful with index initialization. - vectors.push(vec![grid_size as f32; dim]); - - let table = train_pq( - squish(vectors.iter(), dim).as_view(), - 2.min(dim), // Number of PQ chunks is bounded by the dimension. - &mut crate::utils::create_rnd_from_seed_in_tests(0x04a8832604476965), - 1usize, - ) - .unwrap(); - - // Initialize an index for a new round of building. - let init_index = - || new_quant_index(index_config.clone(), debug_config.clone(), table.clone()); - - // Build with full-precision single insert - { - let index = init_index(); - let ctx = DefaultContext; - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(FullPrecision, &ctx, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with quantized single insert - { - let index = init_index(); - let ctx = DefaultContext; - for (i, v) in vectors.iter().take(num_points).enumerate() { - index - .insert(Quantized, &ctx, &(i as u32), v.as_slice()) - .await - .unwrap(); - } - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with full-precision multi-insert - { - let index = init_index(); - let ctx = DefaultContext; - let batch: Box<[_]> = vectors - .iter() - .take(num_points) - .enumerate() - .map(|(id, v)| VectorIdBoxSlice::new(id as u32, v.as_slice().into())) - .collect(); - - index - .multi_insert(FullPrecision, &ctx, batch) - .await - .unwrap(); - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "multi-insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - - // Build with quantized multi-insert - { - let index = init_index(); - let ctx = DefaultContext; - let batch: Box<[_]> = vectors - .iter() - .take(num_points) - .enumerate() - .map(|(id, v)| VectorIdBoxSlice::new(id as u32, v.as_slice().into())) - .collect(); - - index.multi_insert(Quantized, &ctx, batch).await.unwrap(); - - // Ensure the `insert_search_accessor` API is invoked. - assert_eq!( - index.provider().insert_search_accessor_calls.get(), - num_points, - "multi-insert should invoke `insert_search_accessor`", - ); - - check_grid_search(&index, &vectors, &[], FullPrecision, Quantized).await; - } - } -} diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 3d89359e2..cf719e730 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -39,7 +39,3 @@ pub mod bf_tree; // Caching proxy provider to accelerate slow providers. #[cfg(feature = "bf_tree")] pub mod caching; - -// Debug provider for testing. -#[cfg(test)] -pub mod debug_provider; From 3df382bdf9e9c3ca76cae2443393c5da07208244 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 19 Mar 2026 23:42:59 +0000 Subject: [PATCH 3/4] Rewire caching/example.rs to use diskann::graph::test::provider instead of DebugProvider Co-authored-by: harsha-simhadri <5590673+harsha-simhadri@users.noreply.github.com> --- diskann-providers/Cargo.toml | 1 + .../graph/provider/async_/caching/example.rs | 704 ++++++++++++++++++ .../graph/provider/async_/caching/mod.rs | 3 + diskann/src/graph/test/provider.rs | 23 + 4 files changed, 731 insertions(+) create mode 100644 diskann-providers/src/model/graph/provider/async_/caching/example.rs diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index 65a7b0e0b..c917b0a57 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -47,6 +47,7 @@ vfs = { workspace = true, optional = true } [dev-dependencies] approx.workspace = true criterion.workspace = true +diskann = { workspace = true, features = ["testing"] } diskann-utils = { workspace = true, features = ["testing"] } iai-callgrind.workspace = true itertools.workspace = true diff --git a/diskann-providers/src/model/graph/provider/async_/caching/example.rs b/diskann-providers/src/model/graph/provider/async_/caching/example.rs new file mode 100644 index 000000000..a80b0bf26 --- /dev/null +++ b/diskann-providers/src/model/graph/provider/async_/caching/example.rs @@ -0,0 +1,704 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! An example cache for demonstrating how to set up a [`CachingProvider`] with an inner +//! data provider. The inner provider used in this module is the test provider from +//! [`diskann::graph::test::provider`]. + +use diskann::{ + graph::{AdjacencyList, test::provider as test_provider}, + provider::{self as core_provider}, +}; +use diskann_utils::future::AsyncFriendly; +use diskann_vector::distance::Metric; + +use super::{ + bf_cache::{self, Cache}, + error::CacheAccessError, + provider::{self as cache_provider, CachingError, NeighborStatus}, + utils::{CacheKey, Graph, HitStats, KeyGen, LocalStats}, +}; + +/////////////////// +// Example Cache // +/////////////////// + +/// A representation for the adjacency list term. +/// +/// The `repr(u8)` allows this to be converted to an integer and thus be used in a +/// `bytemuck::Pod` struct. +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +enum Tag { + AdjacencyList = 0, + Vector = 1, +} + +const TAGS: [Tag; 2] = [Tag::AdjacencyList, Tag::Vector]; + +impl KeyGen for Tag { + type Key = CacheKey; + fn generate(&self, id: u32) -> Self::Key { + CacheKey::new(id, (*self as u8).into()) + } +} + +/// An example cache for compatibility with the full-precision provider. +/// +/// This is meant largely for demonstration purposes, showing how to build a cache tailored +/// for a certain provider. +#[derive(Debug)] +pub struct ExampleCache { + cache: Cache, + // This field records a collection of uncacheable IDs to ensure the cache skipping + // logic in the provider works as expected. + uncacheable: Option>, + neighbor_stats: HitStats, + vector_stats: HitStats, +} + +impl ExampleCache { + /// Construct a new cache with the given capacity. + pub fn new( + bytes: diskann_quantization::num::PowerOfTwo, + uncacheable: Option>, + ) -> Self { + Self { + cache: Cache::new(bytes).unwrap(), + uncacheable, + neighbor_stats: HitStats::new(), + vector_stats: HitStats::new(), + } + } + + /// Invalidate all cached items associated with `id`. + fn invalidate(&self, id: u32) { + // Since we use a unified cache under the scenes type the list types disambiguated + // by a tag - we need to delete all possible tags. + for tag in TAGS { + self.cache.delete(tag.generate(id)) + } + } + + /// Return a cache accessor for adjacency list terms. + fn neighbors(&self, max_degree: usize) -> Graph<'_, u32, Tag> { + Graph::new( + &self.cache, + max_degree, + Tag::AdjacencyList, + &self.neighbor_stats, + ) + } +} + +impl cache_provider::Evict for ExampleCache { + fn evict(&self, id: u32) { + self.invalidate(id) + } +} + +/// The `CacheAccessor` is the `C` in `cache_provider::CachingAccessor` and interfaces the +/// accessor with the underlying cache. +#[derive(Debug)] +pub struct CacheAccessor<'a, T> { + graph: Graph<'a, u32, Tag>, + cacher: T, + stats: LocalStats<'a>, + uncacheable: Option<&'a [u32]>, + /// Key generation for the element being accessed. + keygen: Tag, +} + +impl cache_provider::ElementCache> + for CacheAccessor<'_, bf_cache::VecCacher> +where + T: AsyncFriendly + bytemuck::Pod + Default + std::fmt::Debug, +{ + type Error = CacheAccessError; + + fn get_cached(&mut self, k: u32) -> Result, CacheAccessError> { + match self + .graph + .cache() + .get(self.keygen.generate(k), &mut self.cacher) + { + Ok(Some(value)) => { + self.stats.hit(); + Ok(Some(value)) + } + Ok(None) => { + self.stats.miss(); + Ok(None) + } + Err(err) => Err(CacheAccessError::read(k, err)), + } + } + + fn set_cached(&mut self, k: u32, element: &&[T]) -> Result<(), CacheAccessError> { + self.graph + .cache() + .set(self.keygen.generate(k), &mut self.cacher, element) + .map_err(|err| CacheAccessError::write(k, err)) + } +} + +impl cache_provider::NeighborCache for CacheAccessor<'_, T> +where + T: AsyncFriendly, +{ + type Error = CacheAccessError; + + fn try_get_neighbors( + &mut self, + id: u32, + neighbors: &mut AdjacencyList, + ) -> Result { + if let Some(uncacheable) = self.uncacheable + && uncacheable.contains(&id) + { + self.graph.stats_mut().miss(); + Ok(NeighborStatus::Uncacheable) + } else { + self.graph.try_get_neighbors(id, neighbors) + } + } + + fn set_neighbors(&mut self, id: u32, neighbors: &[u32]) -> Result<(), CacheAccessError> { + self.graph.set_neighbors(id, neighbors) + } + + fn invalidate_neighbors(&mut self, id: u32) { + self.graph.invalidate_neighbors(id) + } +} + +///////////////////// +// Provider Bridge // +///////////////////// + +impl<'a> cache_provider::AsCacheAccessorFor<'a, test_provider::Accessor<'a>> for ExampleCache { + type Accessor = CacheAccessor<'a, bf_cache::VecCacher>; + type Error = diskann::error::Infallible; + fn as_cache_accessor_for( + &'a self, + inner: test_provider::Accessor<'a>, + ) -> Result< + cache_provider::CachingAccessor, Self::Accessor>, + Self::Error, + > { + let provider = inner.provider(); + let cache_accessor = CacheAccessor { + graph: self.neighbors(provider.max_degree()), + cacher: bf_cache::VecCacher::::new(provider.dim()), + uncacheable: self.uncacheable.as_deref(), + stats: LocalStats::new(&self.vector_stats), + keygen: Tag::Vector, + }; + Ok(cache_provider::CachingAccessor::new(inner, cache_accessor)) + } +} + +impl<'a> cache_provider::CachedFillSet>> + for test_provider::Accessor<'a> +{ +} + +impl<'a> cache_provider::CachedAsElement<&'a [f32], CacheAccessor<'a, bf_cache::VecCacher>> + for test_provider::Accessor<'a> +{ + type Error = CachingError; + async fn cached_as_element<'b>( + &'b mut self, + cache: &'b mut CacheAccessor<'a, bf_cache::VecCacher>, + _vector: &'a [f32], + id: u32, + ) -> Result, Self::Error> { + cache_provider::get_or_insert(self, cache, id).await + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use diskann::{ + graph::DiskANNIndex, + provider::{ + Accessor, DataProvider, Delete, NeighborAccessor, NeighborAccessorMut, SetElement, + }, + utils::async_tools, + }; + use diskann_quantization::num::PowerOfTwo; + use rstest::rstest; + + use crate::{ + index::diskann_async::tests as async_tests, + model::graph::provider::async_::caching::provider::{AsCacheAccessorFor, CachingProvider}, + }; + + fn test_provider( + uncacheable: Option>, + ) -> CachingProvider { + let dim = 2; + let max_degree = 10; + let start_id = u32::MAX; + + let config = test_provider::Config::new( + Metric::L2, + max_degree, + test_provider::StartPoint::new(start_id, vec![0.0; dim]), + ) + .unwrap(); + + CachingProvider::new( + test_provider::Provider::new(config), + ExampleCache::new(PowerOfTwo::new(1024 * 16).unwrap(), uncacheable), + ) + } + + fn ctx() -> test_provider::Context { + test_provider::Context::new() + } + + #[tokio::test] + async fn basic_operations_happy_path() { + let provider = test_provider(None); + let ctx = ctx(); + + // Translations do not yet exist. + assert!(provider.to_external_id(&ctx, 0).is_err()); + assert!(provider.to_internal_id(&ctx, &0).is_err()); + + assert_eq!(provider.inner().metrics().set_vector, 0); + provider.set_element(&ctx, &0, &[1.0, 2.0]).await.unwrap(); + assert_eq!( + provider.inner().metrics().set_vector, + 1 /* increased */ + ); + + assert_eq!(provider.to_external_id(&ctx, 0).unwrap(), 0); + assert_eq!(provider.to_internal_id(&ctx, &0).unwrap(), 0); + + // Retrieval of a valid element. + let mut accessor = provider + .cache() + .as_cache_accessor_for(test_provider::Accessor::new(provider.inner())) + .unwrap(); + + // Hit served from the underlying provider (cache miss). + let element = accessor.get_element(0).await.unwrap(); + assert_eq!(element, &[1.0, 2.0]); + assert_eq!( + accessor.cache().stats.get_local_misses(), + 1, /* increased */ + ); + assert_eq!(accessor.cache().stats.get_local_hits(), 0); + + // This time, the hit is served from the underlying cache. + let element = accessor.get_element(0).await.unwrap(); + assert_eq!(element, &[1.0, 2.0]); + assert_eq!(accessor.cache().stats.get_local_misses(), 1); + assert_eq!( + accessor.cache().stats.get_local_hits(), + 1, /* increased */ + ); + + // Adjacency List from Underlying + assert_eq!(provider.inner().metrics().set_neighbors, 0); + accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); + assert_eq!( + provider.inner().metrics().set_neighbors, + 1, /* increased */ + ); + + let mut list = AdjacencyList::new(); + assert_eq!(provider.inner().metrics().get_neighbors, 0); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!( + provider.inner().metrics().get_neighbors, + 1, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 1, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0); + assert_eq!(&*list, &[1, 2, 3]); + + // Adjacency List From Cache + list.clear(); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!(&*list, &[1, 2, 3]); + assert_eq!(provider.inner().metrics().get_neighbors, 1); + assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); + assert_eq!( + accessor.cache().graph.stats().get_local_hits(), + 1, /* increased */ + ); + + // If we invalidate the key - these elements should be retrieved from the backing + // provider instead. + provider.cache().invalidate(0); + + let element = accessor.get_element(0).await.unwrap(); + assert_eq!(element, &[1.0, 2.0]); + assert_eq!( + accessor.cache().stats.get_local_misses(), + 2, /* increased */ + ); + assert_eq!(accessor.cache().stats.get_local_hits(), 1); + + // Once more from the cache. + let element = accessor.get_element(0).await.unwrap(); + assert_eq!(element, &[1.0, 2.0]); + assert_eq!(accessor.cache().stats.get_local_misses(), 2); + assert_eq!( + accessor.cache().stats.get_local_hits(), + 2, /* increased */ + ); + + list.clear(); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!(&*list, &[1, 2, 3]); + assert_eq!( + provider.inner().metrics().get_neighbors, + 2, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 2, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); + + // Setting adjacency lists invalidates the cache. + accessor.set_neighbors(0, &[2, 3, 4]).await.unwrap(); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!(&*list, &[2, 3, 4]); + assert_eq!( + provider.inner().metrics().set_neighbors, + 2, /* increased */ + ); + assert_eq!( + provider.inner().metrics().get_neighbors, + 3, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 3, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); + + accessor.append_vector(0, &[1]).await.unwrap(); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!(&*list, &[2, 3, 4, 1]); + + assert_eq!( + provider.inner().metrics().set_neighbors, + 2, + "append_vector doesn't go through set_neighbors counter" + ); + assert_eq!( + provider.inner().metrics().get_neighbors, + 4, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 4, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); + + // Deletion. + assert_eq!( + provider.status_by_internal_id(&ctx, 0).await.unwrap(), + core_provider::ElementStatus::Valid + ); + assert_eq!( + provider.status_by_external_id(&ctx, &0).await.unwrap(), + core_provider::ElementStatus::Valid + ); + assert!(provider.status_by_internal_id(&ctx, 1).await.is_err()); + assert!(provider.status_by_external_id(&ctx, &1).await.is_err()); + + provider.delete(&ctx, &0).await.unwrap(); + + assert_eq!( + provider.status_by_internal_id(&ctx, 0).await.unwrap(), + core_provider::ElementStatus::Deleted + ); + assert_eq!( + provider.status_by_external_id(&ctx, &0).await.unwrap(), + core_provider::ElementStatus::Deleted + ); + assert!(provider.status_by_internal_id(&ctx, 1).await.is_err()); + assert!(provider.status_by_external_id(&ctx, &1).await.is_err()); + + // Accessing the deleted element is still valid. + let element = accessor.get_element(0).await.unwrap(); + assert_eq!(element, &[1.0, 2.0]); + assert_eq!(accessor.cache().stats.get_local_misses(), 2); + assert_eq!( + accessor.cache().stats.get_local_hits(), + 3, /* increased */ + ); + + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!(&*list, &[2, 3, 4, 1]); + assert_eq!(provider.inner().metrics().get_neighbors, 4); + assert_eq!(accessor.cache().graph.stats().get_local_misses(), 4); + assert_eq!( + accessor.cache().graph.stats().get_local_hits(), + 2, /* increased */ + ); + + provider.release(&ctx, 0).await.unwrap(); + assert!(provider.status_by_internal_id(&ctx, 0).await.is_err()); + assert!(provider.status_by_external_id(&ctx, &0).await.is_err()); + + assert!(accessor.get_element(0).await.is_err()); + assert_eq!( + accessor.cache().stats.get_local_misses(), + 3 /* increased */ + ); + assert_eq!(accessor.cache().stats.get_local_hits(), 3); + + assert!(accessor.get_neighbors(0, &mut list).await.is_err()); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 5 /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 2); + + // Ensure that the stats get properly recorded when the accessor is dropped. + let vector_hits = accessor.cache().stats.get_local_hits(); + let vector_misses = accessor.cache().stats.get_local_misses(); + let neighbor_hits = accessor.cache().graph.stats().get_local_hits(); + let neighbor_misses = accessor.cache().graph.stats().get_local_misses(); + + std::mem::drop(accessor); + + assert_eq!(provider.cache().vector_stats.get_hits(), vector_hits); + assert_eq!(provider.cache().vector_stats.get_misses(), vector_misses); + + assert_eq!(provider.cache().neighbor_stats.get_hits(), neighbor_hits); + assert_eq!( + provider.cache().neighbor_stats.get_misses(), + neighbor_misses + ); + } + + #[tokio::test] + async fn test_uncacheable() { + // Test that returning `Uncacheable` for an adjacency list is handled correctly by + // the provider and a call to `set_neighbors` is not made. + let uncacheable = u32::MAX; + let provider = test_provider(Some(vec![uncacheable])); + let ctx = ctx(); + + let mut accessor = provider + .cache() + .as_cache_accessor_for(test_provider::Accessor::new(provider.inner())) + .unwrap(); + + provider.set_element(&ctx, &0, &[1.0, 2.0]).await.unwrap(); + + //---------------// + // Cacheable IDs // + //---------------// + + // Adjacency List from Underlying + assert_eq!(provider.inner().metrics().set_neighbors, 0); + accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap(); + assert_eq!( + provider.inner().metrics().set_neighbors, + 1, /* increased */ + ); + + let mut list = AdjacencyList::new(); + assert_eq!(provider.inner().metrics().get_neighbors, 0); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!( + provider.inner().metrics().get_neighbors, + 1, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 1, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0); + assert_eq!(&*list, &[1, 2, 3]); + + // Adjacency List From Cache + list.clear(); + accessor.get_neighbors(0, &mut list).await.unwrap(); + assert_eq!(&*list, &[1, 2, 3]); + assert_eq!(provider.inner().metrics().get_neighbors, 1); + assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1); + assert_eq!( + accessor.cache().graph.stats().get_local_hits(), + 1, /* increased */ + ); + + //-----------------// + // Uncacheable IDs // + //-----------------// + + assert_eq!(provider.inner().metrics().set_neighbors, 1); + accessor.set_neighbors(uncacheable, &[4, 5]).await.unwrap(); + assert_eq!( + provider.inner().metrics().set_neighbors, + 2, /* increased */ + ); + + // The retrieval is served by the inner provider. + assert_eq!(provider.inner().metrics().get_neighbors, 1); + accessor + .get_neighbors(uncacheable, &mut list) + .await + .unwrap(); + assert_eq!( + provider.inner().metrics().get_neighbors, + 2, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 2, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); + assert_eq!(&*list, &[4, 5]); + + // Again, retrieval is served by the inner provider. + assert_eq!(provider.inner().metrics().get_neighbors, 2); + accessor + .get_neighbors(uncacheable, &mut list) + .await + .unwrap(); + assert_eq!( + provider.inner().metrics().get_neighbors, + 3, /* increased */ + ); + assert_eq!( + accessor.cache().graph.stats().get_local_misses(), + 3, /* increased */ + ); + assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1); + assert_eq!(&*list, &[4, 5]); + } + + //----------------// + // Standard Tests // + //----------------// + + fn check_stats(caching: &CachingProvider) { + let metrics = caching.inner().metrics(); + let cache = caching.cache(); + + println!("neighbor reads: {}", metrics.get_neighbors); + println!("neighbor writes: {}", metrics.set_neighbors); + println!("vector reads: {}", metrics.get_vector); + println!("vector writes: {}", metrics.set_vector); + + println!("neighbor hits: {}", cache.neighbor_stats.get_hits()); + println!("neighbor misses: {}", cache.neighbor_stats.get_misses()); + println!("vector hits: {}", cache.vector_stats.get_hits()); + println!("vector misses: {}", cache.vector_stats.get_misses()); + + // Neighbors + assert_eq!(metrics.get_neighbors, cache.neighbor_stats.get_misses()); + + // Vectors + assert_eq!(metrics.get_vector, cache.vector_stats.get_misses()); + } + + #[rstest] + #[tokio::test] + async fn grid_search_with_build( + #[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize), + ) { + let dim = dim_and_size.0; + let grid_size = dim_and_size.1; + let l = 10; + let start_id = u32::MAX; + let start_point = vec![grid_size as f32; dim]; + let metric = Metric::L2; + let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); + + // NOTE: Be careful changing `max_degree`. It needs to be high enough that the + // graph is navigable, but low enough that the batch parallel handling inside + // `multi_insert` is needed for the multi-insert graph to be navigable. + // + // With the current configured values, removing the other elements in the batch + // from the visited set during `multi_insert` results in a graph failure. + let max_degree = 2 * dim; + let num_points = (grid_size).pow(dim as u32); + + let vectors = ::generate_grid(dim, grid_size); + + let index_config = diskann::graph::config::Builder::new_with( + max_degree, + diskann::graph::config::MaxDegree::default_slack(), + l, + metric.into(), + |b| { + b.max_minibatch_par(10); + }, + ) + .build() + .unwrap(); + + let config = test_provider::Config::new( + metric, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), + ) + .unwrap(); + assert_eq!(vectors.len(), num_points); + + // Initialize an index for a new round of building. + let init_index = || { + let provider = CachingProvider::new( + test_provider::Provider::new(config.clone()), + ExampleCache::new(cache_size, None), + ); + Arc::new(DiskANNIndex::new(index_config.clone(), provider, None)) + }; + + let strategy = cache_provider::Cached::new(test_provider::Strategy::new()); + let ctx = ctx(); + + // Build with full-precision single insert + { + let index = init_index(); + for (i, v) in vectors.iter().take(num_points).enumerate() { + index + .insert(strategy, &ctx, &(i as u32), v.as_slice()) + .await + .unwrap(); + } + + check_stats(index.provider()); + } + + // Build with full-precision multi-insert + { + let index = init_index(); + let batch: Box<[_]> = vectors + .iter() + .take(num_points) + .enumerate() + .map(|(id, v)| async_tools::VectorIdBoxSlice::new(id as u32, v.as_slice().into())) + .collect(); + + index.multi_insert(strategy, &ctx, batch).await.unwrap(); + + check_stats(index.provider()); + } + } +} diff --git a/diskann-providers/src/model/graph/provider/async_/caching/mod.rs b/diskann-providers/src/model/graph/provider/async_/caching/mod.rs index aeac402ef..dc93ca52c 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/mod.rs @@ -7,3 +7,6 @@ pub mod bf_cache; pub mod error; pub mod provider; pub mod utils; + +#[cfg(test)] +pub mod example; diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 6a1857fac..56bc7bdc5 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -893,6 +893,11 @@ impl<'a> Accessor<'a> { get_vector: provider.get_vector.local(), } } + + /// Return a reference to the underlying provider. + pub fn provider(&self) -> &Provider { + self.provider + } } impl provider::HasId for Accessor<'_> { @@ -966,6 +971,24 @@ impl glue::SearchExt for Accessor<'_> { impl glue::ExpandBeam<[f32]> for Accessor<'_> {} impl glue::FillSet for Accessor<'_> {} +impl provider::CacheableAccessor for Accessor<'_> { + type Map = diskann_utils::lifetime::Slice; + + fn as_cached<'a, 'b>(element: &'a &'b [f32]) -> &'a &'b [f32] + where + Self: 'a + 'b, + { + element + } + + fn from_cached<'a>(element: &'a [f32]) -> &'a [f32] + where + Self: 'a, + { + element + } +} + #[derive(Debug, Default, Clone, Copy)] pub struct Strategy { _phantom: (), From 08dd1f8ef581d01275071a9aac63d52d8b84fd0c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 19:46:57 +0000 Subject: [PATCH 4/4] Restore deleted grid_search and test_inplace_delete_2d tests Add DeletionAwareStrategy and FilterDeletedCopyIds to the test provider to support inplace_delete which requires post-processing that filters deleted IDs from search results. Rewrite grid_search to inline the search verification logic (the original check_grid_search helper requires DefaultContext, which the test provider doesn't use)." Co-authored-by: harsha-simhadri <5590673+harsha-simhadri@users.noreply.github.com> Agent-Logs-Url: https://github.com/microsoft/DiskANN/sessions/fe42fea3-8fb1-458c-a2c7-c2ed4d577b68 --- .../graph/provider/async_/caching/example.rs | 331 +++++++++++++++++- diskann/src/graph/test/provider.rs | 142 +++++++- 2 files changed, 470 insertions(+), 3 deletions(-) diff --git a/diskann-providers/src/model/graph/provider/async_/caching/example.rs b/diskann-providers/src/model/graph/provider/async_/caching/example.rs index a80b0bf26..c533cf4f9 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/example.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/example.rs @@ -230,18 +230,21 @@ mod tests { use std::sync::Arc; use diskann::{ - graph::DiskANNIndex, + graph::{DiskANNIndex, glue::SearchStrategy, search::Knn, search_output_buffer}, provider::{ Accessor, DataProvider, Delete, NeighborAccessor, NeighborAccessorMut, SetElement, }, utils::async_tools, }; use diskann_quantization::num::PowerOfTwo; + use diskann_utils::views::Matrix; + use diskann_vector::{PureDistanceFunction, distance::SquaredL2}; use rstest::rstest; use crate::{ index::diskann_async::tests as async_tests, model::graph::provider::async_::caching::provider::{AsCacheAccessorFor, CachingProvider}, + utils as crate_utils, }; fn test_provider( @@ -701,4 +704,330 @@ mod tests { check_stats(index.provider()); } } + + #[rstest] + #[case(1, 100)] + #[case(3, 7)] + #[case(4, 5)] + #[tokio::test] + async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) { + let l = 10; + let max_degree = 2 * dim; + let num_points = (grid_size).pow(dim as u32); + let start_id = u32::MAX; + let start_point = vec![grid_size as f32; dim]; + let metric = Metric::L2; + let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); + + let index_config = diskann::graph::config::Builder::new( + max_degree, + diskann::graph::config::MaxDegree::default_slack(), + l, + metric.into(), + ) + .build() + .unwrap(); + + let config = test_provider::Config::new( + metric, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), + ) + .unwrap(); + + let mut vectors = ::generate_grid(dim, grid_size); + + let provider = CachingProvider::new( + test_provider::Provider::new(config), + ExampleCache::new(cache_size, None), + ); + let index = Arc::new(DiskANNIndex::new(index_config, provider, None)); + + let adjacency_lists = match dim { + 1 => crate_utils::generate_1d_grid_adj_list(grid_size as u32), + 3 => crate_utils::genererate_3d_grid_adj_list(grid_size as u32), + 4 => crate_utils::generate_4d_grid_adj_list(grid_size as u32), + _ => panic!("Unsupported number of dimensions"), + }; + assert_eq!(adjacency_lists.len(), num_points); + assert_eq!(vectors.len(), num_points); + + let strategy = cache_provider::Cached::new(test_provider::Strategy::new()); + let ctx = ctx(); + async_tests::populate_data(index.provider(), &ctx, &vectors).await; + { + // Note: Without the fully qualified syntax - this fails to compile. + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider, + [f32], + >>::search_accessor(&strategy, index.provider(), &ctx) + .unwrap(); + async_tests::populate_graph(&mut accessor, &adjacency_lists).await; + + accessor + .set_neighbors(start_id, &[num_points as u32 - 1]) + .await + .unwrap(); + } + + // Test with the zero query — the farthest point from the entry point. + let search_k = dim + 1; + let search_l = 10; + { + let query = vec![0.0f32; dim]; + let mut ids = vec![0u32; search_k]; + let mut distances = vec![0.0f32; search_k]; + let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let knn = Knn::new_default(search_k, search_l).unwrap(); + index + .search(knn, &strategy, &ctx, query.as_slice(), &mut output) + .await + .unwrap(); + + // The nearest neighbor should be id 0 with distance 0. + assert_eq!(ids[0], 0, "expected the nearest neighbor to be 0"); + assert_eq!(distances[0], 0.0, "expected the nearest distance to be 0"); + for &d in &distances[1..search_k] { + assert_eq!( + d, 1.0, + "expected corner query close neighbor to have distance 1.0" + ); + } + } + + // Test with the start point to ensure it is filtered out. + { + let query = start_point.clone(); + let mut ids = vec![0u32; search_k]; + let mut distances = vec![0.0f32; search_k]; + let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + let knn = Knn::new_default(search_k, search_l).unwrap(); + index + .search(knn, &strategy, &ctx, query.as_slice(), &mut output) + .await + .unwrap(); + + assert_ne!( + ids[0] as usize, num_points, + "start point should not be returned" + ); + } + + // Paged Search + let corpus: Matrix = async_tests::squish(vectors.iter(), dim); + + // Test paged search with the zero query. + { + let query = vec![0.0f32; dim]; + let mut gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| { + SquaredL2::evaluate(a, b) + }); + let max_candidates = gt.len(); + let mut state = index + .start_paged_search(strategy, &ctx, query.as_slice(), search_l) + .await + .unwrap(); + + let mut buffer = vec![diskann::neighbor::Neighbor::::default(); search_k]; + let mut seen = 0; + while !gt.is_empty() { + let count = index + .next_search_results::, [f32]>( + &ctx, + &mut state, + search_k, + &mut buffer, + ) + .await + .unwrap(); + for b in buffer.iter().take(count) { + let m = gt.iter().position(|g| g.id == b.id); + if let Some(j) = m { + gt.remove(j); + } + seen += 1; + if seen == max_candidates { + break; + } + } + if seen == max_candidates { + break; + } + } + } + + // Test paged search with the start point. + { + let query = start_point.clone(); + let mut gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| { + SquaredL2::evaluate(a, b) + }); + let max_candidates = gt.len(); + let mut state = index + .start_paged_search(strategy, &ctx, query.as_slice(), search_l) + .await + .unwrap(); + + let mut buffer = vec![diskann::neighbor::Neighbor::::default(); search_k]; + let mut seen = 0; + while !gt.is_empty() { + let count = index + .next_search_results::, [f32]>( + &ctx, + &mut state, + search_k, + &mut buffer, + ) + .await + .unwrap(); + for b in buffer.iter().take(count) { + let m = gt.iter().position(|g| g.id == b.id); + if let Some(j) = m { + gt.remove(j); + } + seen += 1; + if seen == max_candidates { + break; + } + } + if seen == max_candidates { + break; + } + } + } + + // Unfortunately - this is needed for the `check_grid_search` test. + vectors.push(start_point.clone()); + } + + #[tokio::test] + async fn test_inplace_delete_2d() { + // create small index instance + let metric = Metric::L2; + let num_points = 4; + let strategy = cache_provider::Cached::new(test_provider::DeletionAwareStrategy::new()); + let cache_size = PowerOfTwo::new(128 * 1024).unwrap(); + let start_id = num_points as u32; + let start_point = vec![0.5, 0.5]; + + let index_config = diskann::graph::config::Builder::new( + 4, // target_degree + diskann::graph::config::MaxDegree::default_slack(), + 10, // l_build + metric.into(), + ) + .build() + .unwrap(); + + let config = test_provider::Config::new( + metric, + index_config.max_degree().get(), + test_provider::StartPoint::new(start_id, start_point.clone()), + ) + .unwrap(); + + let index = DiskANNIndex::new( + index_config, + CachingProvider::new( + test_provider::Provider::new(config), + ExampleCache::new(cache_size, None), + ), + None, + ); + + let ctx = ctx(); + + // vectors are the four corners of a square, with the start point in the middle + // the middle point forms an edge to each corner, while corners form an edge + // to their opposite vertex vertically as well as the middle + let vectors = [ + vec![0.0, 0.0], + vec![0.0, 1.0], + vec![1.0, 0.0], + vec![1.0, 1.0], + ]; + let adjacency_lists = [ + AdjacencyList::from_iter_untrusted([4, 1]), + AdjacencyList::from_iter_untrusted([4, 0]), + AdjacencyList::from_iter_untrusted([4, 3]), + AdjacencyList::from_iter_untrusted([4, 2]), + AdjacencyList::from_iter_untrusted([0, 1, 2, 3]), + ]; + + // Note: Without the fully qualified syntax - this fails to compile. + let mut accessor = + as SearchStrategy< + cache_provider::CachingProvider, + [f32], + >>::search_accessor(&strategy, index.provider(), &ctx) + .unwrap(); + + async_tests::populate_data(index.provider(), &ctx, &vectors).await; + async_tests::populate_graph(&mut accessor, &adjacency_lists).await; + + index + .inplace_delete( + strategy, + &ctx, + &3, // id to delete + 3, // num_to_replace + diskann::graph::InplaceDeleteMethod::VisitedAndTopK { + k_value: 4, + l_value: 10, + }, + ) + .await + .unwrap(); + + // Check that the vertex was marked as deleted. + assert!( + index + .data_provider + .status_by_internal_id(&ctx, 3) + .await + .unwrap() + .is_deleted() + ); + + // expected outcome: + // vertex 4 (the start point) has its edge to 3 deleted + // vertex 2 (the other point with edge pointing to 3) should have its edge to point 3 deleted, + // and replaced with edges to points 0 and 1 + // vertices 0 and 1 should add an edge pointing to 2. + // vertex 3 should be dropped + { + let mut list = AdjacencyList::new(); + accessor.get_neighbors(4, &mut list).await.unwrap(); + list.sort(); + assert_eq!(&*list, &[0, 1, 2]); + } + + { + let mut list = AdjacencyList::new(); + accessor.get_neighbors(2, &mut list).await.unwrap(); + list.sort(); + assert_eq!(&*list, &[0, 1, 4]); + } + + { + let mut list = AdjacencyList::new(); + accessor.get_neighbors(0, &mut list).await.unwrap(); + list.sort(); + assert_eq!(&*list, &[1, 2, 4]); + } + + { + let mut list = AdjacencyList::new(); + accessor.get_neighbors(1, &mut list).await.unwrap(); + list.sort(); + assert_eq!(&*list, &[0, 2, 4]); + } + + { + let mut list = AdjacencyList::new(); + accessor.get_neighbors(3, &mut list).await.unwrap(); + assert!(list.is_empty()); + } + } } diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 56bc7bdc5..6f2a00683 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -18,9 +18,10 @@ use thiserror::Error; use crate::{ ANNError, ANNResult, error::{Infallible, message}, - graph::{AdjacencyList, glue, test::synthetic}, + graph::{AdjacencyList, SearchOutputBuffer, glue, test::synthetic}, internal::counter::{Counter, LocalCounter}, - provider, + neighbor::Neighbor, + provider::{self, BuildQueryComputer}, utils::VectorRepr, }; @@ -1089,6 +1090,143 @@ impl glue::InplaceDeleteStrategy for Strategy { } } +//--------------------------------------// +// Deletion-Aware Strategy and PostProc // +//--------------------------------------// + +/// A [`glue::SearchPostProcess`] that filters out deleted IDs before copying results. +/// +/// This is analogous to the `RemoveDeletedIdsAndCopy` post-processor in `diskann-providers`, +/// tailored for the test provider. +#[derive(Debug, Default, Clone, Copy)] +pub struct FilterDeletedCopyIds; + +impl<'a, T> glue::SearchPostProcess, T> for FilterDeletedCopyIds +where + T: ?Sized, + Accessor<'a>: BuildQueryComputer, +{ + type Error = std::convert::Infallible; + + fn post_process( + &self, + accessor: &mut Accessor<'a>, + _query: &T, + _computer: & as BuildQueryComputer>::QueryComputer, + candidates: I, + output: &mut B, + ) -> impl std::future::Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let provider = accessor.provider; + let count = output.extend(candidates.filter_map(|n| { + let is_deleted = provider.is_deleted(n.id).unwrap_or(true); + if is_deleted { + None + } else { + Some((n.id, n.distance)) + } + })); + std::future::ready(Ok(count)) + } +} + +/// A strategy variant that filters deleted IDs during search post-processing. +/// +/// This is needed for `inplace_delete` which relies on the post-processor to exclude +/// deleted items from search results. The base [`Strategy`] uses [`glue::CopyIds`] which +/// does not filter deletions. +#[derive(Debug, Default, Clone, Copy)] +pub struct DeletionAwareStrategy { + _phantom: (), +} + +impl DeletionAwareStrategy { + pub fn new() -> Self { + Self { _phantom: () } + } +} + +impl glue::SearchStrategy for DeletionAwareStrategy { + type QueryComputer = ::QueryDistance; + type PostProcessor = FilterDeletedCopyIds; + type SearchAccessorError = Infallible; + type SearchAccessor<'a> = Accessor<'a>; + + fn search_accessor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Infallible> { + Ok(Accessor::new(provider)) + } + + fn post_processor(&self) -> Self::PostProcessor { + FilterDeletedCopyIds + } +} + +impl glue::PruneStrategy for DeletionAwareStrategy { + type DistanceComputer = ::Distance; + type PruneAccessor<'a> = Accessor<'a>; + type PruneAccessorError = Infallible; + + fn prune_accessor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::PruneAccessorError> { + Ok(Accessor::new(provider)) + } +} + +impl glue::InsertStrategy for DeletionAwareStrategy { + type PruneStrategy = Self; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::SearchAccessorError> { + Ok(Accessor::new(provider)) + } +} + +impl glue::InplaceDeleteStrategy for DeletionAwareStrategy { + type DeleteElement<'a> = [f32]; + type DeleteElementGuard = Box<[f32]>; + type DeleteElementError = AccessedInvalidId; + type PruneStrategy = Self; + type SearchStrategy = Self; + + fn prune_strategy(&self) -> Self::PruneStrategy { + *self + } + + fn search_strategy(&self) -> Self::SearchStrategy { + *self + } + + async fn get_delete_element<'a>( + &'a self, + provider: &'a Provider, + _context: &'a ::Context, + id: ::InternalId, + ) -> Result { + provider + .terms + .get(&id) + .map(|v| (*v.data).into()) + .ok_or(AccessedInvalidId(id)) + } +} + /////////// // Tests // ///////////