diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index c972df2de..9bfa21d74 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -6,7 +6,7 @@ use std::{num::NonZeroUsize, sync::Arc}; use diskann::{ - ANNResult, + ANNError, ANNResult, graph::{ self, ConsolidateKind, InplaceDeleteMethod, glue::{ @@ -22,50 +22,66 @@ use diskann::{ utils::{ONE, async_tools::VectorIdBoxSlice}, }; +use crate::storage::{LoadWith, StorageReadProvider}; + +/// Synchronous wrapper around [`graph::DiskANNIndex`] that owns or borrows a tokio runtime. pub struct DiskANNIndex { /// The underlying async DiskANNIndex. pub inner: Arc>, + /// Keeps the runtime alive when `Self` owns it; `None` when using an external handle. _runtime: Option, handle: tokio::runtime::Handle, } +/// Create a multi-threaded tokio runtime and return it together with its handle. +fn create_multi_thread_runtime() -> (tokio::runtime::Runtime, tokio::runtime::Handle) { + #[allow(clippy::expect_used)] + let rt = tokio::runtime::Builder::new_multi_thread() + .build() + .expect("failed to create tokio runtime"); + let handle = rt.handle().clone(); + (rt, handle) +} + +/// Create a current-thread tokio runtime and return it together with its handle. +fn create_current_thread_runtime() -> (tokio::runtime::Runtime, tokio::runtime::Handle) { + #[allow(clippy::expect_used)] + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("failed to create tokio runtime"); + let handle = rt.handle().clone(); + (rt, handle) +} + impl DiskANNIndex where DP: DataProvider, { - /// Construct a synchronous `DiskANNIndex` with its own `tokio::runtime::Runtime`. + /// Construct a synchronous `DiskANNIndex` with its own multi-threaded `tokio::runtime::Runtime`. /// - /// A default configured multi-threaded runtime will be created and used behind the scenes. To use - /// a specific Toktio runtime, use `DiskANNIndex::new_with_multi_thread_runtime()` or `DiskANNIndex::new_with_handle()`. + /// A default multi-threaded runtime will be created and owned by `Self`. For a single-threaded + /// runtime use [`new_with_current_thread_runtime`](Self::new_with_current_thread_runtime), or + /// to supply an external runtime handle use [`new_with_handle`](Self::new_with_handle). pub fn new_with_multi_thread_runtime(config: graph::Config, data_provider: DP) -> Self { - #[allow(clippy::expect_used)] - let rt = tokio::runtime::Builder::new_multi_thread() - .build() - .expect("failed to create tokio runtime"); - - let handle = rt.handle().clone(); - + let (rt, handle) = create_multi_thread_runtime(); Self::new_internal(config, data_provider, Some(rt), handle, Some(ONE)) } - /// Construct a synchronous `DiskANNIndex` with its own `tokio::runtime::Runtime`. + /// Construct a synchronous `DiskANNIndex` with its own single-threaded `tokio::runtime::Runtime`. /// - /// A default configured runtime that uses the curren thread will be created and used behind the scenes. To use - /// a specific Toktio runtime, use `DiskANNIndex::new_with_multi_thread_runtime()` or `DiskANNIndex::new_with_handle()`. + /// A default current-thread runtime will be created and owned by `Self`. For a multi-threaded + /// runtime use [`new_with_multi_thread_runtime`](Self::new_with_multi_thread_runtime), or + /// to supply an external runtime handle use [`new_with_handle`](Self::new_with_handle). pub fn new_with_current_thread_runtime(config: graph::Config, data_provider: DP) -> Self { - #[allow(clippy::expect_used)] - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .expect("failed to create tokio runtime"); - - let handle = rt.handle().clone(); - + let (rt, handle) = create_current_thread_runtime(); Self::new_internal(config, data_provider, Some(rt), handle, Some(ONE)) } /// Construct a synchronous `DiskANNIndex` that uses a provided `tokio::runtime::Handle`. /// /// The `tokio::runtime::Runtime` is owned externally and we just keep a `Handle` to it. + /// `thread_hint` is forwarded to [`graph::DiskANNIndex::new`] to size internal thread pools; + /// pass `None` to let it choose a default. pub fn new_with_handle( config: graph::Config, data_provider: DP, @@ -83,13 +99,99 @@ where thread_hint: Option, ) -> Self { let inner = Arc::new(graph::DiskANNIndex::new(config, data_provider, thread_hint)); - Self { inner, _runtime: runtime, handle, } } + + /// Run an arbitrary async operation against the underlying + /// [`graph::DiskANNIndex`] using this wrapper's tokio runtime. + /// + /// This is a catch-all escape hatch for async methods on the inner index + /// that do not (yet) have a dedicated synchronous wrapper. The closure + /// receives an `&Arc>` and should return a future. + /// + /// # Example + /// + /// ```ignore + /// let stats = index.run(|inner| inner.some_async_method(&ctx))?; + /// ``` + pub fn run(&self, f: F) -> R + where + F: FnOnce(&Arc>) -> Fut, + Fut: core::future::Future, + { + self.handle.block_on(f(&self.inner)) + } + + /// Load a prebuilt index from storage with its own multi-threaded `tokio::runtime::Runtime`. + /// + /// This is the synchronous equivalent of + /// [`LoadWith::load_with`](crate::storage::LoadWith::load_with). + /// A default multi-threaded runtime is created and owned by `Self`. + /// For a single-threaded runtime use [`load_with_current_thread_runtime`](Self::load_with_current_thread_runtime), + /// or to supply an external runtime handle use [`load_with_handle`](Self::load_with_handle). + pub fn load_with_multi_thread_runtime(provider: &P, auxiliary: &T) -> ANNResult + where + graph::DiskANNIndex: LoadWith, + P: StorageReadProvider, + { + let (rt, handle) = create_multi_thread_runtime(); + let inner = handle.block_on(graph::DiskANNIndex::::load_with(provider, auxiliary))?; + Ok(Self { + inner: Arc::new(inner), + _runtime: Some(rt), + handle, + }) + } + + /// Load a prebuilt index from storage with its own single-threaded `tokio::runtime::Runtime`. + /// + /// This is the synchronous equivalent of + /// [`LoadWith::load_with`](crate::storage::LoadWith::load_with). + /// A default current-thread runtime is created and owned by `Self`. + /// For a multi-threaded runtime use [`load_with_multi_thread_runtime`](Self::load_with_multi_thread_runtime), + /// or to supply an external runtime handle use [`load_with_handle`](Self::load_with_handle). + pub fn load_with_current_thread_runtime(provider: &P, auxiliary: &T) -> ANNResult + where + graph::DiskANNIndex: LoadWith, + P: StorageReadProvider, + { + let (rt, handle) = create_current_thread_runtime(); + let inner = handle.block_on(graph::DiskANNIndex::::load_with(provider, auxiliary))?; + Ok(Self { + inner: Arc::new(inner), + _runtime: Some(rt), + handle, + }) + } + + /// Load a prebuilt index from storage using a provided `tokio::runtime::Handle`. + /// + /// This is the synchronous equivalent of + /// [`LoadWith::load_with`](crate::storage::LoadWith::load_with). + /// The `tokio::runtime::Runtime` is owned externally and we just keep a `Handle` to it. + /// For an owned runtime use [`load_with_multi_thread_runtime`](Self::load_with_multi_thread_runtime) + /// or [`load_with_current_thread_runtime`](Self::load_with_current_thread_runtime). + pub fn load_with_handle( + provider: &P, + auxiliary: &T, + handle: tokio::runtime::Handle, + ) -> ANNResult + where + graph::DiskANNIndex: LoadWith, + P: StorageReadProvider, + { + let inner = handle.block_on(graph::DiskANNIndex::::load_with(provider, auxiliary))?; + Ok(Self { + inner: Arc::new(inner), + _runtime: None, + handle, + }) + } + pub fn insert( &self, strategy: S, @@ -222,7 +324,6 @@ where .block_on(self.inner.consolidate_vector(strategy, context, vector_id)) } - #[allow(clippy::too_many_arguments, clippy::type_complexity)] pub fn search( &self, strategy: &S, @@ -320,3 +421,133 @@ where self.handle.block_on(self.inner.get_degree_stats(accessor)) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use diskann::{ + graph::{self, search_output_buffer}, + provider::DefaultContext, + utils::ONE, + }; + use diskann_utils::test_data_root; + use diskann_vector::distance::Metric; + + use super::DiskANNIndex; + use crate::{ + index::diskann_async, + model::{ + configuration::IndexConfiguration, + graph::provider::async_::{ + common::{FullPrecision, TableBasedDeletes}, + inmem::{self, CreateFullPrecision, DefaultProvider}, + }, + }, + storage::{AsyncIndexMetadata, SaveWith, StorageReadProvider, VirtualStorageProvider}, + utils::create_rnd_from_seed_in_tests, + }; + + #[test] + fn test_save_then_sync_load_round_trip() { + // -- Build an index in async context and save it ----------------------- + let save_path = "/index"; + let file_path = "/sift/siftsmall_learn_256pts.fbin"; + + let train_data = { + let storage = VirtualStorageProvider::new_overlay(test_data_root()); + let mut reader = storage.open_reader(file_path).unwrap(); + diskann_utils::io::read_bin::(&mut reader).unwrap() + }; + + let pq_bytes = 8; + let pq_table = diskann_async::train_pq( + train_data.as_view(), + pq_bytes, + &mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade), + 2, + ) + .unwrap(); + + let (build_config, parameters) = diskann_async::simplified_builder( + 20, + 32, + Metric::L2, + train_data.ncols(), + train_data.nrows(), + |_| {}, + ) + .unwrap(); + + let fp_precursor = + CreateFullPrecision::new(parameters.dim, parameters.prefetch_cache_line_level); + let data_provider = + DefaultProvider::new_empty(parameters, fp_precursor, pq_table, TableBasedDeletes) + .unwrap(); + + let index = + DiskANNIndex::new_with_current_thread_runtime(build_config.clone(), data_provider); + + let storage = VirtualStorageProvider::new_memory(); + let ctx = DefaultContext; + for (i, v) in train_data.row_iter().enumerate() { + index.insert(FullPrecision, &ctx, &(i as u32), v).unwrap(); + } + + let save_metadata = AsyncIndexMetadata::new(save_path.to_string()); + let storage_ref = &storage; + let metadata_ref = &save_metadata; + index + .run(|inner| { + let inner = Arc::clone(inner); + async move { inner.save_with(storage_ref, metadata_ref).await } + }) + .unwrap(); + + // -- Reload via the synchronous wrapped_async API ---------------------- + let load_config = IndexConfiguration::new( + Metric::L2, + train_data.ncols(), + train_data.nrows(), + ONE, + 1, + build_config, + ); + + type TestProvider = inmem::FullPrecisionProvider< + f32, + crate::model::graph::provider::async_::FastMemoryQuantVectorProviderAsync, + crate::model::graph::provider::async_::TableDeleteProviderAsync, + >; + + let loaded: DiskANNIndex = + DiskANNIndex::load_with_current_thread_runtime(&storage, &(save_path, load_config)) + .unwrap(); + + // -- Verify the loaded index is functional ----------------------------- + // A single search call is enough to confirm the sync wrapper loaded a + // working index. Exhaustive search-correctness is tested elsewhere. + let top_k = 5; + let search_l = 20; + let mut ids = vec![0u32; top_k]; + let mut distances = vec![0.0f32; top_k]; + let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let query = train_data.row(0); + let search_params = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let stats = loaded + .search( + &FullPrecision, + &DefaultContext, + query, + &search_params, + &mut output, + ) + .unwrap(); + + assert_eq!(stats.result_count, top_k as u32); + // The query is itself in the dataset, so the nearest neighbor must be at distance 0. + assert_eq!(ids[0], 0); + assert_eq!(distances[0], 0.0); + } +}