From a0fceca6d52e324dcdbc9bd65fd7b52e613a8f40 Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Tue, 10 Mar 2026 13:36:38 -0500 Subject: [PATCH 1/5] RFC for quantization bootstrapping --- rfcs/00000-quantizer-bootstrap.md | 158 ++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 rfcs/00000-quantizer-bootstrap.md diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md new file mode 100644 index 000000000..18f45ba3d --- /dev/null +++ b/rfcs/00000-quantizer-bootstrap.md @@ -0,0 +1,158 @@ +# Quantization Bootstrapping + +| | | +|---|---| +| **Authors** | Jack Moffitt | +| **Contributors** | | +| **Created** | 2026-03-10 | +| **Updated** | 2026-03-10 | + +## Summary + +Indexes that use quantization must have a minimum number of vectors before they +can build quantization tables and start inserting vectors. Bootstrapping is the +process of incrementally building an index that starts empty, operates on +non-quantized vectors until enough vectors are present to build quantization +tables, and then transitions to normal operation. + +## Motivation + +### Background + +DiskANN's quantizers require some statistical information in order to build +quantization tables. For PQ, 10,000 vectors are generally required to build good +tables; for spherical, 100 are needed. In order to create an index, these +vectors must be provided at creation time in order to build the quantization +tables, at which point each vector is quantized as it is inserted. + +This requirement is easy to fulfill when building indexes from existing +datasets, but when starting from scratch, there is no ability for DiskANN to +build a quantized index since the quantization tables are a required part of the +constructor. + +Current deployments of DiskANN work around this issue by not allowing index +creation until a dataset is sufficient large (pg_diskann), or operating a +separate flat index until sufficient vectors are collected at which point the +quantization tables are calculated and a graph index is built with DiskANN. + +### Problem Statement + +This RFC proposes changing DiskANN to operate in a quantization bootstrap mode +where it operates on full precision vectors until sufficient vectors exist to +create quantization tables, and then seamlessly transitions to a quantized +index. + +This means the index will operate in three different phases. In Phase 1, the +index operates in full precision mode only until sufficient vectors exist to +build quantization tables. During Phase 2, quantization tables will be built and +vectors will be quantized on insert; pre-existing vectors be quantized in the +background. Once all vectors are quantized, Phase Three begins the normal +operation of the quantized index. + +### Goals + +1. Allow quantized indices to start empty and use full precision data only to + operate until sufficient vectors are inserted. +2. Aside from allowing construction of `DiskANNIndex` without providing + quantization tables, there should user visible changes to using the index. +3. Performance should remain as high as possible during the three phases. +4. The quantization of previously inserted full vectors during Phase Two should + be controllable by the data provider. + +## Proposal + +Bootstrapping needs two changes to DiskANN. + +1. **Switching Strategies**: DiskANN needs to start by using full precision only + strategies during Phase One, and switching to a hybrid strategy for Phase + Two, and if the user's intent is to use quantized-only strategies, switching + to quantized-only for Phase Three. +2. **Quantization Backfill**: During Phase Two, previously inserted vectors will + need to be quantized. As background jobs are a performance concern, DiskANN + will need hooks for customizing this behavior. + +### Switching Strategies + +DiskANN already has the ability to run multiple strategies including hybrid full +precision and quantized ones. These should be sufficient for purposes of +bootstrapping, but we will need to orchestrate seamless transitions between +them. + +As the caller designates a strategy to use, we can implement new +`BootstrappedQuantized` and `BootstrappedHybrid` strategies that layer over +existing `FullPrecision`, `Quantized`, and `Hybrid` strategies. These new +bootstrapped strategies will delegate operation to the existing strategies +depending on the current phase. + +*Open question*: How exactly do we do this? + +### Quantization Backfill + +After quantization tables are built, newly inserted vectors will be quantized +before insertion, but previously inserted vectors won't have quantized +representations yet. During Phase Two, these previously inserted full precision +vectors will need to be quantized before the index enters Phase Three. + +Since integrators of DiskANN are sensitive to background jobs, how the index +manages backfilling quantized vectors should be controllable. + +The simplest way is to backfill all missing quantized vectors immediately during +the insert that starts Phase Two. This will cause a latency spike on that single +insert, but doesn't require any background processing. + +A more complicated solution would be to launch a background job that iterates +over full-precision only vectors and quantizes them. DiskANN should provide such +a job that integrators can use, but should also provide some callback that the +hosting database can pump to make incremental progress under its own control. + +Both of these methods can be realized by having a new trait `QuantBackfill`: + +```rust + +pub enum QuantBackfillStatus { + Incomplete, + Complete, +} + +pub trait QuantBackfill { + type BackfillError: AsyncFriendly; + + /// Backfill quantization vectors for up to approximately `duration` amount of time. + fn backfill(duration: Duration) -> impl Future> + AsyncFriendly; +} +``` + +This trait would be implemented on the type that implements the +`BootstrappedQuantized` and `BootstrappedHybrid` strategies. + +*Open question*: How to implement the background task and make it overridable? + +## Trade-offs + +Currently the workarounds in use are either no index at all until sufficient +vectors exist or operating a side index until sufficient vectors exist and then +building a quantized graph. + +pg_diskann uses the former method, which means users are confused when they try +to create indexes on empty tables or insufficiently populated tables and get an +error. Cosmos DB uses the latter strategy and operates a flat index until an +asynchronous graph build is complete enough to use the graph index. This +requires the Cosmos DB team to maintain all their own infrastructure for the +flat index and the code around transitioning to the graph index. + +This proposal mitigates the downsides while still allowing the integrator to +retain control over key performance details. + +## Benchmark Results + +Since there is no way to build an index currently until quantization tables are +built, there is no way to benchmark the first two phases. There should be no +impact during Phase Three to performance. + +## Future Work + +None. + +## References + +None. \ No newline at end of file From 94050f0f83c559f18d67783cbfe0d8bee741124c Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Thu, 19 Mar 2026 17:19:11 -0500 Subject: [PATCH 2/5] changes based on brainstorming with Mark --- rfcs/00000-quantizer-bootstrap.md | 140 +++++++++++++++++++----------- 1 file changed, 90 insertions(+), 50 deletions(-) diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md index 18f45ba3d..a74c60100 100644 --- a/rfcs/00000-quantizer-bootstrap.md +++ b/rfcs/00000-quantizer-bootstrap.md @@ -3,7 +3,7 @@ | | | |---|---| | **Authors** | Jack Moffitt | -| **Contributors** | | +| **Contributors** | Mark Hildebrand | | **Created** | 2026-03-10 | | **Updated** | 2026-03-10 | @@ -31,7 +31,7 @@ build a quantized index since the quantization tables are a required part of the constructor. Current deployments of DiskANN work around this issue by not allowing index -creation until a dataset is sufficient large (pg_diskann), or operating a +creation until a dataset is sufficiently large (pg_diskann), or operating a separate flat index until sufficient vectors are collected at which point the quantization tables are calculated and a graph index is built with DiskANN. @@ -45,8 +45,8 @@ index. This means the index will operate in three different phases. In Phase 1, the index operates in full precision mode only until sufficient vectors exist to build quantization tables. During Phase 2, quantization tables will be built and -vectors will be quantized on insert; pre-existing vectors be quantized in the -background. Once all vectors are quantized, Phase Three begins the normal +vectors will be quantized on insert; pre-existing vectors will be quantized in +the background. Once all vectors are quantized, Phase Three begins the normal operation of the quantized index. ### Goals @@ -54,78 +54,113 @@ operation of the quantized index. 1. Allow quantized indices to start empty and use full precision data only to operate until sufficient vectors are inserted. 2. Aside from allowing construction of `DiskANNIndex` without providing - quantization tables, there should user visible changes to using the index. + quantization tables, there should be no user-visible changes to using the index. 3. Performance should remain as high as possible during the three phases. 4. The quantization of previously inserted full vectors during Phase Two should be controllable by the data provider. ## Proposal -Bootstrapping needs two changes to DiskANN. +Bootstrapping needs two changes to a DiskANN data provider implementation. 1. **Switching Strategies**: DiskANN needs to start by using full precision only - strategies during Phase One, and switching to a hybrid strategy for Phase - Two, and if the user's intent is to use quantized-only strategies, switching - to quantized-only for Phase Three. + strategies during Phase One, and switching to a quantized-only or hybrid + strategy for Phase Two. 2. **Quantization Backfill**: During Phase Two, previously inserted vectors will - need to be quantized. As background jobs are a performance concern, DiskANN - will need hooks for customizing this behavior. + need to be quantized. As background jobs are a performance concern, how + exactly this is accomplished must be customizable by the data provider. ### Switching Strategies DiskANN already has the ability to run multiple strategies including hybrid full -precision and quantized ones. These should be sufficient for purposes of -bootstrapping, but we will need to orchestrate seamless transitions between -them. - -As the caller designates a strategy to use, we can implement new -`BootstrappedQuantized` and `BootstrappedHybrid` strategies that layer over -existing `FullPrecision`, `Quantized`, and `Hybrid` strategies. These new -bootstrapped strategies will delegate operation to the existing strategies +precision and quantized ones. These strategies represent the high level intent, +but the data provider can choose alternate implementions depending on the data +available during the current phase. + +#### Insertion and Deletion + +Insertion and deletion can remain largely the same in the data provider +implementation. Inserts will need to write vectors, mappings, attributes, et al +into storage, and can track the current phase to gate writes to quantized +vectors. The search portion of these operations will return different objects depending on the current phase. -*Open question*: How exactly do we do this? +For example, consider a `DataProvder::set_element()` implementation: -### Quantization Backfill +```rust +struct ExampleProvider { + // other fields omitted + quantizer: Option, +} -After quantization tables are built, newly inserted vectors will be quantized -before insertion, but previously inserted vectors won't have quantized -representations yet. During Phase Two, these previously inserted full precision -vectors will need to be quantized before the index enters Phase Three. +impl SetElement<[f32]> for ExampleProvider { + // associated types ommitted + + async fn set_element( + &self, + context: &Self::Context, + id: &Self::ExternalId, + element: &[T], + ) -> Result { + let internal_id = self.new_id()?; + self.write_vector(context, internal_id, element)?; + self.set_internal_map(internal_id, id)?; + self.set_external_map(id, internal_id)?; + + // Quantize and storage quant vector if we have a quantizer. + if let Some(quantizer) = self.quantizer { + let qv = quantizer.quantize(element)?; + self.write_quant_vector(context, internal_id, element)?; + } else { + // This function will check if we are ready for Phase Two, and if so, do or schedule the quantizer intialization. + self.maybe_initialize_quantizer()?; + } + + Ok(NoopGuard::new(internal_id)) + } +} +``` -Since integrators of DiskANN are sensitive to background jobs, how the index -manages backfilling quantized vectors should be controllable. +Delete can similarly check the status of the quantizer, and delete quantized +vectors if they exist. -The simplest way is to backfill all missing quantized vectors immediately during -the insert that starts Phase Two. This will cause a latency spike on that single -insert, but doesn't require any background processing. +#### Searching -A more complicated solution would be to launch a background job that iterates -over full-precision only vectors and quantizes them. DiskANN should provide such -a job that integrators can use, but should also provide some callback that the -hosting database can pump to make incremental progress under its own control. +To avoid complexity of hybrid distance calculations, either full precision +distances will be used (Phase One and Two) or quantized distances will be used +(Phase Three). If a hybrid strategy is in use, then the hybrid distances will +not be used until Phase Three. -Both of these methods can be realized by having a new trait `QuantBackfill`: +Since vector data may be in one of two representations, the `Accessor::Element` +type should be `Poly` (this should be over-aligned to the correct alignment +for the primitive element type), and the data provider should interpret based on +data size. The distance and query computers will also need modifications to +accept both vector representations, and in the case of query computer the +representation much match that of the query. -```rust -pub enum QuantBackfillStatus { - Incomplete, - Complete, -} +### Quantization Backfill + +After quantization tables are built, newly inserted vectors will be quantized +before insertion, but previously inserted vectors won't have quantized +representations yet. During Phase Two, these previously inserted full precision +vectors will need to be quantized before the index enters Phase Three. -pub trait QuantBackfill { - type BackfillError: AsyncFriendly; +Since integrators of DiskANN are sensitive to background jobs, how the index +manages backfilling quantized vectors is controlled by the data provider +implementation. The data provider must have some way to track which vectors have +missing quantized representations so that it generate them. - /// Backfill quantization vectors for up to approximately `duration` amount of time. - fn backfill(duration: Duration) -> impl Future> + AsyncFriendly; -} -``` +Once Phase Two is reached, the data provider can either pause during insertion +of the phase changing vector, or schedule the work to happen asynchronously +however it likes. -This trait would be implemented on the type that implements the -`BootstrappedQuantized` and `BootstrappedHybrid` strategies. +One possibility is to piggy-back on deletion tracking to track quantization +status of vectors. For example, in diskann-garnet, a free space map is kept that +tracks deletes. This could be expanded from 1-bit to 2-bits, and the second bit +used to track whether the vector is quantized. Alternatively, metadata about the +allocated range can be kept and used to iterate over the unquantized set. -*Open question*: How to implement the background task and make it overridable? ## Trade-offs @@ -134,7 +169,7 @@ vectors exist or operating a side index until sufficient vectors exist and then building a quantized graph. pg_diskann uses the former method, which means users are confused when they try -to create indexes on empty tables or insufficiently populated tables and get an +to create indexes on empty or insufficiently populated tables and get an error. Cosmos DB uses the latter strategy and operates a flat index until an asynchronous graph build is complete enough to use the graph index. This requires the Cosmos DB team to maintain all their own infrastructure for the @@ -143,6 +178,11 @@ flat index and the code around transitioning to the graph index. This proposal mitigates the downsides while still allowing the integrator to retain control over key performance details. +This proposal also entirely encapsulates this inside the `DataProvider` +implementation. Alternatively, one could attempt to solve this with some kind of +index or strategy layering, but the complexity this would introduce seems not +worth the cost. + ## Benchmark Results Since there is no way to build an index currently until quantization tables are From b5a3759e463e0356f38f5b98094bda85dc6ebaeb Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Mon, 13 Apr 2026 15:08:04 -0500 Subject: [PATCH 3/5] Implement BIN and Q8 quantizers --- Cargo.lock | 1 + diskann-garnet/Cargo.toml | 5 +- diskann-garnet/src/alloc.rs | 2 +- diskann-garnet/src/dyn_index.rs | 36 +- diskann-garnet/src/fsm.rs | 101 ++- diskann-garnet/src/garnet.rs | 8 + diskann-garnet/src/lib.rs | 73 +- diskann-garnet/src/provider.rs | 799 +++++++++++++++++--- diskann-garnet/src/quantization.rs | 317 ++++++++ diskann-garnet/src/test_utils.rs | 19 +- diskann-providers/src/common/minmax_repr.rs | 6 + diskann-providers/src/common/mod.rs | 2 +- vectorset/src/loader.rs | 52 +- vectorset/src/main.rs | 61 +- 14 files changed, 1310 insertions(+), 172 deletions(-) create mode 100644 diskann-garnet/src/quantization.rs diff --git a/Cargo.lock b/Cargo.lock index 1713f4b87..81fcc6f78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -778,6 +778,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "foldhash 0.2.0", + "rand 0.9.4", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-garnet/Cargo.toml b/diskann-garnet/Cargo.toml index c16857e05..b81049947 100644 --- a/diskann-garnet/Cargo.toml +++ b/diskann-garnet/Cargo.toml @@ -16,8 +16,9 @@ dashmap = { workspace = true, features = ["inline"] } diskann.workspace = true diskann-quantization.workspace = true diskann-providers.workspace = true +diskann-utils.workspace = true diskann-vector.workspace = true foldhash = "0.2.0" +rand.workspace = true thiserror.workspace = true -tokio.workspace = true -diskann-utils.workspace = true +tokio = { workspace = true, features = ["sync"] } diff --git a/diskann-garnet/src/alloc.rs b/diskann-garnet/src/alloc.rs index 17db0775e..16e3cfb3c 100644 --- a/diskann-garnet/src/alloc.rs +++ b/diskann-garnet/src/alloc.rs @@ -9,7 +9,7 @@ use std::ptr::NonNull; /// Custom allocator that over-aligns to 8 bytes. This is needed since Garnet will hand us byte slices for f32 data /// that may be unaligned, so we need an allocator to make owned, aligned byte containers. #[derive(Debug, Clone, Copy)] -pub(crate) struct AlignToEight; +pub struct AlignToEight; unsafe impl AllocatorCore for AlignToEight { #[inline] diff --git a/diskann-garnet/src/dyn_index.rs b/diskann-garnet/src/dyn_index.rs index badedba51..77153dabe 100644 --- a/diskann-garnet/src/dyn_index.rs +++ b/diskann-garnet/src/dyn_index.rs @@ -7,7 +7,7 @@ use crate::{ SearchResults, garnet::{Context, GarnetId}, labels::GarnetQueryLabelProvider, - provider::{self, GarnetProvider}, + provider::{self, DynamicQuantization, GarnetProvider}, }; use diskann::{ ANNError, ANNResult, @@ -16,8 +16,7 @@ use diskann::{ utils::VectorRepr, }; use diskann_providers::{ - index::wrapped_async::DiskANNIndex, - model::graph::provider::{async_::common::FullPrecision, layers::BetaFilter}, + index::wrapped_async::DiskANNIndex, model::graph::provider::layers::BetaFilter, }; use std::sync::Arc; @@ -55,6 +54,10 @@ pub trait DynIndex: Send + Sync { fn internal_id_exists(&self, context: &Context, id: u32) -> bool; fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool; + + fn train_quantizer(&self, context: &Context) -> bool; + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize); } impl DynIndex for DiskANNIndex> { @@ -63,7 +66,7 @@ impl DynIndex for DiskANNIndex> { /// The data slice here must be aligned to `T` or this will panic. fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()> { self.insert( - FullPrecision, + DynamicQuantization, context, id, bytemuck::cast_slice::(data), @@ -87,10 +90,10 @@ impl DynIndex for DiskANNIndex> { ) -> ANNResult { let query = bytemuck::cast_slice::(data); if let Some((labels, beta)) = filter { - let beta_filter = BetaFilter::new(FullPrecision, Arc::new(labels.clone()), beta); + let beta_filter = BetaFilter::new(DynamicQuantization, Arc::new(labels.clone()), beta); self.search(*params, &beta_filter, context, query, output) } else { - self.search(*params, &FullPrecision, context, query, output) + self.search(*params, &DynamicQuantization, context, query, output) } } @@ -105,9 +108,9 @@ impl DynIndex for DiskANNIndex> { let rt = tokio::runtime::Builder::new_current_thread() .build() .map_err(|e| ANNError::new(diskann::ANNErrorKind::Opaque, e))?; - let mut accessor: provider::FullAccessor<'_, T> = - >::search_accessor( - &FullPrecision, + let mut accessor: provider::DynamicAccessor<'_, T> = + >::search_accessor( + &DynamicQuantization, self.inner.provider(), context, )?; @@ -115,13 +118,12 @@ impl DynIndex for DiskANNIndex> { // Look up internal ID let iid = self.inner.provider().to_internal_id(context, id)?; let data = rt.block_on(accessor.get_element(iid))?; - let data_bytes = bytemuck::cast_slice::(&data); - self.search_vector(context, data_bytes, params, filter, output) + self.search_vector(context, &data, params, filter, output) } fn remove(&self, context: &Context, id: &GarnetId) -> ANNResult<()> { self.inplace_delete( - FullPrecision, + DynamicQuantization, context, id, 3, @@ -147,4 +149,14 @@ impl DynIndex for DiskANNIndex> { fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool { self.inner.provider().vector_id_exists(context, id) } + + fn train_quantizer(&self, context: &Context) -> bool { + self.inner.provider().train_quantizer(context) + } + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + self.inner + .provider() + .backfill_quant_vectors(context, task_idx, task_count); + } } diff --git a/diskann-garnet/src/fsm.rs b/diskann-garnet/src/fsm.rs index 37096278f..6b6e56130 100644 --- a/diskann-garnet/src/fsm.rs +++ b/diskann-garnet/src/fsm.rs @@ -26,7 +26,7 @@ use crossbeam::queue::ArrayQueue; use std::sync::{ RwLock, - atomic::{AtomicBool, AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}, }; use thiserror::Error; @@ -47,11 +47,23 @@ pub enum FsmError { } pub struct FreeSpaceMap { + /// Garnet callbacks for reading/writing FSM keys callbacks: Callbacks, + /// A flag to signal whether there are free IDs in the FSM. + /// This is set after a scan of the FSM, and is used to prevent extraneous reads + /// of FSM blocks. has_free_ids: AtomicBool, + /// A queue of previously deleted IDs to prevent excessive reads to the FSM fast_free_list: ArrayQueue, + /// The maximum block ID stored in the FSM max_block: RwLock, + /// The next ID that will be minted if reusing previously deleted IDs is unavailable next_id: AtomicU32, + /// The total number of IDs marked used in the FSM + total_used: AtomicUsize, + /// Lock that prevents reuse of prevously used IDs. + /// This is used to disable ID reuse during quantization backfill. + reuse_lock: AtomicUsize, } impl FreeSpaceMap { @@ -60,6 +72,7 @@ impl FreeSpaceMap { let fast_free_list = ArrayQueue::new(FAST_SIZE); let max_block = RwLock::new(u32::MAX); let next_id = AtomicU32::new(0); + let total_used = AtomicUsize::new(0); let mut this = Self { callbacks, @@ -67,6 +80,8 @@ impl FreeSpaceMap { fast_free_list, max_block, next_id, + total_used, + reuse_lock: AtomicUsize::new(0), }; // Attempt to load state from Garnet. @@ -97,6 +112,7 @@ impl FreeSpaceMap { let mut block = vec![0u8; BLOCK_SIZE_BYTES]; let mut last_used_id = -1i64; + let mut total_used = 0usize; for block_id in (0..max_block_id).rev() { let block_key = Self::block_key(block_id); @@ -115,8 +131,9 @@ impl FreeSpaceMap { let used = bit_used(byte, bidx); if used { last_used_id = last_used_id.max(id as i64); - } else if (id as i64) < last_used_id && self.fast_free_list.push(id).is_err() { - break; + total_used += 1; + } else if (id as i64) < last_used_id { + let _ = self.fast_free_list.push(id); } id = id.saturating_sub(1); @@ -130,6 +147,8 @@ impl FreeSpaceMap { self.next_id .store((last_used_id + 1) as u32, Ordering::Release); + self.total_used.store(total_used, Ordering::Release); + if !self.fast_free_list.is_empty() { self.has_free_ids.store(true, Ordering::Release); } @@ -174,6 +193,14 @@ impl FreeSpaceMap { return Err(FsmError::Garnet(GarnetError::Write)); } + if changed { + if used { + self.total_used.fetch_add(1, Ordering::AcqRel); + } else { + self.total_used.fetch_sub(1, Ordering::AcqRel); + } + } + // NOTE: We don't modify the free list if the id was already free. if !used && changed { // Push the id onto the fast free list. If the queue is full, ignore it. @@ -214,7 +241,7 @@ impl FreeSpaceMap { /// This may be a a fresh ID larger than all the others, or it may be a reused ID that /// previously belonged to a deleted element. The returned ID is marked as used. pub fn next_id(&self, ctx: Context) -> Result { - if self.has_free_ids.load(Ordering::Acquire) { + if self.can_reuse() && self.has_free_ids.load(Ordering::Acquire) { // We retry reusing a freed ID until there are none or we get one and marking it used // succeeds in changing the value. loop { @@ -270,6 +297,10 @@ impl FreeSpaceMap { self.next_id.load(Ordering::Acquire).saturating_sub(1) } + pub fn total_used(&self) -> usize { + self.total_used.load(Ordering::Acquire) + } + /// Return the FSM block number, byte index, and bit index for a given ID. /// The block number is the block which stores this ID, the byte index is byte offset /// within the block which contains the status bits, and the bit index is the bit index @@ -301,9 +332,9 @@ impl FreeSpaceMap { let mut has_free_ids = false; let mut id = 0u32; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; 'scan: for block_id in 0..*max_block { let block_key = Self::block_key(block_id); - let mut block = vec![0u8; BLOCK_SIZE_BYTES]; if !self .callbacks .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) @@ -363,6 +394,66 @@ impl FreeSpaceMap { Ok(()) } + + /// Visit each used id in the FSM, invoking f on each id. + pub fn visit_used(&self, ctx: Context, mut f: F) -> Result<(), FsmError> + where + F: FnMut(u32) -> bool, + { + let max_block = { *self.max_block.read().unwrap() }; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; + let mut id = 0u32; + + for block_id in 0..max_block + 1 { + let block_key = Self::block_key(block_id); + if !self + .callbacks + .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + { + return Err(FsmError::Garnet(GarnetError::Read)); + } + + for &byte in &block { + if byte == 0x00 { + id += 8; + continue; + } + + for bidx in 0..8 { + if bit_used(byte, bidx) { + let keep_going = f(id); + if !keep_going { + return Ok(()); + } + } + id += 1; + } + } + } + + Ok(()) + } + + /// Returns whether previously deleted IDs may be reused. + fn can_reuse(&self) -> bool { + self.reuse_lock.load(Ordering::Acquire) == 0 + } + + /// Prevent the reuse of previously deleted IDs. + /// Each call to this increments a counter, and only once the counter is back to zero + /// will reuse be allowed again. + pub fn lock_reuse(&self) { + self.reuse_lock.fetch_add(1, Ordering::AcqRel); + } + + /// Resume reuse of previously deleted IDs. + /// Each call to this decrements a counter, and only once the counter is back to zero + /// will reuse be allowed. This returns whether reuse was actually enabled. + pub fn unlock_reuse(&self) -> bool { + let prev = self.reuse_lock.fetch_sub(1, Ordering::AcqRel); + debug_assert_ne!(prev, 0); + prev == 1 + } } /// Return whether the `bidx`th bit is set in byte, where bits are labeled from left to right. diff --git a/diskann-garnet/src/garnet.rs b/diskann-garnet/src/garnet.rs index 85ba0cce5..566460af1 100644 --- a/diskann-garnet/src/garnet.rs +++ b/diskann-garnet/src/garnet.rs @@ -86,6 +86,10 @@ impl Callbacks { self.rmw_callback } + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] pub fn exists_iid(&self, ctx: Context, id: u32) -> bool { let key = [4, id]; // SAFETY: Key bytes are preceded by 4 bytes of space. @@ -100,6 +104,10 @@ impl Callbacks { unsafe { self.exists_raw(ctx, &key_bytes[4..]) } } + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] pub fn exists_eid(&self, ctx: Context, id: &GarnetId) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.exists_raw(ctx, id) } diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index a66cd7ab0..c6b40e37f 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -51,6 +51,7 @@ mod fsm; mod garnet; mod labels; mod provider; +mod quantization; #[cfg(test)] mod test_utils; @@ -175,7 +176,8 @@ fn create_index_impl( callbacks: Callbacks, context: Context, ) -> Result, GarnetProviderError> { - let provider = GarnetProvider::::new(dim, metric_type, max_degree, callbacks, context)?; + let provider = + GarnetProvider::::new(dim, quant_type, metric_type, max_degree, callbacks, context)?; let state = if provider.start_points_exist() { AtomicUsize::new(IndexState::Ready as usize) } else { @@ -217,7 +219,7 @@ pub unsafe extern "C" fn create_index( target_degree, config::MaxDegree::Value(max_degree as usize), l_build as usize, - config::PruneKind::TriangleInequality, + metric_type.into(), ) .build() { @@ -230,6 +232,7 @@ pub unsafe extern "C" fn create_index( let callbacks = Callbacks::new(read_callback, write_callback, delete_callback, rmw_callback); match quant_type { + VectorQuantType::Invalid => ptr::null(), VectorQuantType::XPreQ8 => { if let Ok(index) = create_index_impl::( quant_type, @@ -245,7 +248,7 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - VectorQuantType::NoQuant => { + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { if let Ok(index) = create_index_impl::( quant_type, config, @@ -260,7 +263,6 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - _ => ptr::null(), } } @@ -318,6 +320,9 @@ fn interpret_vector<'a>( let v = match vector_value_type { VectorValueType::Invalid => return None, VectorValueType::FP32 => match quant_type { + VectorQuantType::Invalid => { + return None; + } VectorQuantType::XPreQ8 => { let mut bp = if let Ok(bp) = Poly::broadcast(0u8, vector_len, AlignToEight) { bp @@ -332,11 +337,11 @@ fn interpret_vector<'a>( } PolyCow::from(bp) } - VectorQuantType::NoQuant if v.as_ptr().align_offset(4) == 0 => { + _ if v.as_ptr().align_offset(4) == 0 => { // pointer is correctly aligned to interpret as f32 PolyCow::from(v) } - VectorQuantType::NoQuant => { + _ => { // need to copy f32 data as it is unaligned let mut fp = if let Ok(fp) = Poly::broadcast(0u8, vector_len_bytes, AlignToEight) { fp @@ -346,9 +351,6 @@ fn interpret_vector<'a>( fp.copy_from_slice(v); PolyCow::from(fp) } - _ => { - return None; - } }, VectorValueType::XB8 => match quant_type { VectorQuantType::XPreQ8 => PolyCow::from(v), @@ -375,6 +377,10 @@ fn interpret_vector<'a>( Some(v) } +const INSERT_FAIL: u8 = 0; +const INSERT_SUCCESS: u8 = 1; +const INSERT_SUCCESS_START_TRAINING: u8 = 2; + /// # Safety /// /// FFI @@ -389,7 +395,7 @@ pub unsafe extern "C" fn insert( vector_len: usize, attribute_data: *const u8, attribute_len: usize, -) -> bool { +) -> u8 { let index = unsafe { &*index_ptr.cast::() }; let ctx = Context(ctx); @@ -404,13 +410,13 @@ pub unsafe extern "C" fn insert( ) { v } else { - return false; + return INSERT_FAIL; }; if let Some(_err) = ensure_index_ready_or_init(index, || index.inner.maybe_set_start_point(&ctx, &v).err()) { - return false; + return INSERT_FAIL; }; // Write attributes to garnet @@ -420,11 +426,22 @@ pub unsafe extern "C" fn insert( &[] }; if index.inner.set_attributes(&ctx, &id, attr_data).is_err() { - return false; + return INSERT_FAIL; } + let old_ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + // Insert the vector - index.inner.insert(&ctx, &id, &v).is_ok() + if index.inner.insert(&ctx, &id, &v).is_ok() { + let ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + if !old_ready && ready { + INSERT_SUCCESS_START_TRAINING + } else { + INSERT_SUCCESS + } + } else { + INSERT_FAIL + } } fn ensure_index_ready_or_init(index: &Index, init: F) -> Option @@ -466,6 +483,34 @@ where None } +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn build_quant_table(context: u64, index_ptr: *const c_void) -> bool { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context(context); + + index.inner.train_quantizer(&ctx) +} + +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn backfill_quant_vectors( + context: u64, + index_ptr: *const c_void, + task_index: usize, + task_count: usize, +) { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context(context); + index + .inner + .backfill_quant_vectors(&ctx, task_index, task_count); +} + /// # Safety /// /// FFI diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 0be5be9a4..6f8750135 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -11,7 +11,7 @@ use diskann::{ config::defaults::MAX_OCCLUSION_SIZE, glue::{ self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchPostProcessStep, SearchStrategy, }, workingset::{self, map::Entry}, }, @@ -22,21 +22,40 @@ use diskann::{ }, utils::VectorRepr, }; -use diskann_providers::model::graph::provider::async_::common::FullPrecision; -use diskann_utils::Reborrow; +use diskann_quantization::alloc::{AllocatorError, Poly}; use diskann_utils::object_pool::{AsPooled, ObjectPool, PooledRef, Undef}; -use diskann_vector::{PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric}; +use diskann_utils::{Reborrow, views::Matrix}; +use diskann_vector::{ + DistanceFunction, PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric, +}; use std::{ - future, mem, + any::TypeId, + future, + marker::PhantomData, + mem, ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, + time::SystemTime, }; use thiserror::Error; use crate::{ + VectorQuantType, + alloc::AlignToEight, fsm::{FreeSpaceMap, FsmError}, garnet::{Callbacks, Context, GarnetError, GarnetId, Term}, + quantization::{ + self, DynDistanceComputer, DynQueryComputer, GarnetQuantizer, GarnetQuantizerError, + }, }; +thread_local! { + /// Thread local flag to detect when we've reached the quantization threshold. This is needed + /// to return the correct status at the end of insert() since the provider code that becomes + /// aware of the state change cannot directly communicate it back to the caller. + pub static QUANTIZER_READY: AtomicBool = const { AtomicBool::new(false) }; +} + #[derive(Clone)] struct AdjList(AdjacencyList); @@ -64,6 +83,44 @@ impl AsPooled for AdjList { } } +/// A type erased vector, properly aligned so that it can hold any size element +/// (up to 8 bytes long). +pub struct DynVector { + inner: Poly<[u8], AlignToEight>, + ty: PhantomData, +} + +impl DynVector { + fn new(inner: Poly<[u8], AlignToEight>) -> Self { + Self { + inner, + ty: PhantomData, + } + } +} + +impl Deref for DynVector { + type Target = Poly<[u8], AlignToEight>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for DynVector { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a, T: VectorRepr> Reborrow<'a> for DynVector { + type Target = &'a [u8]; + + fn reborrow(&'a self) -> Self::Target { + &self.inner + } +} + #[derive(Debug, Error)] pub enum GarnetProviderError { #[error("Garnet operation failed")] @@ -72,6 +129,18 @@ pub enum GarnetProviderError { Fsm(#[from] FsmError), #[error("Start point invalid")] StartPoint, + #[error("Allocation failed")] + AllocFailed(#[from] AllocatorError), + #[error("Invalid quantizer for vector data")] + InvalidQuantizer, + #[error("Expected quantizer is missing")] + MissingQuantizer, + #[error("Failed to gather quantizer training data")] + MissingTrainingData, + #[error("Quantizer error: {0}")] + Quantizer(#[from] GarnetQuantizerError), + #[error("Post processing error: {0}")] + PostProcessing(Box), } impl From for ANNError { @@ -83,21 +152,32 @@ impl From for ANNError { diskann::always_escalate!(GarnetProviderError); +/// The Garnet DataProvider implementation. pub struct GarnetProvider { dim: usize, metric_type: Metric, max_degree: usize, callbacks: Callbacks, + /// The quantizer the index will use, or None if NOQUANT is used. + quantizer: Option>, + /// Tracks whether quantization backfill is complete. + all_quantized: AtomicBool, + /// Tracks whether training has already started. + training_started: AtomicU64, id_buffer_pool: ObjectPool, filtered_ids_pool: ObjectPool>, + quant_buffer_pool: ObjectPool>, neighbor_cache: DashMap, foldhash::fast::RandomState>, - start_point_cache: DashMap, foldhash::fast::RandomState>, + start_point_cache: DashMap, foldhash::fast::RandomState>, + start_point_quant_cache: DashMap, foldhash::fast::RandomState>, fsm: FreeSpaceMap, + _phantom: PhantomData, } impl GarnetProvider { pub fn new( dim: usize, + quant_type: VectorQuantType, metric_type: Metric, max_degree: usize, callbacks: Callbacks, @@ -114,11 +194,13 @@ impl GarnetProvider { let start_point_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); + let start_point_quant_cache = + DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); let neighbor_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); // Try to read the start point from Garnet - let mut v = vec![T::default(); dim]; + let mut v = Poly::broadcast(0u8, dim * mem::size_of::(), AlignToEight)?; if callbacks.read_single_iid(context.term(Term::Vector), 0, &mut v) { let mut neighbors = vec![0u32; max_degree + 1]; if !callbacks.read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) { @@ -134,16 +216,50 @@ impl GarnetProvider { let fsm = FreeSpaceMap::new(context, callbacks)?; + let (quantizer, canonical_bytes, all_quantized) = match quant_type { + VectorQuantType::NoQuant | VectorQuantType::XPreQ8 => (None, 0, false), + VectorQuantType::Invalid => return Err(GarnetProviderError::InvalidQuantizer), + VectorQuantType::Q8 => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = Box::new(quantization::MinMax8Bit::new(dim, metric_type)?) + as Box; + let canonical_bytes = quantizer.canonical_bytes(); + // NOTE: Q8 needs no training, so it always starts with backfill complete. + (Some(quantizer), canonical_bytes, true) + } + VectorQuantType::Bin => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = + Box::new(quantization::Spherical1Bit::new(dim)) as Box; + let canonical_bytes = quantizer.canonical_bytes(); + (Some(quantizer), canonical_bytes, false) + } + }; + let quant_buffer_pool = + ObjectPool::new(Undef::new(canonical_bytes), parallelism, Some(parallelism)); + Ok(Self { dim, metric_type, max_degree, callbacks, + quantizer, + all_quantized: AtomicBool::new(all_quantized), + training_started: AtomicU64::new(0), id_buffer_pool, filtered_ids_pool, + quant_buffer_pool, start_point_cache, + start_point_quant_cache, neighbor_cache, fsm, + _phantom: PhantomData, }) } @@ -152,7 +268,7 @@ impl GarnetProvider { context: &Context, point: &[T], ) -> Result<(), GarnetProviderError> { - let mut v = vec![T::default(); self.dim]; + let mut v = Poly::broadcast(0u8, self.dim * mem::size_of::(), AlignToEight)?; if self .callbacks .read_single_iid(context.term(Term::Vector), 0, &mut v) @@ -187,6 +303,25 @@ impl GarnetProvider { return Err(GarnetError::Write.into()); } + if self.is_quantized() + && let Some(quantizer) = self.quantizer() + { + // We are already able to quantize, so store the quantized start point + + let mut qpoint = vec![0u8; quantizer.canonical_bytes()]; + quantizer.compress(bytemuck::cast_slice::(point), &mut qpoint)?; + + if !self + .callbacks + .write_iid(context.term(Term::Quantized), 0, &qpoint) + { + return Err(GarnetError::Write.into()); + } + + self.start_point_quant_cache + .insert(0, Poly::from_iter(qpoint.iter().copied(), AlignToEight)?); + } + if !self .callbacks .write_iid(context.term(Term::Neighbors), 0, &neighbors) @@ -194,7 +329,13 @@ impl GarnetProvider { return Err(GarnetError::Write.into()); } - self.start_point_cache.insert(0, point.to_vec()); + self.start_point_cache.insert( + 0, + Poly::from_iter( + bytemuck::cast_slice::(point).iter().copied(), + AlignToEight, + )?, + ); self.neighbor_cache .insert(0, Vec::with_capacity(self.max_degree + 1)); } @@ -206,10 +347,6 @@ impl GarnetProvider { self.start_point_cache.get(&0).is_some() && self.neighbor_cache.get(&0).is_some() } - pub fn callbacks(&self) -> &Callbacks { - &self.callbacks - } - pub fn set_attributes( &self, context: &Context, @@ -240,6 +377,189 @@ impl GarnetProvider { pub fn max_internal_id(&self) -> u32 { self.fsm.max_id() } + + /// Train the quantizer. + /// + /// This should only be called when at least `quantizer.required_vectors()` vectors exist in + /// the provider. This will build quantization tables, but does not quantize any vectors. + pub fn train_quantizer(&self, context: &Context) -> bool { + // Collect up to `self.quantizer.required_vectors()` vectors and use them for quantizer training. + + let current_time: u64 = + match SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH) { + Ok(t) => { + let t = t.as_millis().try_into().unwrap_or(0); + if t == 0 { + return false; + } + t + } + Err(_) => return false, + }; + + // Ensure we don't kick off training twice. + match self.training_started.compare_exchange( + 0, + current_time, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => (), + Err(_) => return false, + } + + let quantizer = match &self.quantizer { + Some(q) => q, + None => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + + debug_assert_eq!(std::any::TypeId::of::(), std::any::TypeId::of::()); + + let rows = quantizer.required_vectors(); + let mut data = Matrix::new(0f32, rows, self.dim); + let mut row_idx = 0usize; + + if self + .fsm + .visit_used(*context, |id| { + // Skip the start point. + if id == 0 { + return true; + } + + if row_idx >= rows { + return false; + } + + // Read the vector into the data matrix. + // Note that it's ok to read f32 instead of T here, because this can only get called when T == f32. + let row = data.row_mut(row_idx); + if !self + .callbacks + .read_single_iid(context.term(Term::Vector), id, row) + { + return false; + } + + row_idx += 1; + + true + }) + .is_err() + { + // Training failed + self.training_started.store(0, Ordering::Release); + return false; + } + + if row_idx < quantizer.required_vectors() { + self.training_started.store(0, Ordering::Release); + return false; + } + + let view = if let Some(view) = data.subview(0..row_idx) { + view + } else { + return false; + }; + + // Train the quantizer. + match quantizer.train(self.metric_type, view) { + Ok(()) => true, + Err(_e) => { + self.training_started.store(0, Ordering::Release); + false + } + } + } + + /// Bulk quantize previously inserted vectors. + /// + /// This function will be invoked on multiple threads. The total number of tasks and the ID of + /// the current task are given as inputs. + pub fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + let quantizer = match &self.quantizer { + Some(q) => q, + None => return, + }; + + let max_id = self.fsm.max_id() as usize; + + let task_count = task_count.min(max_id + 1); + if task_idx >= task_count { + return; + } + + let work_count = (max_id + 1) / task_count; // will be >= 1 + let start_id = (work_count * task_idx) as u32; + let end_id = (work_count * (task_idx + 1)).min(max_id + 1) as u32; + + self.fsm.lock_reuse(); + + let mut v = vec![0f32; self.dim]; + let mut q = vec![0u8; quantizer.canonical_bytes()]; + for id in start_id..end_id { + if !self + .callbacks + .read_single_iid(context.term(Term::Vector), id, &mut v) + { + continue; + } + + if quantizer.compress(&v, &mut q).is_err() { + continue; + }; + + if !self + .callbacks + .write_iid(context.term(Term::Quantized), id, &q) + { + continue; + } + } + + let backfill_finished = self.fsm.unlock_reuse(); + + if backfill_finished { + // Finish by quantizing the start points + if let Some(v) = self.start_point_cache.get(&0) + && quantizer + .compress(bytemuck::cast_slice::(&v), &mut q) + .is_ok() + { + let _ = self + .callbacks + .write_iid(context.term(Term::Quantized), 0, &q); + + // set the cache + let point = if let Ok(p) = Poly::from_iter(q.iter().copied(), AlignToEight) { + p + } else { + return; + }; + self.start_point_quant_cache.insert(0, point); + } + + self.all_quantized.store(true, Ordering::Release); + } + } + + /// Returns the quantizer associated with the index. + fn quantizer(&self) -> Option<&dyn GarnetQuantizer> { + if let Some(quantizer) = &self.quantizer { + return Some(&**quantizer as &dyn GarnetQuantizer); + } + + None + } + + /// Returns quantization status. If this is true, the index is operating fully quantized. + pub fn is_quantized(&self) -> bool { + self.quantizer.is_some() && self.all_quantized.load(Ordering::Acquire) + } } impl DataProvider for GarnetProvider { @@ -287,11 +607,31 @@ impl SetElement<&[T]> for GarnetProvider { ) -> Result { let internal_id = self.fsm.next_id(*context)?; + // Set quantization readiness + if let Some(quantizer) = &self.quantizer + && !quantizer.is_prepared() + && self.fsm.total_used() > quantizer.required_vectors() + { + QUANTIZER_READY.with(|v| v.store(true, Ordering::Release)); + } + let insert = || -> Result<(), Self::SetError> { self.callbacks .write_iid(context.term(Term::Vector), internal_id, element) .then_some(()) .ok_or(GarnetError::Write)?; + if let Some(quantizer) = &self.quantizer + && quantizer.is_prepared() + { + let mut quant = self + .quant_buffer_pool + .get_ref(Undef::new(quantizer.canonical_bytes())); + quantizer.compress(bytemuck::cast_slice::(element), &mut quant)?; + self.callbacks + .write_iid(context.term(Term::Quantized), internal_id, &quant) + .then_some(()) + .ok_or(GarnetError::Write)?; + } self.callbacks .write_iid(context.term(Term::ExtMap), internal_id, id) .then_some(()) @@ -345,6 +685,7 @@ impl Delete for GarnetProvider { // NOTE: Commented out until DiskANN fixes accessing neighbor data post-delete. //ok &= self.callbacks.delete_iid(context.term(Term::Neighbors), id); ok &= self.callbacks.delete_iid(context.term(Term::Vector), id); + ok &= self.callbacks.delete_iid(context.term(Term::Quantized), id); if !ok { return future::ready(Err(GarnetError::Delete.into())); @@ -386,20 +727,25 @@ impl Delete for GarnetProvider { } } -#[allow(dead_code)] -pub struct FullAccessor<'a, T: VectorRepr> { +/// Dynamic accessor that seamlessly transitions from full precision vector based operation to +/// quantized-only operation. +#[derive(Copy, Clone, Debug)] +pub struct DynamicQuantization; + +pub struct DynamicAccessor<'a, T: VectorRepr> { provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + /// Whether this accessor should use quantized vectors + quantized: bool, id_buffer: PooledRef<'a, AdjList>, filtered_ids: PooledRef<'a, Vec>, } -impl<'a, T: VectorRepr> FullAccessor<'a, T> { +impl<'a, T: VectorRepr> DynamicAccessor<'a, T> { pub(crate) fn new( provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + quantized: bool, ) -> Self { let id_buffer = provider .id_buffer_pool @@ -407,10 +753,10 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { let filtered_ids = provider .filtered_ids_pool .get_ref(Undef::new(MAX_OCCLUSION_SIZE.get() as usize * 2)); // x2 to allow for the length prefixes for garnet - FullAccessor { + DynamicAccessor { provider, context, - is_search, + quantized, id_buffer, filtered_ids, } @@ -434,6 +780,7 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { id, &mut guard, ) { + guard.finish(0); return false; } @@ -444,11 +791,11 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { } } -impl HasId for FullAccessor<'_, T> { +impl HasId for DynamicAccessor<'_, T> { type Id = u32; } -impl SearchExt for FullAccessor<'_, T> { +impl SearchExt for DynamicAccessor<'_, T> { fn starting_points(&self) -> impl Future>> + Send { let points = if self.provider.start_points_exist() { vec![0] @@ -466,7 +813,7 @@ impl SearchExt for FullAccessor<'_, T> { } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T> { +impl ExpandBeam<&[T]> for DynamicAccessor<'_, T> { fn expand_beam( &mut self, ids: Itr, @@ -490,14 +837,31 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { .filter(|id| pred.eval_mut(id)) { if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + let dist = if self.quantized + && let Some(_quantizer) = self.provider.quantizer() + { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) + { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) } else { - return future::ready(Err( - GarnetProviderError::Garnet(GarnetError::Read).into() - )); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) }; - let dist = computer.evaluate_similarity(&*guard); + on_neighbors(dist, id); } else { self.filtered_ids.push(4); @@ -505,49 +869,78 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { } } - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |i, v| { - let dist = computer.evaluate_similarity(v); - on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); - }, - ); + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.filtered_ids.is_empty() { + self.provider + .callbacks + .read_multi_lpiid(ctx, &self.filtered_ids, |i, v| { + let dist = computer.evaluate_similarity(v); + on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); + }); + } } future::ready(Ok(())) } } -impl Accessor for FullAccessor<'_, T> { +impl Accessor for DynamicAccessor<'_, T> { type Element<'a> - = Vec + = DynVector where Self: 'a; - type ElementRef<'a> = &'a [T]; + type ElementRef<'a> = &'a [u8]; type GetError = GarnetProviderError; fn get_element( &mut self, id: Self::Id, ) -> impl Future, Self::GetError>> + Send { - let mut v = vec![T::default(); self.provider.dim]; + let v_len = if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + quantizer.canonical_bytes() + } else { + self.provider.dim * mem::size_of::() + }; + + let mut v = match Poly::broadcast(0u8, v_len, AlignToEight) { + Ok(v) => DynVector::new(v), + Err(e) => return future::ready(Err(GarnetProviderError::AllocFailed(e))), + }; if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + if self.quantized { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); } else { - return future::ready(Err(GarnetError::Read.into())); - }; - v.copy_from_slice(&guard); - return future::ready(Ok(v)); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); + } } - if !self - .provider - .callbacks - .read_single_iid(self.context.term(Term::Vector), id, &mut v) - { + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.provider.callbacks.read_single_iid(ctx, id, &mut v) { return future::ready(Err(GarnetError::Read.into())); } @@ -555,29 +948,102 @@ impl Accessor for FullAccessor<'_, T> { } } -impl BuildDistanceComputer for FullAccessor<'_, T> { - type DistanceComputer = T::Distance; +/// Wrapper for full precision distance computer. +pub struct FullPrecisionDistance(T::Distance); + +impl DynDistanceComputer for FullPrecisionDistance { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + self.0.evaluate_similarity( + bytemuck::cast_slice::(a), + bytemuck::cast_slice::(b), + ) + } +} + +/// Wrapper for full precision query computer. +pub struct FullPrecisionQueryDistance(T::QueryDistance); + +impl DynQueryComputer for FullPrecisionQueryDistance { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + self.0.evaluate_similarity(bytemuck::cast_slice::(a)) + } +} + +/// Type-erased distance computer. +pub struct GarnetDistanceComputer { + inner: Box, +} + +impl GarnetDistanceComputer { + pub fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} +impl DistanceFunction<&[u8], &[u8]> for GarnetDistanceComputer { + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.inner.evaluate_similarity(x, y) + } +} + +/// Type-erased query computer. +pub struct GarnetQueryComputer { + inner: Box, +} + +impl GarnetQueryComputer { + pub fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} + +impl PreprocessedDistanceFunction<&[u8]> for GarnetQueryComputer { + fn evaluate_similarity(&self, changing: &[u8]) -> f32 { + self.inner.evaluate_similarity(changing) + } +} + +impl BuildDistanceComputer for DynamicAccessor<'_, T> { + type DistanceComputer = GarnetDistanceComputer; type DistanceComputerError = GarnetProviderError; fn build_distance_computer( &self, ) -> Result { - Ok(T::distance( - self.provider.metric_type, - Some(self.provider.dim), - )) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer.distance_computer()?) + } else { + Ok(GarnetDistanceComputer::new(FullPrecisionDistance::( + T::distance(self.provider.metric_type, Some(self.provider.dim)), + ))) + } } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { - type QueryComputer = T::QueryDistance; +impl BuildQueryComputer<&[T]> for DynamicAccessor<'_, T> { + type QueryComputer = GarnetQueryComputer; type QueryComputerError = GarnetProviderError; fn build_query_computer( &self, from: &[T], ) -> Result { - Ok(T::query_distance(from, self.provider.metric_type)) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer + .query_computer(bytemuck::cast_slice::(from)) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?) + } else { + Ok(GarnetQueryComputer::new(FullPrecisionQueryDistance::( + T::query_distance(from, self.provider.metric_type), + ))) + } } } @@ -594,61 +1060,106 @@ impl<'a, T> Reborrow<'a> for Escape { } } -type WorkingSet = workingset::Map>; -type WorkingSetView<'a, T> = workingset::map::View<'a, u32, Escape>; +pub struct WorkingSet { + map: workingset::Map>, + contains_unquantized: bool, +} + +impl WorkingSet { + pub fn new(capacity_type: workingset::map::Capacity, capacity: usize) -> Self { + Self { + map: workingset::map::Builder::new(capacity_type).build(capacity), + contains_unquantized: false, + } + } +} + +impl Deref for WorkingSet { + type Target = workingset::Map>; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +impl DerefMut for WorkingSet { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.map + } +} + +type WorkingSetView<'a> = workingset::map::View<'a, u32, Escape>; -impl workingset::Fill> for FullAccessor<'_, T> { +impl workingset::Fill for DynamicAccessor<'_, T> { type Error = GarnetProviderError; type View<'a> - = WorkingSetView<'a, T> + = WorkingSetView<'a> where Self: 'a; async fn fill<'a, Itr>( &'a mut self, - set: &'a mut WorkingSet, + set: &'a mut WorkingSet, itr: Itr, ) -> Result, Self::Error> where Itr: ExactSizeIterator + Clone + Send + Sync, Self: 'a, { + if !self.quantized { + // Mark this working set as having full vectors if we're not quantizing yet + set.contains_unquantized = true; + } else if set.contains_unquantized { + // Working set is polluted by full vectors, it must be cleared + set.clear(); + set.contains_unquantized = false; + } + // Evict items from the working set to make room if needed. set.prepare(itr.clone()); self.filtered_ids.clear(); for id in itr { if id == 0 { - if let Entry::Vacant(e) = set.entry(id) { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r - } else { - return Err(GarnetError::Read.into()); - }; - e.insert(Escape((&**guard).into())); - } + if self.quantized + && let Entry::Vacant(e) = set.entry(id) + { + if let Some(guard) = self.provider.start_point_quant_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } + } else if let Entry::Vacant(e) = set.entry(id) { + if let Some(guard) = self.provider.start_point_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } + } else { + continue; + }; } else if !set.contains_key(&id) { self.filtered_ids.push(4); self.filtered_ids.push(id); } } + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + if !self.filtered_ids.is_empty() { - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |id, v| { + self.provider + .callbacks + .read_multi_lpiid(ctx, &self.filtered_ids, |id, v| { set.insert(self.filtered_ids[id as usize * 2 + 1], Escape(v.into())); - }, - ); + }); } Ok(set.view()) } } -pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut FullAccessor<'p, T>); +pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut DynamicAccessor<'p, T>); impl HasId for DelegateNeighborAccessor<'_, '_, T> { type Id = u32; @@ -668,7 +1179,7 @@ impl NeighborAccessor for DelegateNeighborAccessor<'_, '_, T> { } } -impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for FullAccessor<'p, T> { +impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for DynamicAccessor<'p, T> { type Delegate = DelegateNeighborAccessor<'p, 'a, T>; fn delegate_neighbor(&'a mut self) -> Self::Delegate { @@ -754,21 +1265,21 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> #[derive(Debug, Default, Clone, Copy)] pub struct CopyExternalIds; -impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> +impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> for CopyExternalIds { type Error = GarnetProviderError; fn post_process( &self, - accessor: &mut FullAccessor<'a, T>, + accessor: &mut DynamicAccessor<'a, T>, _query: &[T], - _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send where - I: Iterator as HasId>::Id>> + Send, + I: Iterator as HasId>::Id>> + Send, B: SearchOutputBuffer + Send + ?Sized, { let initial = output.current_len(); @@ -782,41 +1293,124 @@ impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], Garn break; } } + let count = output.current_len() - initial; future::ready(Ok(count)) } } -impl SearchStrategy, &[T]> for FullPrecision { - type SearchAccessor<'a> = FullAccessor<'a, T>; +/// A [`SearchPostProcess`] base object that reranks quantized vectors by full precision distance. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl<'a, 'b, T: VectorRepr> SearchPostProcessStep, &'b [T], GarnetId> + for Rerank +{ + type Error + = GarnetProviderError + where + NextError: diskann::error::StandardError; + + type NextAccessor = DynamicAccessor<'a, T>; + + async fn post_process_step( + &self, + next: &Next, + accessor: &mut DynamicAccessor<'a, T>, + query: &'b [T], + computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result> + where + I: Iterator as HasId>::Id>> + Send, + B: SearchOutputBuffer + Send + ?Sized, + Next: SearchPostProcess + Sync, + { + if !accessor.quantized { + // Skip reranking if the accessor if working with full precision + return next + .post_process(accessor, query, computer, candidates, output) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))); + } + + let provider = &accessor.provider; + let f = T::distance(provider.metric_type, Some(provider.dim)); + let mut v = Poly::broadcast(0u8, provider.dim * mem::size_of::(), AlignToEight)?; + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if !provider.vector_iid_exists(accessor.context, n.id) { + None + } else if provider.callbacks.read_single_iid( + accessor.context.term(Term::Vector), + n.id, + &mut v, + ) { + Some(( + n.id, + f.evaluate_similarity(query, bytemuck::cast_slice::(&v)), + )) + } else { + None + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + next.post_process( + accessor, + query, + computer, + reranked.into_iter().map(|(id, d)| Neighbor::new(id, d)), + output, + ) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))) + } +} + +impl SearchStrategy, &[T]> for DynamicQuantization { + type SearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; - type QueryComputer = T::QueryDistance; + type QueryComputer = GarnetQueryComputer; fn search_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, true)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } } -impl DefaultPostProcessor, &[T], GarnetId> for FullPrecision { - default_post_processor!(glue::Pipeline); +impl DefaultPostProcessor, &[T], GarnetId> + for DynamicQuantization +{ + default_post_processor!( + glue::Pipeline> + ); } -impl PruneStrategy> for FullPrecision { - type PruneAccessor<'a> = FullAccessor<'a, T>; +impl PruneStrategy> for DynamicQuantization { + type PruneAccessor<'a> = DynamicAccessor<'a, T>; type PruneAccessorError = GarnetProviderError; - type DistanceComputer<'a> = T::Distance; - type WorkingSet = WorkingSet; + type DistanceComputer<'a> = GarnetDistanceComputer; + type WorkingSet = WorkingSet; fn prune_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { @@ -824,11 +1418,11 @@ impl PruneStrategy> for FullPrecision { // cache and persist up to `capacity` items across uses of the working set. // // This reuse is limited to a single collection of backedges for an insert or multi-insert. - workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity) + WorkingSet::new(workingset::map::Capacity::Default, capacity) } } -impl InsertStrategy, &[T]> for FullPrecision { +impl InsertStrategy, &[T]> for DynamicQuantization { type PruneStrategy = Self; fn insert_search_accessor<'a>( @@ -836,7 +1430,8 @@ impl InsertStrategy, &[T]> for FullPrecision { provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -844,13 +1439,13 @@ impl InsertStrategy, &[T]> for FullPrecision { } } -impl InplaceDeleteStrategy> for FullPrecision { +impl InplaceDeleteStrategy> for DynamicQuantization { type DeleteElement<'a> = &'a [T]; type DeleteElementGuard = Box<[T]>; type DeleteElementError = GarnetProviderError; type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = FullAccessor<'a, T>; + type DeleteSearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchPostProcessor = glue::CopyIds; type SearchStrategy = Self; diff --git a/diskann-garnet/src/quantization.rs b/diskann-garnet/src/quantization.rs new file mode 100644 index 000000000..265bda056 --- /dev/null +++ b/diskann-garnet/src/quantization.rs @@ -0,0 +1,317 @@ +use std::{num::NonZero, sync::RwLock}; + +use diskann::utils::VectorRepr; +use diskann_quantization::{ + CompressInto, + algorithms::{Transform, TransformKind, transforms::NewTransformError}, + alloc::{GlobalAllocator, ScopedAllocator}, + minmax, + num::Positive, + spherical::{ + self, Data, PreScale, SphericalQuantizer, SupportedMetric, + iface::{self, Opaque, OpaqueMut, Quantizer}, + }, +}; +use diskann_utils::views::MatrixView; +use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; +use thiserror::Error; + +use crate::provider::{GarnetDistanceComputer, GarnetQueryComputer}; + +#[derive(Debug, Error)] +pub enum GarnetQuantizerError { + #[error("Quantization training error: {0}")] + Training(Box), + #[error("Quantization alloc error: {0}")] + Alloc(Box), + #[error("Query computer error: {0}")] + QueryComputer(Box), + #[error("Binary quantization error: {0}")] + Compression(Box), + #[error("No quantizer found")] + NoQuantizer, + #[error("Got zero dimension")] + ZeroDim, + #[error("Transform error: {0}")] + BadTransform(#[from] NewTransformError), +} + +/// Quantizer trait that all diskann-garnet quantizers must implement +pub trait GarnetQuantizer: Send + Sync { + /// Check whether the quantizer is ready to be used + fn is_prepared(&self) -> bool; + /// Returns the number of vectors needed before the quantizer can be trained + fn required_vectors(&self) -> usize; + /// Returns the size of a quantized vector + fn canonical_bytes(&self) -> usize; + /// Train the quantizer. + /// Each row of the matrix will be a vector + fn train(&self, metric: Metric, data: MatrixView) -> Result<(), GarnetQuantizerError>; + /// Quantize a vector + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError>; + /// Returns a distance computer for comparing quantized vectors + fn distance_computer(&self) -> Result; + /// Returns a query computer for comparing distances to a particular query + fn query_computer(&self, query: &[f32]) -> Result; +} + +/// Type-erased distance computer +pub trait DynDistanceComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32; +} + +/// Type-erased query computer +pub trait DynQueryComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8]) -> f32; +} + +/// Spherical 1-bit quantization. +/// +/// This quantizer corresponds to `BIN` quantizer in the Redis protocol. It requires hundreds of +/// vectors (but not thousands) for training. Quantized vectors have 1 bit per dimension plus up +/// to 6 bytes of overhead. +pub struct Spherical1Bit { + dim: usize, + inner: RwLock>>, +} + +impl Spherical1Bit { + pub fn new(dim: usize) -> Self { + Self { + dim, + inner: RwLock::new(None), + } + } +} + +impl GarnetQuantizer for Spherical1Bit { + fn is_prepared(&self) -> bool { + self.inner.read().unwrap().is_some() + } + + fn required_vectors(&self) -> usize { + 1000 + } + + fn canonical_bytes(&self) -> usize { + Data::<1, GlobalAllocator>::canonical_bytes(self.dim) + } + + fn train( + &self, + metric_type: Metric, + data: MatrixView, + ) -> Result<(), GarnetQuantizerError> { + let mut rng = rand::rng(); + let quantizer = SphericalQuantizer::train( + data.as_view(), + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + SupportedMetric::try_from(metric_type) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?, + PreScale::ReciprocalMeanNorm, + &mut rng, + GlobalAllocator, + ) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?; + + let mut inner = self.inner.write().unwrap(); + *inner = Some( + spherical::iface::Impl::<1>::new(quantizer) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?, + ); + + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + spherical::iface::Quantizer::::compress( + quantizer, + v, + OpaqueMut::new(into), + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn distance_computer(&self) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .distance_computer(GlobalAllocator) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?; + Ok(GarnetDistanceComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn query_computer(&self, query: &[f32]) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .fused_query_computer( + query, + iface::QueryLayout::FullPrecision, + true, + GlobalAllocator, + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?; + Ok(GarnetQueryComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } +} + +impl DynDistanceComputer for iface::DistanceComputer { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + , Opaque<'_>, _>>::evaluate_similarity( + self, + Opaque::new(a), + Opaque::new(b), + ) + .unwrap() + } +} + +impl DynQueryComputer for iface::QueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + , _>>::evaluate_similarity( + self, + Opaque::new(a), + ) + .unwrap() + } +} + +/// 8-bit scalar quantizer using MinMax +/// +/// This quantizer requires no training at all and is usable immediately on the first first. Each +/// quantized vector has 8 bits per dimension and 20 bytes of overhead. +pub struct MinMax8Bit { + dim: usize, + metric: Metric, + inner: minmax::MinMaxQuantizer, +} + +impl MinMax8Bit { + pub fn new(dim: usize, metric: Metric) -> Result { + let dim = match NonZero::new(dim) { + Some(d) => d, + None => return Err(GarnetQuantizerError::ZeroDim), + }; + let mut rng = rand::rng(); + let transform = Transform::new( + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + dim, + Some(&mut rng), + GlobalAllocator, + )?; + let grid_scale = Positive::new(1.0).unwrap(); + + Ok(Self { + dim: dim.get(), + metric, + inner: minmax::MinMaxQuantizer::new(transform, grid_scale), + }) + } +} + +impl GarnetQuantizer for MinMax8Bit { + fn is_prepared(&self) -> bool { + true + } + + fn required_vectors(&self) -> usize { + 0 + } + + fn canonical_bytes(&self) -> usize { + minmax::Data::<8>::canonical_bytes(self.dim) + } + + fn train(&self, _metric: Metric, _data: MatrixView) -> Result<(), GarnetQuantizerError> { + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let into = minmax::DataMutRef::<8>::from_canonical_front_mut(into, self.dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + self.inner + .compress_into(v, into) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } + + fn distance_computer(&self) -> Result { + let computer = GarnetDistanceComputer::new( + ::distance( + self.metric, + Some(self.dim), + ), + ); + Ok(computer) + } + + fn query_computer(&self, query: &[f32]) -> Result { + let computer = GarnetQueryComputer::new(MinMax8BitQueryComputer::new( + &self.inner, + query, + self.dim, + self.metric, + )?); + Ok(computer) + } +} + +impl DynDistanceComputer for diskann_providers::common::FnPtr { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + let b = diskann_providers::common::MinMax8::from_bytes(b); + >::evaluate_similarity(self, a, b) + } +} +struct MinMax8BitQueryComputer( + diskann_providers::common::BufferedFnPtr, +); + +impl MinMax8BitQueryComputer { + fn new( + quantizer: &minmax::MinMaxQuantizer, + query: &[f32], + dim: usize, + metric: Metric, + ) -> Result { + let mut v = vec![Default::default(); minmax::Data::<8>::canonical_bytes(dim)]; + quantizer + .compress_into( + query, + minmax::DataMutRef::<8>::from_canonical_front_mut(&mut v, dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?, + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + let inner = diskann_providers::common::MinMax8::query_distance( + diskann_providers::common::MinMax8::from_bytes(&v), + metric, + ); + Ok(Self(inner)) + } +} + +impl DynQueryComputer for MinMax8BitQueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + self.0.evaluate_similarity(a) + } +} diff --git a/diskann-garnet/src/test_utils.rs b/diskann-garnet/src/test_utils.rs index 26b72f774..ba9caed10 100644 --- a/diskann-garnet/src/test_utils.rs +++ b/diskann-garnet/src/test_utils.rs @@ -131,6 +131,7 @@ mod tests { use std::collections::HashMap; use crate::{ + VectorQuantType, garnet::{Context, Term}, test_utils::Store, }; @@ -203,7 +204,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); @@ -215,7 +223,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); diff --git a/diskann-providers/src/common/minmax_repr.rs b/diskann-providers/src/common/minmax_repr.rs index 73c19b10d..19dc8ec96 100644 --- a/diskann-providers/src/common/minmax_repr.rs +++ b/diskann-providers/src/common/minmax_repr.rs @@ -117,6 +117,12 @@ impl num_traits::FromPrimitive for MinMaxElement { impl_from_primitive!(NBITS, from_u32, u32); } +impl MinMaxElement { + pub fn from_bytes(x: &[u8]) -> &[Self] { + bytemuck::must_cast_slice(x) + } +} + //////////////////////// /// DistanceProvider /// //////////////////////// diff --git a/diskann-providers/src/common/mod.rs b/diskann-providers/src/common/mod.rs index b5c4a42a4..ec39946f7 100644 --- a/diskann-providers/src/common/mod.rs +++ b/diskann-providers/src/common/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ mod minmax_repr; -pub use minmax_repr::{MinMax4, MinMax8, MinMaxElement}; +pub use minmax_repr::{BufferedFnPtr, FnPtr, MinMax4, MinMax8, MinMaxElement}; mod ignore_lock_poison; pub use ignore_lock_poison::IgnoreLockPoison; diff --git a/vectorset/src/loader.rs b/vectorset/src/loader.rs index f1e62444e..9e0f7a856 100644 --- a/vectorset/src/loader.rs +++ b/vectorset/src/loader.rs @@ -4,22 +4,14 @@ */ use anyhow::{Result, anyhow}; -use std::{ - collections::HashMap, - marker::PhantomData, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{collections::HashMap, marker::PhantomData, mem, path::Path, sync::Arc}; use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; const BATCH_SIZE: usize = 1024; /// Loads base or query vectors from a given path and allow iteration over them. -#[allow(dead_code)] pub struct DatasetLoader { file: Mutex<(File, usize)>, - path: PathBuf, num_vectors: usize, dim: usize, type_: PhantomData, @@ -36,7 +28,6 @@ impl DatasetLoader { Ok(Arc::new(Self { file: Mutex::new((file, 0)), - path, num_vectors, dim, type_: PhantomData, @@ -60,35 +51,33 @@ impl DatasetLoader { /// Returns (count, first_id) where `count` is the number of vectors loaded /// and `first_id` is the id of the first vector. pub async fn next(&self, buffer: &mut Vec) -> Result<(usize, usize)> { + let mut f = self.file.lock().await; + let mut count; let mut first_id; loop { - { - let mut f = self.file.lock().await; - - first_id = f.1; - if f.1 >= self.num_vectors { - buffer.clear(); - return Ok((0, first_id)); - } - - buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); + first_id = f.1; + if f.1 >= self.num_vectors { + buffer.clear(); + return Ok((0, first_id)); + } - let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); - while let bytes_read = f.0.read(buf).await? - && bytes_read > 0 - { - buf = &mut buf[bytes_read..]; - } + buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); - let elements_left = buf.len() / mem::size_of::(); - if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { - return Err(anyhow!("unexpected EOF")); - } + let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); + while let bytes_read = f.0.read(buf).await? + && bytes_read > 0 + { + buf = &mut buf[bytes_read..]; + } - count = BATCH_SIZE - elements_left / self.dim; + let elements_left = buf.len() / mem::size_of::(); + if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { + return Err(anyhow!("unexpected EOF")); } + count = BATCH_SIZE - elements_left / self.dim; + if count == 0 { continue; } @@ -96,7 +85,6 @@ impl DatasetLoader { break; } - let mut f = self.file.lock().await; f.1 += count; Ok((count, first_id)) diff --git a/vectorset/src/main.rs b/vectorset/src/main.rs index 4ba864d56..8333e4f86 100644 --- a/vectorset/src/main.rs +++ b/vectorset/src/main.rs @@ -107,6 +107,14 @@ struct IngestArgs { #[arg(long)] no_header_with_dim: Option, + /// Quantizer + #[arg(long)] + quantizer: Option, + + /// Metric + #[arg(long)] + metric: Option, + /// Paths to base vectors base_path: PathBuf, } @@ -161,6 +169,50 @@ enum DataType { Float32, } +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum Quantizer { + None, + Bin, + Q8, +} + +impl ToRedisArgs for Quantizer { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + Quantizer::None => b"NOQUANT".as_slice(), + Quantizer::Bin => b"BIN".as_slice(), + Quantizer::Q8 => b"Q8".as_slice(), + }; + out.write_arg(q); + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum DistanceMetric { + L2, + Cosine, + CosineNormalized, + InnerProduct, +} + +impl ToRedisArgs for DistanceMetric { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + DistanceMetric::L2 => b"L2".as_slice(), + DistanceMetric::Cosine => b"COSINE".as_slice(), + DistanceMetric::CosineNormalized => b"COSINE_NORMALIZED".as_slice(), + DistanceMetric::InnerProduct => b"INNER_PRODUCT".as_slice(), + }; + out.write_arg(q); + } +} + struct VectorId(u32); impl ToRedisArgs for VectorId { @@ -362,6 +414,8 @@ async fn ingest( let degree = args.degree; let mut cred = cred.clone(); let data_type = opts.data_type; + let quantizer = args.quantizer; + let metric = args.metric; tasks.spawn(async move { let mut buf = vec![T::zeroed(); ds.batch_size() * ds.dim()]; @@ -387,6 +441,7 @@ async fn ingest( let element = VectorId((first_id + i) as u32); let buf_start = i * ds.dim(); let buf_end = buf_start + ds.dim(); + pipeline.cmd("VADD").arg(&vset); match data_type { @@ -407,10 +462,14 @@ async fn ingest( pipeline.arg(b"XPREQ8"); } DataType::Float32 => { - pipeline.arg(b"NOQUANT"); + pipeline.arg(quantizer); } } + if let Some(metric) = metric { + pipeline.arg(b"XDISTANCE_METRIC").arg(metric); + } + pipeline .arg(b"EF") .arg(l_build.to_string().as_bytes()) From 6f397cb33175d88a69cb3671059ef9364f2976b5 Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Mon, 11 May 2026 15:17:20 -0500 Subject: [PATCH 4/5] address feedback --- diskann-garnet/Cargo.toml | 2 +- diskann-garnet/src/alloc.rs | 2 +- diskann-garnet/src/dyn_index.rs | 2 +- diskann-garnet/src/ffi_recall_tests.rs | 12 +- diskann-garnet/src/ffi_tests.rs | 140 ++++++++++--------- diskann-garnet/src/fsm.rs | 97 +++++++------- diskann-garnet/src/garnet.rs | 179 +++++++++++++++++-------- diskann-garnet/src/labels.rs | 6 +- diskann-garnet/src/lib.rs | 36 ++--- diskann-garnet/src/provider.rs | 156 +++++++++++---------- diskann-garnet/src/quantization.rs | 16 +-- diskann-garnet/src/test_utils.rs | 40 +++--- rfcs/00000-quantizer-bootstrap.md | 10 +- 13 files changed, 387 insertions(+), 311 deletions(-) diff --git a/diskann-garnet/Cargo.toml b/diskann-garnet/Cargo.toml index b81049947..35ee4f3d2 100644 --- a/diskann-garnet/Cargo.toml +++ b/diskann-garnet/Cargo.toml @@ -14,7 +14,7 @@ bytemuck.workspace = true crossbeam = "0.8.4" dashmap = { workspace = true, features = ["inline"] } diskann.workspace = true -diskann-quantization.workspace = true +diskann-quantization = { workspace = true, features = ["flatbuffers"] } diskann-providers.workspace = true diskann-utils.workspace = true diskann-vector.workspace = true diff --git a/diskann-garnet/src/alloc.rs b/diskann-garnet/src/alloc.rs index 16e3cfb3c..17db0775e 100644 --- a/diskann-garnet/src/alloc.rs +++ b/diskann-garnet/src/alloc.rs @@ -9,7 +9,7 @@ use std::ptr::NonNull; /// Custom allocator that over-aligns to 8 bytes. This is needed since Garnet will hand us byte slices for f32 data /// that may be unaligned, so we need an allocator to make owned, aligned byte containers. #[derive(Debug, Clone, Copy)] -pub struct AlignToEight; +pub(crate) struct AlignToEight; unsafe impl AllocatorCore for AlignToEight { #[inline] diff --git a/diskann-garnet/src/dyn_index.rs b/diskann-garnet/src/dyn_index.rs index 77153dabe..300201e68 100644 --- a/diskann-garnet/src/dyn_index.rs +++ b/diskann-garnet/src/dyn_index.rs @@ -22,7 +22,7 @@ use std::sync::Arc; /// Type-erased version of `DiskANNIndex`. /// All vector data is passed as untyped byte slices. -pub trait DynIndex: Send + Sync { +pub(crate) trait DynIndex: Send + Sync { fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()>; fn set_attributes(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()>; diff --git a/diskann-garnet/src/ffi_recall_tests.rs b/diskann-garnet/src/ffi_recall_tests.rs index fdfaac349..a70ed0ce5 100644 --- a/diskann-garnet/src/ffi_recall_tests.rs +++ b/diskann-garnet/src/ffi_recall_tests.rs @@ -25,7 +25,7 @@ mod tests { let vector_bytes: &[u8] = bytemuck::cast_slice(vector); unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -34,7 +34,7 @@ mod tests { vector.len(), b"".as_ptr(), 0, - ) + ) > 0 } } @@ -156,14 +156,14 @@ mod tests { ) -> f64 { store.clear(); let callbacks = store.callbacks(); - let ctx = Context(0); + let ctx = Context::new(0); let reduce_dimensions = 0; let l_build = 100; let max_degree = 32; let index_ptr = unsafe { create_index( - ctx.0, + ctx.get(), dimensions, reduce_dimensions, VectorQuantType::NoQuant, @@ -204,7 +204,7 @@ mod tests { let count = unsafe { search_vector( - ctx.0, + ctx.get(), index_ptr, VectorValueType::FP32, query_bytes.as_ptr(), @@ -237,7 +237,7 @@ mod tests { total_expected += expected_ids.len(); } - unsafe { drop_index(ctx.0, index_ptr) }; + unsafe { drop_index(ctx.get(), index_ptr) }; total_matches as f64 / total_expected as f64 } diff --git a/diskann-garnet/src/ffi_tests.rs b/diskann-garnet/src/ffi_tests.rs index 6fc2022a4..efc5ce308 100644 --- a/diskann-garnet/src/ffi_tests.rs +++ b/diskann-garnet/src/ffi_tests.rs @@ -30,7 +30,7 @@ mod tests { store.clear(); let callbacks = store.callbacks(); - let ctx = Context(0); + let ctx = Context::new(0); let dim: u32 = 2; let reduce_dim = 0; @@ -40,7 +40,7 @@ mod tests { let index_ptr = unsafe { create_index( - ctx.0, + ctx.get(), dim, reduce_dim, quant_type, @@ -64,7 +64,7 @@ mod tests { assert!(!index_ptr.is_null()); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -105,7 +105,7 @@ mod tests { valid_metric ); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } } @@ -126,9 +126,9 @@ mod tests { let attributes_bytes = b"wololo"; let attributes_len = attributes_bytes.len(); - let result: bool = unsafe { + let result = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -140,25 +140,26 @@ mod tests { ) }; - assert!(result); + assert!(result > 0); // Confirm vector exists using FFI function - let exists = - unsafe { check_external_id_valid(ctx.0, index_ptr, id_bytes.as_ptr(), id_bytes.len()) }; + let exists = unsafe { + check_external_id_valid(ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len()) + }; assert!(exists); - let mut cardinality = unsafe { card(ctx.0, index_ptr) }; + let mut cardinality = unsafe { card(ctx.get(), index_ptr) }; assert_eq!(cardinality, 1); - let removed = unsafe { remove(ctx.0, index_ptr, id_bytes.as_ptr(), id_bytes.len()) }; + let removed = unsafe { remove(ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len()) }; assert!(removed); // Currently we're not tracking deletions for cardinality, so it should still be 1 - cardinality = unsafe { card(ctx.0, index_ptr) }; + cardinality = unsafe { card(ctx.get(), index_ptr) }; assert_eq!(cardinality, 1); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -175,7 +176,7 @@ mod tests { let attributes1 = b"wololo"; let set_attribute_result1 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -188,9 +189,9 @@ mod tests { let vector: [f32; 2] = [1.0, 2.0]; let vector_bytes = bytemuck::cast_slice(&vector); - let result: bool = unsafe { + let result = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -201,13 +202,13 @@ mod tests { attributes1.len(), ) }; - assert!(result); + assert!(result > 0); // Set attributes after insertion let attributes2 = b"new_attributes"; let set_attribute_result2 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -220,7 +221,7 @@ mod tests { // Set attributes to empty using null ptr let set_attribute_result3 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -233,7 +234,7 @@ mod tests { let empty_attribute = b""; let set_attribute_result4 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -244,7 +245,7 @@ mod tests { assert!(set_attribute_result4); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -269,14 +270,14 @@ mod tests { // Check external_id exists with EID 1 (should not exist initially) let exists1 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(!exists1, "EID 1 should not exist initially"); // Add vector with EID 1 let insert_result1 = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len(), @@ -287,28 +288,29 @@ mod tests { attributes_bytes.len(), ) }; - assert!(insert_result1, "Insert with EID 1 should succeed"); + assert!(insert_result1 > 0, "Insert with EID 1 should succeed"); // Check external_id exists with EID 1 (should exist after insert) let exists2 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(exists2, "EID 1 should exist after insert"); // Remove vector with EID 1 - let removed = unsafe { remove(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; + let removed = + unsafe { remove(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(removed, "Remove with EID 1 should succeed"); // Check external_id exists with EID 1 (should not exist after removal) let exists3 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(!exists3, "EID 1 should not exist after removal"); // Add vector with EID 2 let insert_result2 = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len(), @@ -319,22 +321,22 @@ mod tests { attributes_bytes.len(), ) }; - assert!(insert_result2, "Insert with EID 2 should succeed"); + assert!(insert_result2 > 0, "Insert with EID 2 should succeed"); // Check external_id exists with EID 2 (should exist after insert) let exists4 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len()) }; assert!(exists4, "EID 2 should exist after insert"); // Check external_id exists with EID 1 (should still not exist) let exists5 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(!exists5, "EID 1 should still not exist"); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -352,33 +354,37 @@ mod tests { let id1_bytes = bytemuck::bytes_of(&id1); let id2_bytes = bytemuck::bytes_of(&id2); - assert!(unsafe { - insert( - ctx.0, - index_ptr, - id1_bytes.as_ptr(), - id1_bytes.len(), - VectorValueType::XB8, - v1.as_ptr(), - v1.len(), - b"".as_ptr(), - 0, - ) - }); + assert!( + unsafe { + insert( + ctx.get(), + index_ptr, + id1_bytes.as_ptr(), + id1_bytes.len(), + VectorValueType::XB8, + v1.as_ptr(), + v1.len(), + b"".as_ptr(), + 0, + ) + } > 0 + ); - assert!(unsafe { - insert( - ctx.0, - index_ptr, - id2_bytes.as_ptr(), - id2_bytes.len(), - VectorValueType::XB8, - v2.as_ptr(), - v2.len(), - b"".as_ptr(), - 0, - ) - }); + assert!( + unsafe { + insert( + ctx.get(), + index_ptr, + id2_bytes.as_ptr(), + id2_bytes.len(), + VectorValueType::XB8, + v2.as_ptr(), + v2.len(), + b"".as_ptr(), + 0, + ) + } > 0 + ); let qv = &[0u8, 0u8]; let mut output_id_buffer = vec![0u8; 2 * (mem::size_of::() + mem::size_of::())]; @@ -386,7 +392,7 @@ mod tests { let count = unsafe { search_vector( - ctx.0, + ctx.get(), index_ptr, VectorValueType::XB8, qv.as_ptr(), @@ -442,7 +448,7 @@ mod tests { } unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -457,7 +463,7 @@ mod tests { let vector_bytes = bytemuck::cast_slice(vector); unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -466,7 +472,7 @@ mod tests { vector.len(), b"".as_ptr(), 0, - ) + ) > 0 } } @@ -489,7 +495,7 @@ mod tests { let count = unsafe { search_vector( - ctx.0, + ctx.get(), index_ptr, VectorValueType::FP32, query_bytes.as_ptr(), @@ -543,7 +549,7 @@ mod tests { // Closest to [1,0] should be id=10 (exact match) assert_eq!(ids[0], 10); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -566,7 +572,7 @@ mod tests { assert!(ids.len() >= 2, "should return at least 2 matching vectors"); assert_eq!(ids[0], 10, "closest should still be id=10"); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -595,7 +601,7 @@ mod tests { "filtered vector EID 20 should be in results" ); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -617,7 +623,7 @@ mod tests { assert_eq!(ids_null, ids_empty); assert_eq!(dists_null, dists_empty); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } } diff --git a/diskann-garnet/src/fsm.rs b/diskann-garnet/src/fsm.rs index 6b6e56130..6837e5c46 100644 --- a/diskann-garnet/src/fsm.rs +++ b/diskann-garnet/src/fsm.rs @@ -39,14 +39,14 @@ const FAST_SIZE: usize = 1024; const FSM_KEY_PREFIX: u32 = u32::from_be_bytes(*b"_fsm"); #[derive(Debug, Error, PartialEq)] -pub enum FsmError { +pub(crate) enum FsmError { #[error("Garnet operation failed")] Garnet(#[from] GarnetError), #[error("requested ID is out of range {0}")] IdOutOfRange(u32), } -pub struct FreeSpaceMap { +pub(crate) struct FreeSpaceMap { /// Garnet callbacks for reading/writing FSM keys callbacks: Callbacks, /// A flag to signal whether there are free IDs in the FSM. @@ -67,7 +67,7 @@ pub struct FreeSpaceMap { } impl FreeSpaceMap { - pub fn new(ctx: Context, callbacks: Callbacks) -> Result { + pub(crate) fn new(ctx: &Context, callbacks: Callbacks) -> Result { let has_free_ids = AtomicBool::new(false); let fast_free_list = ArrayQueue::new(FAST_SIZE); let max_block = RwLock::new(u32::MAX); @@ -88,7 +88,7 @@ impl FreeSpaceMap { let block_key = Self::block_key(0); if this .callbacks - .exists_wid(ctx.term(Term::Metadata), block_key) + .exists_wid(&ctx.term(Term::Metadata), block_key) { this.load_state(ctx)?; } else { @@ -101,11 +101,11 @@ impl FreeSpaceMap { } /// Load all state from Garnet by scanning the FSM blocks. - fn load_state(&mut self, ctx: Context) -> Result<(), FsmError> { + fn load_state(&mut self, ctx: &Context) -> Result<(), FsmError> { let mut max_block_id = 0; while self .callbacks - .exists_wid(ctx.term(Term::Metadata), Self::block_key(max_block_id)) + .exists_wid(&ctx.term(Term::Metadata), Self::block_key(max_block_id)) { max_block_id += 1; } @@ -119,7 +119,7 @@ impl FreeSpaceMap { if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { break; } @@ -157,20 +157,20 @@ impl FreeSpaceMap { } /// Mark an ID as free. - pub fn mark_free(&self, ctx: Context, id: u32) -> Result<(), FsmError> { + pub(crate) fn mark_free(&self, ctx: &Context, id: u32) -> Result<(), FsmError> { // We don't care about the changed status on free. self.mark_id(ctx, id, false).map(|_| ()) } /// Mark an ID as occupied. - fn mark_used(&self, ctx: Context, id: u32) -> Result { + fn mark_used(&self, ctx: &Context, id: u32) -> Result { self.mark_id(ctx, id, true) } /// Mark an ID according to value (true = used, false = free). /// Side effects only happen if the value actually changed. The return value reflects whether /// data changed or not. - fn mark_id(&self, ctx: Context, id: u32, used: bool) -> Result { + fn mark_id(&self, ctx: &Context, id: u32, used: bool) -> Result { if id >= self.next_id.load(Ordering::Acquire) { return Err(FsmError::IdOutOfRange(id)); } @@ -185,7 +185,7 @@ impl FreeSpaceMap { let mut changed = false; if !self.callbacks.rmw_wid( - ctx.term(Term::Metadata), + &ctx.term(Term::Metadata), block_key, BLOCK_SIZE_BYTES, |data: &mut [u8]| changed = update_status(used, &mut data[byte_idx], bit_idx), @@ -211,7 +211,7 @@ impl FreeSpaceMap { Ok(changed) } - pub fn is_free(&self, ctx: Context, id: u32) -> Result { + pub(crate) fn is_free(&self, ctx: &Context, id: u32) -> Result { if id >= self.next_id.load(Ordering::Acquire) { return Err(FsmError::IdOutOfRange(id)); } @@ -227,7 +227,7 @@ impl FreeSpaceMap { if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { return Err(FsmError::Garnet(GarnetError::Read)); } @@ -240,7 +240,7 @@ impl FreeSpaceMap { /// Return a a new ID. /// This may be a a fresh ID larger than all the others, or it may be a reused ID that /// previously belonged to a deleted element. The returned ID is marked as used. - pub fn next_id(&self, ctx: Context) -> Result { + pub(crate) fn next_id(&self, ctx: &Context) -> Result { if self.can_reuse() && self.has_free_ids.load(Ordering::Acquire) { // We retry reusing a freed ID until there are none or we get one and marking it used // succeeds in changing the value. @@ -293,11 +293,11 @@ impl FreeSpaceMap { Ok(id) } - pub fn max_id(&self) -> u32 { + pub(crate) fn max_id(&self) -> u32 { self.next_id.load(Ordering::Acquire).saturating_sub(1) } - pub fn total_used(&self) -> usize { + pub(crate) fn total_used(&self) -> usize { self.total_used.load(Ordering::Acquire) } @@ -320,7 +320,7 @@ impl FreeSpaceMap { } /// Scan the FSM blocks to fill up fast_free_list. - fn refill_fast_free_list(&self, ctx: Context) -> Result { + fn refill_fast_free_list(&self, ctx: &Context) -> Result { // NOTE: We take a write lock to prevent multiple refills happening simultaneously. #[allow(clippy::readonly_write_lock)] let max_block = self.max_block.write().unwrap(); @@ -337,7 +337,7 @@ impl FreeSpaceMap { let block_key = Self::block_key(block_id); if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { return Err(FsmError::Garnet(GarnetError::Read)); } @@ -369,7 +369,7 @@ impl FreeSpaceMap { } /// Ensure enough blocks exist in the FSM to hold `id`. - fn expand_to(&self, ctx: Context, id: u32) -> Result<(), FsmError> { + fn expand_to(&self, ctx: &Context, id: u32) -> Result<(), FsmError> { let (block_id, _, _) = self.indexes_for_id(id); let mut max_block = self.max_block.write().unwrap(); if *max_block == u32::MAX || block_id == *max_block + 1 { @@ -378,7 +378,7 @@ impl FreeSpaceMap { if !self .callbacks - .write_wid(ctx.term(Term::Metadata), block_key, &block_bytes) + .write_wid(&ctx.term(Term::Metadata), block_key, &block_bytes) { return Err(FsmError::Garnet(GarnetError::Write)); } @@ -396,7 +396,7 @@ impl FreeSpaceMap { } /// Visit each used id in the FSM, invoking f on each id. - pub fn visit_used(&self, ctx: Context, mut f: F) -> Result<(), FsmError> + pub(crate) fn visit_used(&self, ctx: &Context, mut f: F) -> Result<(), FsmError> where F: FnMut(u32) -> bool, { @@ -408,7 +408,7 @@ impl FreeSpaceMap { let block_key = Self::block_key(block_id); if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { return Err(FsmError::Garnet(GarnetError::Read)); } @@ -442,14 +442,14 @@ impl FreeSpaceMap { /// Prevent the reuse of previously deleted IDs. /// Each call to this increments a counter, and only once the counter is back to zero /// will reuse be allowed again. - pub fn lock_reuse(&self) { + pub(crate) fn lock_reuse(&self) { self.reuse_lock.fetch_add(1, Ordering::AcqRel); } /// Resume reuse of previously deleted IDs. /// Each call to this decrements a counter, and only once the counter is back to zero /// will reuse be allowed. This returns whether reuse was actually enabled. - pub fn unlock_reuse(&self) -> bool { + pub(crate) fn unlock_reuse(&self) -> bool { let prev = self.reuse_lock.fetch_sub(1, Ordering::AcqRel); debug_assert_ne!(prev, 0); prev == 1 @@ -487,7 +487,10 @@ mod tests { { let key = FreeSpaceMap::block_key(block_id); let block = store - .get(Context(0).term(Term::Metadata).0, bytemuck::bytes_of(&key)) + .get( + Context::new(0).term(Term::Metadata).get(), + bytemuck::bytes_of(&key), + ) .unwrap(); f(&block); } @@ -497,7 +500,7 @@ mod tests { let store = Store; // A fresh FSM should result in full block of all zeroes. - let _fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let _fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); verify_block(&store, 0, |block| { assert_eq!(block.len(), BLOCK_SIZE_BYTES); assert_eq!(block.iter().filter(|b| **b != 0).count(), 0); @@ -508,29 +511,29 @@ mod tests { fn basic_next_id() { let store = Store; - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); // A fresh FSM should not have free IDs. assert!(!fsm.has_free_ids.load(std::sync::atomic::Ordering::Acquire)); // It should return sequentials IDs starting from 0. - assert_eq!(fsm.next_id(Context(0)).unwrap(), 0); - assert_eq!(fsm.next_id(Context(0)).unwrap(), 1); + assert_eq!(fsm.next_id(&Context::new(0)).unwrap(), 0); + assert_eq!(fsm.next_id(&Context::new(0)).unwrap(), 1); } #[test] fn basic_delete() { let store = Store; - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); // Deleting a ID out of range should return an error. - let res = fsm.mark_free(Context(0), 0); + let res = fsm.mark_free(&Context::new(0), 0); assert!(res.is_err()); assert_eq!(res.err().unwrap(), FsmError::IdOutOfRange(0)); for _ in 0u32..64 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&Context::new(0)).unwrap(); } verify_block(&store, 0, |block| { @@ -542,8 +545,8 @@ mod tests { } }); - fsm.mark_free(Context(0), 37).unwrap(); - fsm.mark_free(Context(0), 9).unwrap(); + fsm.mark_free(&Context::new(0), 37).unwrap(); + fsm.mark_free(&Context::new(0), 9).unwrap(); verify_block(&store, 0, |block| { assert_eq!(block[0], 0b11111111); @@ -566,20 +569,20 @@ mod tests { fn basic_id_reuse() { let store = Store; - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); for _ in 0u32..64 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&Context::new(0)).unwrap(); } // Next ID should be 64 since none are free before that. - assert_eq!(fsm.next_id(Context(0)).unwrap(), 64); + assert_eq!(fsm.next_id(&Context::new(0)).unwrap(), 64); // After deleting an ID, it should be returned from `next_id()`. - fsm.mark_free(Context(0), 37).unwrap(); - assert_eq!(fsm.next_id(Context(0)).unwrap(), 37); + fsm.mark_free(&Context::new(0), 37).unwrap(); + assert_eq!(fsm.next_id(&Context::new(0)).unwrap(), 37); // Once all free IDs are used, fresh ones should be returned. - assert_eq!(fsm.next_id(Context(0)).unwrap(), 65); + assert_eq!(fsm.next_id(&Context::new(0)).unwrap(), 65); // And has_free_ids should now be false. assert!(!fsm.has_free_ids.load(Ordering::Acquire)); } @@ -588,32 +591,32 @@ mod tests { fn basic_recovery() { let store = Store; - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); for _ in 0u32..64 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&Context::new(0)).unwrap(); } - fsm.mark_free(Context(0), 37).unwrap(); + fsm.mark_free(&Context::new(0), 37).unwrap(); // Loading FSM from store should recover all the state. - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); assert_eq!(*fsm.max_block.read().unwrap(), 0); assert_eq!(fsm.next_id.load(Ordering::Acquire), 64); assert!(fsm.has_free_ids.load(Ordering::Acquire)); assert_eq!(fsm.fast_free_list.len(), 1); - assert_eq!(fsm.next_id(Context(0)).unwrap(), 37); + assert_eq!(fsm.next_id(&Context::new(0)).unwrap(), 37); } #[test] fn dynamic_expansion() { let store = Store; - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&Context::new(0), store.callbacks()).unwrap(); // Asking for more than BLOCK_SIZE_IDS will force another FSM block to be allocated. for _ in 0u32..BLOCK_SIZE_IDS as u32 + 1 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&Context::new(0)).unwrap(); } } } diff --git a/diskann-garnet/src/garnet.rs b/diskann-garnet/src/garnet.rs index 566460af1..526cee264 100644 --- a/diskann-garnet/src/garnet.rs +++ b/diskann-garnet/src/garnet.rs @@ -3,17 +3,26 @@ * Licensed under the MIT license. */ -use std::{ffi::c_void, fmt, mem, ops::Deref, slice}; +use std::{ + ffi::c_void, + fmt, mem, + ops::Deref, + slice, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, +}; use diskann::provider::ExecutionContext; use thiserror::Error; /// Bitmask for extracting the Term bits from a Context. /// Must have enough bits to represent all Term variants (max value is 6, needs 3 bits). -pub const TERM_BITMASK: u64 = (1 << 3) - 1; +pub(crate) const TERM_BITMASK: u64 = (1 << 3) - 1; #[derive(Debug)] -pub enum Term { +pub(crate) enum Term { Vector = 0, Neighbors = 1, Quantized = 2, @@ -23,28 +32,62 @@ pub enum Term { ExtMap = 6, } -#[derive(Copy, Clone)] -pub struct Context(pub u64); +#[derive(Clone, Debug)] +pub(crate) struct Context { + inner: u64, + quantizer_ready: Arc, +} impl Context { - pub fn term(&self, kind: Term) -> Self { - Context(self.0 | (kind as u64 & TERM_BITMASK)) + pub(crate) fn new(inner: u64) -> Self { + Self { + inner, + quantizer_ready: Arc::new(AtomicBool::new(false)), + } + } + + #[cfg(test)] + pub(crate) fn get(&self) -> u64 { + self.inner + } + + pub(crate) fn term(&self, kind: Term) -> Self { + let Context { + inner, + quantizer_ready, + } = self; + let inner = *inner | (kind as u64 & TERM_BITMASK); + let quantizer_ready = quantizer_ready.clone(); + + Self { + inner, + quantizer_ready, + } + } + + pub(crate) fn quantizer_ready(&self) -> bool { + self.quantizer_ready.load(Ordering::Acquire) + } + + pub(crate) fn set_quantizer_ready(&self) { + self.quantizer_ready.store(true, Ordering::Release); } } impl ExecutionContext for Context {} -pub type ReadCallback = +pub(crate) type ReadCallback = unsafe extern "C" fn(u64, u32, *const u8, usize, ReadDataCallback, *mut c_void); -pub type WriteCallback = unsafe extern "C" fn(u64, *const u8, usize, *const u8, usize) -> bool; -pub type DeleteCallback = unsafe extern "C" fn(u64, *const u8, usize) -> bool; -pub type ReadModifyWriteCallback = +pub(crate) type WriteCallback = + unsafe extern "C" fn(u64, *const u8, usize, *const u8, usize) -> bool; +pub(crate) type DeleteCallback = unsafe extern "C" fn(u64, *const u8, usize) -> bool; +pub(crate) type ReadModifyWriteCallback = unsafe extern "C" fn(u64, *const u8, usize, usize, RmwDataCallback, *mut c_void) -> bool; -pub type ReadDataCallback = unsafe extern "C" fn(u32, *mut c_void, *const u8, usize); -pub type RmwDataCallback = unsafe extern "C" fn(*mut c_void, *mut u8, usize); +pub(crate) type ReadDataCallback = unsafe extern "C" fn(u32, *mut c_void, *const u8, usize); +pub(crate) type RmwDataCallback = unsafe extern "C" fn(*mut c_void, *mut u8, usize); #[derive(Copy, Clone)] -pub struct Callbacks { +pub(crate) struct Callbacks { read_callback: ReadCallback, write_callback: WriteCallback, delete_callback: DeleteCallback, @@ -52,7 +95,7 @@ pub struct Callbacks { } impl Callbacks { - pub fn new( + pub(crate) fn new( read_callback: ReadCallback, write_callback: WriteCallback, delete_callback: DeleteCallback, @@ -67,36 +110,33 @@ impl Callbacks { } #[cfg(test)] - pub fn read_callback(&self) -> ReadCallback { + pub(crate) fn read_callback(&self) -> ReadCallback { self.read_callback } #[cfg(test)] - pub fn write_callback(&self) -> WriteCallback { + pub(crate) fn write_callback(&self) -> WriteCallback { self.write_callback } #[cfg(test)] - pub fn delete_callback(&self) -> DeleteCallback { + pub(crate) fn delete_callback(&self) -> DeleteCallback { self.delete_callback } #[cfg(test)] - pub fn rmw_callback(&self) -> ReadModifyWriteCallback { + pub(crate) fn rmw_callback(&self) -> ReadModifyWriteCallback { self.rmw_callback } - #[expect( - dead_code, - reason = "currently unused, but may be needed in the future" - )] - pub fn exists_iid(&self, ctx: Context, id: u32) -> bool { + #[cfg(test)] + pub(crate) fn exists_iid(&self, ctx: &Context, id: u32) -> bool { let key = [4, id]; // SAFETY: Key bytes are preceded by 4 bytes of space. unsafe { self.exists_raw(ctx, bytemuck::bytes_of(&key)) } } - pub fn exists_wid(&self, ctx: Context, key: u64) -> bool { + pub(crate) fn exists_wid(&self, ctx: &Context, key: u64) -> bool { // NOTE: the length is bit-shifted so that we have a u32 in the lower half of the u64. let mut key = [8 << 32, key]; let key_bytes = bytemuck::bytes_of_mut(&mut key); @@ -108,7 +148,7 @@ impl Callbacks { dead_code, reason = "currently unused, but may be needed in the future" )] - pub fn exists_eid(&self, ctx: Context, id: &GarnetId) -> bool { + pub(crate) fn exists_eid(&self, ctx: &Context, id: &GarnetId) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.exists_raw(ctx, id) } } @@ -117,7 +157,7 @@ impl Callbacks { /// /// NOTE: The key bytes must be preceded by 4 valid bytes that Garnet can write into. /// This invariant must be checked by the caller. - unsafe fn exists_raw(&self, ctx: Context, key: &[u8]) -> bool { + unsafe fn exists_raw(&self, ctx: &Context, key: &[u8]) -> bool { let mut called = false; let mut cb = |_, _: &[u8]| { called = true; @@ -125,7 +165,7 @@ impl Callbacks { unsafe { (self.read_callback)( - ctx.0, + ctx.inner, 1, key.as_ptr(), key.len(), @@ -138,9 +178,9 @@ impl Callbacks { } #[must_use] - pub fn read_single_iid( + pub(crate) fn read_single_iid( &self, - ctx: Context, + ctx: &Context, id: u32, value: &mut [D], ) -> bool { @@ -156,9 +196,9 @@ impl Callbacks { } #[must_use] - pub fn read_single_wid( + pub(crate) fn read_single_wid( &self, - ctx: Context, + ctx: &Context, key: u64, value: &mut [D], ) -> bool { @@ -176,9 +216,9 @@ impl Callbacks { } #[must_use] - pub fn read_single_eid( + pub(crate) fn read_single_eid( &self, - ctx: Context, + ctx: &Context, id: &GarnetId, value: &mut [D], ) -> bool { @@ -197,7 +237,7 @@ impl Callbacks { /// NOTE: The key bytes must be preceded by 4 valid bytes that Garnet can write into. /// This invariant must be checked by the caller. #[must_use] - unsafe fn read_single_raw(&self, ctx: Context, key: &[u8], value: &mut [u8]) -> bool { + unsafe fn read_single_raw(&self, ctx: &Context, key: &[u8], value: &mut [u8]) -> bool { let mut found = false; let mut cb = |_, data: &[u8]| { found = true; @@ -206,7 +246,7 @@ impl Callbacks { unsafe { (self.read_callback)( - ctx.0, + ctx.inner, 1, key.as_ptr(), key.len(), @@ -219,8 +259,12 @@ impl Callbacks { } // ids must be passed as 4-byte length prefixed u32s. so [4, I1_u32, 4, I2_u32, ...] - pub fn read_multi_lpiid<'a, F, T: bytemuck::Pod>(&self, ctx: Context, ids: &[u32], mut f: F) - where + pub(crate) fn read_multi_lpiid<'a, F, T: bytemuck::Pod>( + &self, + ctx: &Context, + ids: &[u32], + mut f: F, + ) where F: FnMut(u32, &'a [T]), { if ids.is_empty() { @@ -229,7 +273,7 @@ impl Callbacks { unsafe { (self.read_callback)( - ctx.0, + ctx.inner, ids.len() as u32 / 2, bytemuck::must_cast_slice::<_, u8>(ids).as_ptr(), mem::size_of_val(ids), @@ -244,7 +288,11 @@ impl Callbacks { /// This function allocations inside the read callback since it can't know the size /// of the value up front. #[must_use] - pub fn read_varsize_iid(&self, ctx: Context, id: u32) -> Option> { + pub(crate) fn read_varsize_iid( + &self, + ctx: &Context, + id: u32, + ) -> Option> { let key = [4, id]; let mut result = None; let mut cb = |_, data: &[u8]| { @@ -262,7 +310,7 @@ impl Callbacks { // SAFETY: Key bytes are preceded by 4 bytes of extra space. unsafe { (self.read_callback)( - ctx.0, + ctx.inner, 1, bytemuck::bytes_of(&key).as_ptr(), mem::size_of_val(&key), @@ -275,7 +323,7 @@ impl Callbacks { } #[must_use] - pub fn write_iid(&self, ctx: Context, id: u32, value: &[D]) -> bool { + pub(crate) fn write_iid(&self, ctx: &Context, id: u32, value: &[D]) -> bool { let key = [0, id]; // SAFETY: Key bytes are preceded by 4 bytes of extra space. unsafe { @@ -288,7 +336,7 @@ impl Callbacks { } #[must_use] - pub fn write_wid(&self, ctx: Context, key: u64, value: &[D]) -> bool { + pub(crate) fn write_wid(&self, ctx: &Context, key: u64, value: &[D]) -> bool { let key = [0, key]; // SAFETY: Key bytes are preceded by 8 bytes of extra space. unsafe { @@ -301,7 +349,12 @@ impl Callbacks { } #[must_use] - pub fn write_eid(&self, ctx: Context, id: &GarnetId, value: &[D]) -> bool { + pub(crate) fn write_eid( + &self, + ctx: &Context, + id: &GarnetId, + value: &[D], + ) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.write_raw(ctx, id, bytemuck::must_cast_slice::(value)) } } @@ -311,22 +364,22 @@ impl Callbacks { /// NOTE: The key bytes must be preceded by 4 valid bytes that Garnet can write into. /// This invariant must be checked by the caller. #[must_use] - unsafe fn write_raw(&self, ctx: Context, key: &[u8], value: &[u8]) -> bool { + unsafe fn write_raw(&self, ctx: &Context, key: &[u8], value: &[u8]) -> bool { let value_ptr = value.as_ptr(); let value_len = value.len(); - unsafe { (self.write_callback)(ctx.0, key.as_ptr(), key.len(), value_ptr, value_len) } + unsafe { (self.write_callback)(ctx.inner, key.as_ptr(), key.len(), value_ptr, value_len) } } #[must_use] - pub fn delete_iid(&self, ctx: Context, id: u32) -> bool { + pub(crate) fn delete_iid(&self, ctx: &Context, id: u32) -> bool { let key = [0, id]; - unsafe { (self.delete_callback)(ctx.0, bytemuck::bytes_of(&key[1]).as_ptr(), 4) } + unsafe { (self.delete_callback)(ctx.inner, bytemuck::bytes_of(&key[1]).as_ptr(), 4) } } #[must_use] - pub fn delete_eid(&self, ctx: Context, id: &GarnetId) -> bool { + pub(crate) fn delete_eid(&self, ctx: &Context, id: &GarnetId) -> bool { let id: &[u8] = id; - unsafe { (self.delete_callback)(ctx.0, id.as_ptr(), id.len()) } + unsafe { (self.delete_callback)(ctx.inner, id.as_ptr(), id.len()) } } /// Modify a value in Garnet by internal ID. @@ -336,7 +389,13 @@ impl Callbacks { /// /// `f` should not panic. #[must_use] - pub fn rmw_iid<'a, F, T>(&self, ctx: Context, id: u32, write_len: usize, mut f: F) -> bool + pub(crate) fn rmw_iid<'a, F, T>( + &self, + ctx: &Context, + id: u32, + write_len: usize, + mut f: F, + ) -> bool where F: FnMut(&'a mut [T]), T: bytemuck::Pod, @@ -365,7 +424,13 @@ impl Callbacks { /// /// `f` should not panic. #[must_use] - pub fn rmw_wid<'a, F, T>(&self, ctx: Context, key: u64, write_len: usize, mut f: F) -> bool + pub(crate) fn rmw_wid<'a, F, T>( + &self, + ctx: &Context, + key: u64, + write_len: usize, + mut f: F, + ) -> bool where F: FnMut(&'a mut [T]), T: bytemuck::Pod, @@ -397,13 +462,13 @@ impl Callbacks { /// /// `f` should not panic. #[must_use] - unsafe fn rmw_raw<'a, F>(&self, ctx: Context, key: &[u8], write_len: usize, mut f: F) -> bool + unsafe fn rmw_raw<'a, F>(&self, ctx: &Context, key: &[u8], write_len: usize, mut f: F) -> bool where F: FnMut(&'a mut [u8]), { unsafe { (self.rmw_callback)( - ctx.0, + ctx.inner, key.as_ptr(), key.len(), write_len, @@ -467,7 +532,7 @@ where } #[derive(Debug, Error, PartialEq)] -pub enum GarnetError { +pub(crate) enum GarnetError { #[error("garnet read failed")] Read, #[error("garnet write failed")] @@ -485,12 +550,12 @@ pub enum GarnetError { /// Dereferencing this type will return a slice to the actual ID bytes, without the padding, /// which makes this interchangeable in most respects with using a raw `Box<[u8]>`. #[derive(Clone, PartialEq)] -pub struct GarnetId { +pub(crate) struct GarnetId { inner: Box<[u8]>, } impl GarnetId { - pub fn as_prefixed_key_bytes(&self) -> &[u8] { + pub(crate) fn as_prefixed_key_bytes(&self) -> &[u8] { &self.inner } } diff --git a/diskann-garnet/src/labels.rs b/diskann-garnet/src/labels.rs index f569cb30d..8e0a0b482 100644 --- a/diskann-garnet/src/labels.rs +++ b/diskann-garnet/src/labels.rs @@ -36,7 +36,7 @@ use diskann::graph::index::QueryLabelProvider; /// this struct. In practice, the struct is created and dropped within a /// single FFI search call. #[derive(Debug)] -pub struct GarnetQueryLabelProvider { +pub(crate) struct GarnetQueryLabelProvider { data: *const u8, len: usize, } @@ -68,7 +68,7 @@ impl GarnetQueryLabelProvider { /// The caller must ensure `data` is valid for reads of `len` bytes /// and that the memory remains valid and unmodified for the lifetime /// of this struct. - pub unsafe fn from_raw(data: *const u8, len: usize) -> Self { + pub(crate) unsafe fn from_raw(data: *const u8, len: usize) -> Self { if data.is_null() || len == 0 { return Self { data: std::ptr::null(), @@ -80,7 +80,7 @@ impl GarnetQueryLabelProvider { /// Construct a `GarnetQueryLabelProvider` from a byte slice. #[allow(dead_code)] - pub fn from_bytes(bytes: &[u8]) -> Self { + pub(crate) fn from_bytes(bytes: &[u8]) -> Self { if bytes.is_empty() { Self { data: std::ptr::null(), diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index c6b40e37f..73067ca0a 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -72,7 +72,7 @@ impl From for IndexState { } } -pub struct Index { +pub(crate) struct Index { inner: Box, quant_type: VectorQuantType, state: AtomicUsize, @@ -176,8 +176,14 @@ fn create_index_impl( callbacks: Callbacks, context: Context, ) -> Result, GarnetProviderError> { - let provider = - GarnetProvider::::new(dim, quant_type, metric_type, max_degree, callbacks, context)?; + let provider = GarnetProvider::::new( + dim, + quant_type, + metric_type, + max_degree, + callbacks, + &context, + )?; let state = if provider.start_points_exist() { AtomicUsize::new(IndexState::Ready as usize) } else { @@ -228,7 +234,7 @@ pub unsafe extern "C" fn create_index( return ptr::null(); }; - let context = Context(ctx); + let context = Context::new(ctx); let callbacks = Callbacks::new(read_callback, write_callback, delete_callback, rmw_callback); match quant_type { @@ -397,7 +403,7 @@ pub unsafe extern "C" fn insert( attribute_len: usize, ) -> u8 { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); @@ -429,11 +435,11 @@ pub unsafe extern "C" fn insert( return INSERT_FAIL; } - let old_ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + let old_ready = ctx.quantizer_ready(); // Insert the vector if index.inner.insert(&ctx, &id, &v).is_ok() { - let ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + let ready = ctx.quantizer_ready(); if !old_ready && ready { INSERT_SUCCESS_START_TRAINING } else { @@ -489,7 +495,7 @@ where #[unsafe(no_mangle)] pub unsafe extern "C" fn build_quant_table(context: u64, index_ptr: *const c_void) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(context); + let ctx = Context::new(context); index.inner.train_quantizer(&ctx) } @@ -505,7 +511,7 @@ pub unsafe extern "C" fn backfill_quant_vectors( task_count: usize, ) { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(context); + let ctx = Context::new(context); index .inner .backfill_quant_vectors(&ctx, task_index, task_count); @@ -524,7 +530,7 @@ pub unsafe extern "C" fn set_attribute( attribute_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(context); + let ctx = Context::new(context); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); @@ -577,7 +583,7 @@ pub unsafe extern "C" fn search_vector( return -1; }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let mut output = SearchResults::new( output_ids, @@ -642,7 +648,7 @@ pub unsafe extern "C" fn search_element( let index = unsafe { &*index_ptr.cast::() }; let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); - let ctx = Context(ctx); + let ctx = Context::new(ctx); let mut output = SearchResults::new( output_ids, @@ -710,7 +716,7 @@ pub unsafe extern "C" fn remove( id_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); @@ -742,7 +748,7 @@ pub unsafe extern "C" fn check_internal_id_valid( internal_id_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let internal_id_bytes = unsafe { slice::from_raw_parts(internal_id_data, internal_id_len) }; if internal_id_bytes.len() != mem::size_of::() { return false; @@ -765,7 +771,7 @@ pub unsafe extern "C" fn check_external_id_valid( id_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 6f8750135..76c9a0636 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -49,13 +49,6 @@ use crate::{ }, }; -thread_local! { - /// Thread local flag to detect when we've reached the quantization threshold. This is needed - /// to return the correct status at the end of insert() since the provider code that becomes - /// aware of the state change cannot directly communicate it back to the caller. - pub static QUANTIZER_READY: AtomicBool = const { AtomicBool::new(false) }; -} - #[derive(Clone)] struct AdjList(AdjacencyList); @@ -85,7 +78,7 @@ impl AsPooled for AdjList { /// A type erased vector, properly aligned so that it can hold any size element /// (up to 8 bytes long). -pub struct DynVector { +pub(crate) struct DynVector { inner: Poly<[u8], AlignToEight>, ty: PhantomData, } @@ -122,7 +115,7 @@ impl<'a, T: VectorRepr> Reborrow<'a> for DynVector { } #[derive(Debug, Error)] -pub enum GarnetProviderError { +pub(crate) enum GarnetProviderError { #[error("Garnet operation failed")] Garnet(#[from] GarnetError), #[error("FSM error")] @@ -133,10 +126,6 @@ pub enum GarnetProviderError { AllocFailed(#[from] AllocatorError), #[error("Invalid quantizer for vector data")] InvalidQuantizer, - #[error("Expected quantizer is missing")] - MissingQuantizer, - #[error("Failed to gather quantizer training data")] - MissingTrainingData, #[error("Quantizer error: {0}")] Quantizer(#[from] GarnetQuantizerError), #[error("Post processing error: {0}")] @@ -153,7 +142,7 @@ impl From for ANNError { diskann::always_escalate!(GarnetProviderError); /// The Garnet DataProvider implementation. -pub struct GarnetProvider { +pub(crate) struct GarnetProvider { dim: usize, metric_type: Metric, max_degree: usize, @@ -175,13 +164,13 @@ pub struct GarnetProvider { } impl GarnetProvider { - pub fn new( + pub(crate) fn new( dim: usize, quant_type: VectorQuantType, metric_type: Metric, max_degree: usize, callbacks: Callbacks, - context: Context, + context: &Context, ) -> Result { let parallelism = std::thread::available_parallelism().unwrap().get() * 2; let id_buffer_pool = @@ -201,9 +190,9 @@ impl GarnetProvider { // Try to read the start point from Garnet let mut v = Poly::broadcast(0u8, dim * mem::size_of::(), AlignToEight)?; - if callbacks.read_single_iid(context.term(Term::Vector), 0, &mut v) { + if callbacks.read_single_iid(&context.term(Term::Vector), 0, &mut v) { let mut neighbors = vec![0u32; max_degree + 1]; - if !callbacks.read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) { + if !callbacks.read_single_iid(&context.term(Term::Neighbors), 0, &mut neighbors) { return Err(GarnetError::Read.into()); } @@ -263,7 +252,7 @@ impl GarnetProvider { }) } - pub fn maybe_set_start_point( + pub(crate) fn maybe_set_start_point( &self, context: &Context, point: &[T], @@ -271,13 +260,13 @@ impl GarnetProvider { let mut v = Poly::broadcast(0u8, self.dim * mem::size_of::(), AlignToEight)?; if self .callbacks - .read_single_iid(context.term(Term::Vector), 0, &mut v) + .read_single_iid(&context.term(Term::Vector), 0, &mut v) { // Garnet already has a start point, so use that instead of `point` let mut neighbors = vec![0u32; self.max_degree + 1]; if !self .callbacks - .read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) + .read_single_iid(&context.term(Term::Neighbors), 0, &mut neighbors) { return Err(GarnetError::Read.into()); } @@ -290,15 +279,15 @@ impl GarnetProvider { let neighbors = vec![0u32; self.max_degree + 1]; // Grab the start point id, which must be zero. - let id = self.fsm.next_id(*context)?; + let id = self.fsm.next_id(context)?; if id != 0 { - self.fsm.mark_free(*context, id)?; + self.fsm.mark_free(context, id)?; return Err(GarnetProviderError::StartPoint); } if !self .callbacks - .write_iid(context.term(Term::Vector), 0, point) + .write_iid(&context.term(Term::Vector), 0, point) { return Err(GarnetError::Write.into()); } @@ -313,7 +302,7 @@ impl GarnetProvider { if !self .callbacks - .write_iid(context.term(Term::Quantized), 0, &qpoint) + .write_iid(&context.term(Term::Quantized), 0, &qpoint) { return Err(GarnetError::Write.into()); } @@ -324,7 +313,7 @@ impl GarnetProvider { if !self .callbacks - .write_iid(context.term(Term::Neighbors), 0, &neighbors) + .write_iid(&context.term(Term::Neighbors), 0, &neighbors) { return Err(GarnetError::Write.into()); } @@ -343,11 +332,11 @@ impl GarnetProvider { Ok(()) } - pub fn start_points_exist(&self) -> bool { + pub(crate) fn start_points_exist(&self) -> bool { self.start_point_cache.get(&0).is_some() && self.neighbor_cache.get(&0).is_some() } - pub fn set_attributes( + pub(crate) fn set_attributes( &self, context: &Context, id: &GarnetId, @@ -355,26 +344,26 @@ impl GarnetProvider { ) -> Result<(), GarnetProviderError> { if self .callbacks - .write_eid(context.term(Term::Attributes), id, data) + .write_eid(&context.term(Term::Attributes), id, data) { Ok(()) } else { Err(GarnetError::Write.into()) } } - pub fn vector_id_exists(&self, context: &Context, id: &GarnetId) -> bool { + pub(crate) fn vector_id_exists(&self, context: &Context, id: &GarnetId) -> bool { let iid = match self.to_internal_id(context, id) { Ok(iid) => iid, Err(_) => return false, }; - !self.fsm.is_free(*context, iid).unwrap_or(true) + !self.fsm.is_free(context, iid).unwrap_or(true) } - pub fn vector_iid_exists(&self, context: &Context, id: u32) -> bool { - !self.fsm.is_free(*context, id).unwrap_or(true) + pub(crate) fn vector_iid_exists(&self, context: &Context, id: u32) -> bool { + !self.fsm.is_free(context, id).unwrap_or(true) } - pub fn max_internal_id(&self) -> u32 { + pub(crate) fn max_internal_id(&self) -> u32 { self.fsm.max_id() } @@ -382,7 +371,7 @@ impl GarnetProvider { /// /// This should only be called when at least `quantizer.required_vectors()` vectors exist in /// the provider. This will build quantization tables, but does not quantize any vectors. - pub fn train_quantizer(&self, context: &Context) -> bool { + pub(crate) fn train_quantizer(&self, context: &Context) -> bool { // Collect up to `self.quantizer.required_vectors()` vectors and use them for quantizer training. let current_time: u64 = @@ -424,7 +413,7 @@ impl GarnetProvider { if self .fsm - .visit_used(*context, |id| { + .visit_used(context, |id| { // Skip the start point. if id == 0 { return true; @@ -439,7 +428,7 @@ impl GarnetProvider { let row = data.row_mut(row_idx); if !self .callbacks - .read_single_iid(context.term(Term::Vector), id, row) + .read_single_iid(&context.term(Term::Vector), id, row) { return false; } @@ -480,7 +469,12 @@ impl GarnetProvider { /// /// This function will be invoked on multiple threads. The total number of tasks and the ID of /// the current task are given as inputs. - pub fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + pub(crate) fn backfill_quant_vectors( + &self, + context: &Context, + task_idx: usize, + task_count: usize, + ) { let quantizer = match &self.quantizer { Some(q) => q, None => return, @@ -504,7 +498,7 @@ impl GarnetProvider { for id in start_id..end_id { if !self .callbacks - .read_single_iid(context.term(Term::Vector), id, &mut v) + .read_single_iid(&context.term(Term::Vector), id, &mut v) { continue; } @@ -515,7 +509,7 @@ impl GarnetProvider { if !self .callbacks - .write_iid(context.term(Term::Quantized), id, &q) + .write_iid(&context.term(Term::Quantized), id, &q) { continue; } @@ -532,7 +526,7 @@ impl GarnetProvider { { let _ = self .callbacks - .write_iid(context.term(Term::Quantized), 0, &q); + .write_iid(&context.term(Term::Quantized), 0, &q); // set the cache let point = if let Ok(p) = Poly::from_iter(q.iter().copied(), AlignToEight) { @@ -557,7 +551,7 @@ impl GarnetProvider { } /// Returns quantization status. If this is true, the index is operating fully quantized. - pub fn is_quantized(&self) -> bool { + pub(crate) fn is_quantized(&self) -> bool { self.quantizer.is_some() && self.all_quantized.load(Ordering::Acquire) } } @@ -576,7 +570,7 @@ impl DataProvider for GarnetProvider { ) -> Result { let mut id = 0u32; if !self.callbacks.read_single_eid( - context.term(Term::IntMap), + &context.term(Term::IntMap), gid, bytemuck::bytes_of_mut(&mut id), ) { @@ -588,7 +582,7 @@ impl DataProvider for GarnetProvider { fn to_external_id(&self, context: &Context, id: u32) -> Result { match self .callbacks - .read_varsize_iid(context.term(Term::ExtMap), id) + .read_varsize_iid(&context.term(Term::ExtMap), id) { Some(eid) => Ok(eid.into()), None => Err(GarnetProviderError::Garnet(GarnetError::Read)), @@ -605,19 +599,19 @@ impl SetElement<&[T]> for GarnetProvider { id: &Self::ExternalId, element: &[T], ) -> Result { - let internal_id = self.fsm.next_id(*context)?; + let internal_id = self.fsm.next_id(context)?; // Set quantization readiness if let Some(quantizer) = &self.quantizer && !quantizer.is_prepared() && self.fsm.total_used() > quantizer.required_vectors() { - QUANTIZER_READY.with(|v| v.store(true, Ordering::Release)); + context.set_quantizer_ready(); } let insert = || -> Result<(), Self::SetError> { self.callbacks - .write_iid(context.term(Term::Vector), internal_id, element) + .write_iid(&context.term(Term::Vector), internal_id, element) .then_some(()) .ok_or(GarnetError::Write)?; if let Some(quantizer) = &self.quantizer @@ -628,17 +622,17 @@ impl SetElement<&[T]> for GarnetProvider { .get_ref(Undef::new(quantizer.canonical_bytes())); quantizer.compress(bytemuck::cast_slice::(element), &mut quant)?; self.callbacks - .write_iid(context.term(Term::Quantized), internal_id, &quant) + .write_iid(&context.term(Term::Quantized), internal_id, &quant) .then_some(()) .ok_or(GarnetError::Write)?; } self.callbacks - .write_iid(context.term(Term::ExtMap), internal_id, id) + .write_iid(&context.term(Term::ExtMap), internal_id, id) .then_some(()) .ok_or(GarnetError::Write)?; self.callbacks .write_eid( - context.term(Term::IntMap), + &context.term(Term::IntMap), id, bytemuck::bytes_of(&internal_id), ) @@ -650,7 +644,7 @@ impl SetElement<&[T]> for GarnetProvider { match insert() { Ok(()) => (), Err(e) => { - self.fsm.mark_free(*context, internal_id)?; + self.fsm.mark_free(context, internal_id)?; return Err(e); } } @@ -671,21 +665,23 @@ impl Delete for GarnetProvider { }; // Mark the ID free in the FSM. - if let Err(e) = self.fsm.mark_free(*context, id) { + if let Err(e) = self.fsm.mark_free(context, id) { return future::ready(Err(e.into())); }; // Delete all the data associated with the vector. let mut ok = true; - ok &= self.callbacks.delete_iid(context.term(Term::ExtMap), id); - ok &= self.callbacks.delete_eid(context.term(Term::IntMap), gid); + ok &= self.callbacks.delete_iid(&context.term(Term::ExtMap), id); + ok &= self.callbacks.delete_eid(&context.term(Term::IntMap), gid); ok &= self .callbacks - .delete_eid(context.term(Term::Attributes), gid); + .delete_eid(&context.term(Term::Attributes), gid); // NOTE: Commented out until DiskANN fixes accessing neighbor data post-delete. //ok &= self.callbacks.delete_iid(context.term(Term::Neighbors), id); - ok &= self.callbacks.delete_iid(context.term(Term::Vector), id); - ok &= self.callbacks.delete_iid(context.term(Term::Quantized), id); + ok &= self.callbacks.delete_iid(&context.term(Term::Vector), id); + ok &= self + .callbacks + .delete_iid(&context.term(Term::Quantized), id); if !ok { return future::ready(Err(GarnetError::Delete.into())); @@ -708,7 +704,7 @@ impl Delete for GarnetProvider { context: &Self::Context, id: Self::InternalId, ) -> impl Future> + Send { - let status = match self.fsm.is_free(*context, id) { + let status = match self.fsm.is_free(context, id) { Ok(true) => ElementStatus::Deleted, Ok(false) => ElementStatus::Valid, Err(e) => return future::ready(Err(e.into())), @@ -730,9 +726,9 @@ impl Delete for GarnetProvider { /// Dynamic accessor that seamlessly transitions from full precision vector based operation to /// quantized-only operation. #[derive(Copy, Clone, Debug)] -pub struct DynamicQuantization; +pub(crate) struct DynamicQuantization; -pub struct DynamicAccessor<'a, T: VectorRepr> { +pub(crate) struct DynamicAccessor<'a, T: VectorRepr> { provider: &'a GarnetProvider, context: &'a Context, /// Whether this accessor should use quantized vectors @@ -776,7 +772,7 @@ impl<'a, T: VectorRepr> DynamicAccessor<'a, T> { } if !self.provider.callbacks.read_single_iid( - self.context.term(Term::Neighbors), + &self.context.term(Term::Neighbors), id, &mut guard, ) { @@ -878,7 +874,7 @@ impl ExpandBeam<&[T]> for DynamicAccessor<'_, T> { if !self.filtered_ids.is_empty() { self.provider .callbacks - .read_multi_lpiid(ctx, &self.filtered_ids, |i, v| { + .read_multi_lpiid(&ctx, &self.filtered_ids, |i, v| { let dist = computer.evaluate_similarity(v); on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); }); @@ -940,7 +936,7 @@ impl Accessor for DynamicAccessor<'_, T> { self.context.term(Term::Vector) }; - if !self.provider.callbacks.read_single_iid(ctx, id, &mut v) { + if !self.provider.callbacks.read_single_iid(&ctx, id, &mut v) { return future::ready(Err(GarnetError::Read.into())); } @@ -949,7 +945,7 @@ impl Accessor for DynamicAccessor<'_, T> { } /// Wrapper for full precision distance computer. -pub struct FullPrecisionDistance(T::Distance); +pub(crate) struct FullPrecisionDistance(T::Distance); impl DynDistanceComputer for FullPrecisionDistance { fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { @@ -961,7 +957,7 @@ impl DynDistanceComputer for FullPrecisionDistance { } /// Wrapper for full precision query computer. -pub struct FullPrecisionQueryDistance(T::QueryDistance); +pub(crate) struct FullPrecisionQueryDistance(T::QueryDistance); impl DynQueryComputer for FullPrecisionQueryDistance { fn evaluate_similarity(&self, a: &[u8]) -> f32 { @@ -970,12 +966,12 @@ impl DynQueryComputer for FullPrecisionQueryDistance { } /// Type-erased distance computer. -pub struct GarnetDistanceComputer { +pub(crate) struct GarnetDistanceComputer { inner: Box, } impl GarnetDistanceComputer { - pub fn new(computer: T) -> Self { + pub(crate) fn new(computer: T) -> Self { Self { inner: Box::new(computer), } @@ -988,12 +984,12 @@ impl DistanceFunction<&[u8], &[u8]> for GarnetDistanceComputer { } /// Type-erased query computer. -pub struct GarnetQueryComputer { +pub(crate) struct GarnetQueryComputer { inner: Box, } impl GarnetQueryComputer { - pub fn new(computer: T) -> Self { + pub(crate) fn new(computer: T) -> Self { Self { inner: Box::new(computer), } @@ -1051,7 +1047,7 @@ impl BuildQueryComputer<&[T]> for DynamicAccessor<'_, T> { /// /// Without an `&[T]: Into>`, the blanket implementation for `workingset::Map` /// is not applicable, allowing customization of `Fill`. -pub struct Escape(Box<[T]>); +pub(crate) struct Escape(Box<[T]>); impl<'a, T> Reborrow<'a> for Escape { type Target = &'a [T]; @@ -1060,13 +1056,13 @@ impl<'a, T> Reborrow<'a> for Escape { } } -pub struct WorkingSet { +pub(crate) struct WorkingSet { map: workingset::Map>, contains_unquantized: bool, } impl WorkingSet { - pub fn new(capacity_type: workingset::map::Capacity, capacity: usize) -> Self { + pub(crate) fn new(capacity_type: workingset::map::Capacity, capacity: usize) -> Self { Self { map: workingset::map::Builder::new(capacity_type).build(capacity), contains_unquantized: false, @@ -1150,7 +1146,7 @@ impl workingset::Fill for DynamicAccessor<'_, T> { if !self.filtered_ids.is_empty() { self.provider .callbacks - .read_multi_lpiid(ctx, &self.filtered_ids, |id, v| { + .read_multi_lpiid(&ctx, &self.filtered_ids, |id, v| { set.insert(self.filtered_ids[id as usize * 2 + 1], Escape(v.into())); }); } @@ -1159,7 +1155,7 @@ impl workingset::Fill for DynamicAccessor<'_, T> { } } -pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut DynamicAccessor<'p, T>); +pub(crate) struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut DynamicAccessor<'p, T>); impl HasId for DelegateNeighborAccessor<'_, '_, T> { type Id = u32; @@ -1201,7 +1197,7 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> .0 .provider .callbacks - .write_iid(self.0.context.term(Term::Neighbors), id, &guard) + .write_iid(&self.0.context.term(Term::Neighbors), id, &guard) { return future::ready(Err(GarnetProviderError::Garnet(GarnetError::Write).into())); } @@ -1225,7 +1221,7 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> ) -> impl Future> + Send { let max_degree = self.0.provider.max_degree; if !self.0.provider.callbacks.rmw_iid( - self.0.context.term(Term::Neighbors), + &self.0.context.term(Term::Neighbors), id, (self.0.provider.max_degree + 1) * mem::size_of::(), |data: &mut [u32]| { @@ -1263,7 +1259,7 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> /// A [`SearchPostProcess`] base object that copies each `Neighbor` to a `(ExternalId, f32)` pair /// and writes as many as possible to the output buffer. #[derive(Debug, Default, Clone, Copy)] -pub struct CopyExternalIds; +pub(crate) struct CopyExternalIds; impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> for CopyExternalIds @@ -1301,7 +1297,7 @@ impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], G /// A [`SearchPostProcess`] base object that reranks quantized vectors by full precision distance. #[derive(Debug, Default, Clone, Copy)] -pub struct Rerank; +pub(crate) struct Rerank; impl<'a, 'b, T: VectorRepr> SearchPostProcessStep, &'b [T], GarnetId> for Rerank @@ -1345,7 +1341,7 @@ impl<'a, 'b, T: VectorRepr> SearchPostProcessStep, &'b [T if !provider.vector_iid_exists(accessor.context, n.id) { None } else if provider.callbacks.read_single_iid( - accessor.context.term(Term::Vector), + &accessor.context.term(Term::Vector), n.id, &mut v, ) { @@ -1469,7 +1465,7 @@ impl InplaceDeleteStrategy> for DynamicQuantiza ) -> impl Future> + Send { let mut v = vec![T::default(); provider.dim]; - if !provider.callbacks.read_single_iid(*context, id, &mut v) { + if !provider.callbacks.read_single_iid(context, id, &mut v) { return future::ready(Err(GarnetError::Read.into())); } future::ready(Ok(v.into())) diff --git a/diskann-garnet/src/quantization.rs b/diskann-garnet/src/quantization.rs index 265bda056..3979ddf48 100644 --- a/diskann-garnet/src/quantization.rs +++ b/diskann-garnet/src/quantization.rs @@ -19,7 +19,7 @@ use thiserror::Error; use crate::provider::{GarnetDistanceComputer, GarnetQueryComputer}; #[derive(Debug, Error)] -pub enum GarnetQuantizerError { +pub(crate) enum GarnetQuantizerError { #[error("Quantization training error: {0}")] Training(Box), #[error("Quantization alloc error: {0}")] @@ -37,7 +37,7 @@ pub enum GarnetQuantizerError { } /// Quantizer trait that all diskann-garnet quantizers must implement -pub trait GarnetQuantizer: Send + Sync { +pub(crate) trait GarnetQuantizer: Send + Sync { /// Check whether the quantizer is ready to be used fn is_prepared(&self) -> bool; /// Returns the number of vectors needed before the quantizer can be trained @@ -56,12 +56,12 @@ pub trait GarnetQuantizer: Send + Sync { } /// Type-erased distance computer -pub trait DynDistanceComputer: Send + Sync { +pub(crate) trait DynDistanceComputer: Send + Sync { fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32; } /// Type-erased query computer -pub trait DynQueryComputer: Send + Sync { +pub(crate) trait DynQueryComputer: Send + Sync { fn evaluate_similarity(&self, a: &[u8]) -> f32; } @@ -70,13 +70,13 @@ pub trait DynQueryComputer: Send + Sync { /// This quantizer corresponds to `BIN` quantizer in the Redis protocol. It requires hundreds of /// vectors (but not thousands) for training. Quantized vectors have 1 bit per dimension plus up /// to 6 bytes of overhead. -pub struct Spherical1Bit { +pub(crate) struct Spherical1Bit { dim: usize, inner: RwLock>>, } impl Spherical1Bit { - pub fn new(dim: usize) -> Self { + pub(crate) fn new(dim: usize) -> Self { Self { dim, inner: RwLock::new(None), @@ -197,14 +197,14 @@ impl DynQueryComputer for iface::QueryComputer { /// /// This quantizer requires no training at all and is usable immediately on the first first. Each /// quantized vector has 8 bits per dimension and 20 bytes of overhead. -pub struct MinMax8Bit { +pub(crate) struct MinMax8Bit { dim: usize, metric: Metric, inner: minmax::MinMaxQuantizer, } impl MinMax8Bit { - pub fn new(dim: usize, metric: Metric) -> Result { + pub(crate) fn new(dim: usize, metric: Metric) -> Result { let dim = match NonZero::new(dim) { Some(d) => d, None => return Err(GarnetQuantizerError::ZeroDim), diff --git a/diskann-garnet/src/test_utils.rs b/diskann-garnet/src/test_utils.rs index ba9caed10..cdc59897f 100644 --- a/diskann-garnet/src/test_utils.rs +++ b/diskann-garnet/src/test_utils.rs @@ -141,42 +141,42 @@ mod tests { let store = Store; store.clear(); let callbacks = store.callbacks(); - let ctx = Context(0); + let ctx = Context::new(0); // Reading a non-existant key should fail. - assert!(!callbacks.exists_iid(ctx, 0)); + assert!(!callbacks.exists_iid(&ctx, 0)); // Round tripping a write should work. - assert!(callbacks.write_iid(ctx, 0, b"test")); + assert!(callbacks.write_iid(&ctx, 0, b"test")); let mut val = vec![0u8; 4]; - assert!(callbacks.read_single_iid(ctx, 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx, 0, &mut val)); assert_eq!(val, b"test"); // Overwriting a key should work. - assert!(callbacks.write_iid(ctx, 0, b"again")); + assert!(callbacks.write_iid(&ctx, 0, b"again")); let mut val = vec![0u8; 5]; - assert!(callbacks.read_single_iid(ctx, 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx, 0, &mut val)); assert_eq!(val, b"again"); // Exists and delete should work. - assert!(callbacks.exists_iid(ctx, 0)); - assert!(callbacks.delete_iid(ctx, 0)); - assert!(!callbacks.exists_iid(ctx, 0)); + assert!(callbacks.exists_iid(&ctx, 0)); + assert!(callbacks.delete_iid(&ctx, 0)); + assert!(!callbacks.exists_iid(&ctx, 0)); // Different contexts should stay separate. - assert!(callbacks.write_iid(ctx.term(Term::Vector), 0, b"0000")); - assert!(callbacks.write_iid(ctx.term(Term::Neighbors), 0, b"nnnn")); + assert!(callbacks.write_iid(&ctx.term(Term::Vector), 0, b"0000")); + assert!(callbacks.write_iid(&ctx.term(Term::Neighbors), 0, b"nnnn")); let mut val = vec![0u8; 4]; - assert!(callbacks.read_single_iid(ctx.term(Term::Vector), 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx.term(Term::Vector), 0, &mut val)); assert_eq!(val, b"0000"); - assert!(callbacks.read_single_iid(ctx.term(Term::Neighbors), 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx.term(Term::Neighbors), 0, &mut val)); assert_eq!(val, b"nnnn"); // Multi-read should work. - assert!(callbacks.write_iid(ctx.term(Term::Vector), 1, b"2222")); + assert!(callbacks.write_iid(&ctx.term(Term::Vector), 1, b"2222")); let ids = [4u32, 0, 4, 1, 4, 2]; let mut results = HashMap::new(); - callbacks.read_multi_lpiid(ctx.term(Term::Vector), &ids, |i, v| { + callbacks.read_multi_lpiid(&ctx.term(Term::Vector), &ids, |i, v| { results.insert(i, v.to_owned()); }); assert_eq!(results.get(&0), Some(b"0000".to_vec()).as_ref()); @@ -184,10 +184,10 @@ mod tests { assert_eq!(results.get(&2), None); // RMW should work. - assert!(callbacks.rmw_iid(ctx.term(Term::Vector), 0, 4, |data| { + assert!(callbacks.rmw_iid(&ctx.term(Term::Vector), 0, 4, |data| { data.copy_from_slice(b"0rmw"); })); - assert!(callbacks.read_single_iid(ctx.term(Term::Vector), 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx.term(Term::Vector), 0, &mut val)); assert_eq!(val, b"0rmw"); } @@ -199,7 +199,7 @@ mod tests { let store: Store = Store; store.clear(); let callbacks = store.callbacks(); - let ctx: Context = Context(0); + let ctx: Context = Context::new(0); // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; @@ -210,7 +210,7 @@ mod tests { Metric::L2, max_degree, callbacks, - ctx, + &ctx, ); // Provider should be created successfully @@ -229,7 +229,7 @@ mod tests { Metric::L2, max_degree, callbacks, - ctx, + &ctx, ); // Provider should be created successfully diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md index a74c60100..6577c3c06 100644 --- a/rfcs/00000-quantizer-bootstrap.md +++ b/rfcs/00000-quantizer-bootstrap.md @@ -74,7 +74,7 @@ Bootstrapping needs two changes to a DiskANN data provider implementation. DiskANN already has the ability to run multiple strategies including hybrid full precision and quantized ones. These strategies represent the high level intent, -but the data provider can choose alternate implementions depending on the data +but the data provider can choose alternate implementations depending on the data available during the current phase. #### Insertion and Deletion @@ -85,7 +85,7 @@ into storage, and can track the current phase to gate writes to quantized vectors. The search portion of these operations will return different objects depending on the current phase. -For example, consider a `DataProvder::set_element()` implementation: +For example, consider a `DataProvider::set_element()` implementation: ```rust struct ExampleProvider { @@ -94,7 +94,7 @@ struct ExampleProvider { } impl SetElement<[f32]> for ExampleProvider { - // associated types ommitted + // associated types omitted async fn set_element( &self, @@ -107,12 +107,12 @@ impl SetElement<[f32]> for ExampleProvider { self.set_internal_map(internal_id, id)?; self.set_external_map(id, internal_id)?; - // Quantize and storage quant vector if we have a quantizer. + // Quantize and store quant vector if we have a quantizer. if let Some(quantizer) = self.quantizer { let qv = quantizer.quantize(element)?; self.write_quant_vector(context, internal_id, element)?; } else { - // This function will check if we are ready for Phase Two, and if so, do or schedule the quantizer intialization. + // This function will check if we are ready for Phase Two, and if so, do or schedule the quantizer initialization. self.maybe_initialize_quantizer()?; } From 0e936b72f1010f4d492c2ea85e24bb0ee3166066 Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Tue, 19 May 2026 18:00:30 -0500 Subject: [PATCH 5/5] new FFI for i8 --- Cargo.lock | 123 ++++++++++++------------- diskann-garnet/diskann-garnet.nuspec | 2 +- diskann-garnet/src/ffi_recall_tests.rs | 6 +- diskann-garnet/src/ffi_tests.rs | 13 +-- diskann-garnet/src/lib.rs | 105 ++++++++------------- diskann-garnet/src/provider.rs | 61 ++++++++---- vectorset/src/main.rs | 81 +++++++++++----- 7 files changed, 204 insertions(+), 187 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 81fcc6f78..451d298a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -134,7 +134,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -177,7 +177,7 @@ checksum = "4a2df13e12ead1bb9f3b45cebec7e2b36eb6c2e88305745adc34828c1e457f92" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "tracing", "typespec_client_core", ] @@ -207,14 +207,14 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bf-tree" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07155f3f780c77c9ae58fcd1668140a6ac46a9b1e85bddf81489061291b3e80f" +checksum = "11a6082b72699bf88f431c4ebab0e9e2ab6c31cf0576a3b115653fb8a0a0018e" dependencies = [ "cfg-if", "io-uring", "libc", - "rand 0.8.5", + "rand 0.8.6", "rand 0.9.4", "serde", "serde_json", @@ -282,7 +282,7 @@ checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -377,7 +377,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -616,7 +616,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -972,7 +972,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1018,7 +1018,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1047,7 +1047,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1058,7 +1058,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1128,7 +1128,7 @@ checksum = "2cc4b8cd876795d3b19ddfd59b03faa303c0b8adb9af6e188e81fc647c485bb9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1291,7 +1291,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1682,7 +1682,7 @@ dependencies = [ "quote", "serde", "serde_json", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1907,10 +1907,12 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "js-sys" -version = "0.3.83" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -2227,7 +2229,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2345,7 +2347,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2441,7 +2443,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2484,14 +2486,14 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "proc-macro2" -version = "1.0.104" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9695f8df41bb4f3d222c95a67532365f569318332d03d5f3f67f37b20e6ebdf0" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] @@ -2535,7 +2537,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2572,9 +2574,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.42" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -2587,9 +2589,9 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha 0.3.1", @@ -2861,7 +2863,7 @@ dependencies = [ "regex", "relative-path 1.9.3", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", "unicode-ident", ] @@ -3028,14 +3030,14 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "serde_json" -version = "1.0.148" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -3171,9 +3173,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.113" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "678faa00651c9eb72dd2020cbdf275d92eccb2400d568e419efdd64838145cb4" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -3197,7 +3199,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3268,7 +3270,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3279,7 +3281,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3368,7 +3370,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3555,7 +3557,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3666,7 +3668,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3838,9 +3840,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" dependencies = [ "cfg-if", "once_cell", @@ -3851,22 +3853,19 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.56" +version = "0.4.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" +checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" dependencies = [ - "cfg-if", "js-sys", - "once_cell", "wasm-bindgen", - "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3874,22 +3873,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" dependencies = [ "unicode-ident", ] @@ -3943,9 +3942,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.83" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" +checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" dependencies = [ "js-sys", "wasm-bindgen", @@ -4177,7 +4176,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn 2.0.113", + "syn 2.0.117", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -4193,7 +4192,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -4260,7 +4259,7 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "synstructure", ] @@ -4281,7 +4280,7 @@ checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -4301,7 +4300,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "synstructure", ] @@ -4341,7 +4340,7 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] diff --git a/diskann-garnet/diskann-garnet.nuspec b/diskann-garnet/diskann-garnet.nuspec index f6997c410..e77e73275 100644 --- a/diskann-garnet/diskann-garnet.nuspec +++ b/diskann-garnet/diskann-garnet.nuspec @@ -2,7 +2,7 @@ diskann-garnet - 1.0.26 + 1.0.28 docs/README.md Microsoft https://github.com/microsoft/DiskANN diff --git a/diskann-garnet/src/ffi_recall_tests.rs b/diskann-garnet/src/ffi_recall_tests.rs index a70ed0ce5..b15178346 100644 --- a/diskann-garnet/src/ffi_recall_tests.rs +++ b/diskann-garnet/src/ffi_recall_tests.rs @@ -10,8 +10,8 @@ mod tests { use diskann_vector::distance::{Cosine, Metric, SquaredL2}; use crate::{ - VectorQuantType, VectorValueType, create_index, drop_index, garnet::Context, insert, - search_vector, test_utils::Store, + VectorQuantType, create_index, drop_index, garnet::Context, insert, search_vector, + test_utils::Store, }; /// Helper to insert a vector with a string external ID and FP32 data. @@ -29,7 +29,6 @@ mod tests { index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector.len(), b"".as_ptr(), @@ -206,7 +205,6 @@ mod tests { search_vector( ctx.get(), index_ptr, - VectorValueType::FP32, query_bytes.as_ptr(), vec.len(), delta, diff --git a/diskann-garnet/src/ffi_tests.rs b/diskann-garnet/src/ffi_tests.rs index efc5ce308..251b1b02b 100644 --- a/diskann-garnet/src/ffi_tests.rs +++ b/diskann-garnet/src/ffi_tests.rs @@ -9,8 +9,8 @@ mod tests { use diskann_vector::distance::Metric; use crate::{ - VectorQuantType, VectorValueType, card, check_external_id_valid, create_index, drop_index, - garnet::Context, insert, remove, search_vector, set_attribute, test_utils::Store, + VectorQuantType, card, check_external_id_valid, create_index, drop_index, garnet::Context, + insert, remove, search_vector, set_attribute, test_utils::Store, }; /// Creates an index with default test values and returns (index_ptr, Context). @@ -132,7 +132,6 @@ mod tests { index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_len, attributes_bytes.as_ptr(), @@ -195,7 +194,6 @@ mod tests { index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_bytes.len() / 4, attributes1.as_ptr(), @@ -281,7 +279,6 @@ mod tests { index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_len, attributes_bytes.as_ptr(), @@ -314,7 +311,6 @@ mod tests { index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_len, attributes_bytes.as_ptr(), @@ -361,7 +357,6 @@ mod tests { index_ptr, id1_bytes.as_ptr(), id1_bytes.len(), - VectorValueType::XB8, v1.as_ptr(), v1.len(), b"".as_ptr(), @@ -377,7 +372,6 @@ mod tests { index_ptr, id2_bytes.as_ptr(), id2_bytes.len(), - VectorValueType::XB8, v2.as_ptr(), v2.len(), b"".as_ptr(), @@ -394,7 +388,6 @@ mod tests { search_vector( ctx.get(), index_ptr, - VectorValueType::XB8, qv.as_ptr(), qv.len(), 2.0, @@ -467,7 +460,6 @@ mod tests { index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector.len(), b"".as_ptr(), @@ -497,7 +489,6 @@ mod tests { search_vector( ctx.get(), index_ptr, - VectorValueType::FP32, query_bytes.as_ptr(), query.len(), 0.0, diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index 73067ca0a..1fcddc376 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -93,7 +93,10 @@ pub enum VectorQuantType { NoQuant, Bin, Q8, - XPreQ8, + XNoQuantU8, + XNoQuantI8, + XBinI8, + XBinU8, } struct SearchResults<'a> { @@ -221,6 +224,7 @@ pub unsafe extern "C" fn create_index( }; let target_degree = (max_degree as f32 / GRAPH_SLACK_FACTOR) as usize; + let config = if let Ok(config) = config::Builder::new( target_degree, config::MaxDegree::Value(max_degree as usize), @@ -239,7 +243,7 @@ pub unsafe extern "C" fn create_index( match quant_type { VectorQuantType::Invalid => ptr::null(), - VectorQuantType::XPreQ8 => { + VectorQuantType::XNoQuantU8 | VectorQuantType::XBinU8 => { if let Ok(index) = create_index_impl::( quant_type, config, @@ -254,6 +258,21 @@ pub unsafe extern "C" fn create_index( ptr::null() } } + VectorQuantType::XNoQuantI8 | VectorQuantType::XBinI8 => { + if let Ok(index) = create_index_impl::( + quant_type, + config, + dim as usize, + metric_type, + max_degree as usize, + callbacks, + context, + ) { + Arc::into_raw(index).cast::() + } else { + ptr::null() + } + } VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { if let Ok(index) = create_index_impl::( quant_type, @@ -311,43 +330,27 @@ impl<'a> From> for PolyCow<'a> { fn interpret_vector<'a>( quant_type: VectorQuantType, - vector_value_type: VectorValueType, vector_data: &'a *const u8, vector_len: usize, ) -> Option> { - let vector_len_bytes = match vector_value_type { - VectorValueType::FP32 => vector_len * 4, - VectorValueType::XB8 => vector_len, - VectorValueType::Invalid => return None, + let vector_len_bytes = match quant_type { + VectorQuantType::Invalid => return None, + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => vector_len * 4, + VectorQuantType::XNoQuantU8 + | VectorQuantType::XNoQuantI8 + | VectorQuantType::XBinU8 + | VectorQuantType::XBinI8 => vector_len, }; let v = unsafe { slice::from_raw_parts(*vector_data, vector_len_bytes) }; - let v = match vector_value_type { - VectorValueType::Invalid => return None, - VectorValueType::FP32 => match quant_type { - VectorQuantType::Invalid => { - return None; - } - VectorQuantType::XPreQ8 => { - let mut bp = if let Ok(bp) = Poly::broadcast(0u8, vector_len, AlignToEight) { - bp - } else { - return None; - }; - for (idx, e) in bp.iter_mut().enumerate() { - let el_size = mem::size_of::(); - *e = f32::from_le_bytes( - v[idx * el_size..(idx + 1) * el_size].try_into().unwrap(), - ) as u8; - } - PolyCow::from(bp) - } - _ if v.as_ptr().align_offset(4) == 0 => { + let v = match quant_type { + VectorQuantType::Invalid => return None, + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { + if v.as_ptr().align_offset(mem::align_of::()) == 0 { // pointer is correctly aligned to interpret as f32 PolyCow::from(v) - } - _ => { + } else { // need to copy f32 data as it is unaligned let mut fp = if let Ok(fp) = Poly::broadcast(0u8, vector_len_bytes, AlignToEight) { fp @@ -357,27 +360,11 @@ fn interpret_vector<'a>( fp.copy_from_slice(v); PolyCow::from(fp) } - }, - VectorValueType::XB8 => match quant_type { - VectorQuantType::XPreQ8 => PolyCow::from(v), - VectorQuantType::NoQuant => { - let mut fp = if let Ok(p) = - Poly::broadcast(0u8, vector_len_bytes * mem::size_of::(), AlignToEight) - { - p - } else { - return None; - }; - for (fe, be) in bytemuck::cast_slice_mut::(&mut fp) - .iter_mut() - .zip(v) - { - *fe = *be as f32; - } - PolyCow::from(fp) - } - _ => return None, - }, + } + VectorQuantType::XNoQuantU8 + | VectorQuantType::XNoQuantI8 + | VectorQuantType::XBinU8 + | VectorQuantType::XBinI8 => PolyCow::from(v), }; Some(v) @@ -396,7 +383,6 @@ pub unsafe extern "C" fn insert( index_ptr: *const c_void, id_data: *const u8, id_len: usize, - vector_value_type: VectorValueType, vector_data: *const u8, vector_len: usize, attribute_data: *const u8, @@ -408,12 +394,7 @@ pub unsafe extern "C" fn insert( let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); - let v = if let Some(v) = interpret_vector( - index.quant_type, - vector_value_type, - &vector_data, - vector_len, - ) { + let v = if let Some(v) = interpret_vector(index.quant_type, &vector_data, vector_len) { v } else { return INSERT_FAIL; @@ -556,7 +537,6 @@ pub unsafe extern "C" fn set_attribute( pub unsafe extern "C" fn search_vector( ctx: u64, index_ptr: *const c_void, - vector_value_type: VectorValueType, vector_data: *const u8, vector_len: usize, _delta: f32, @@ -572,12 +552,7 @@ pub unsafe extern "C" fn search_vector( ) -> i32 { let index = unsafe { &*index_ptr.cast::() }; - let v = if let Some(v) = interpret_vector( - index.quant_type, - vector_value_type, - &vector_data, - vector_len, - ) { + let v = if let Some(v) = interpret_vector(index.quant_type, &vector_data, vector_len) { v } else { return -1; diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 76c9a0636..45ef5420e 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -23,8 +23,11 @@ use diskann::{ utils::VectorRepr, }; use diskann_quantization::alloc::{AllocatorError, Poly}; -use diskann_utils::object_pool::{AsPooled, ObjectPool, PooledRef, Undef}; use diskann_utils::{Reborrow, views::Matrix}; +use diskann_utils::{ + object_pool::{AsPooled, ObjectPool, PooledRef, Undef}, + views::MatrixView, +}; use diskann_vector::{ DistanceFunction, PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric, }; @@ -206,7 +209,9 @@ impl GarnetProvider { let fsm = FreeSpaceMap::new(context, callbacks)?; let (quantizer, canonical_bytes, all_quantized) = match quant_type { - VectorQuantType::NoQuant | VectorQuantType::XPreQ8 => (None, 0, false), + VectorQuantType::NoQuant + | VectorQuantType::XNoQuantU8 + | VectorQuantType::XNoQuantI8 => (None, 0, false), VectorQuantType::Invalid => return Err(GarnetProviderError::InvalidQuantizer), VectorQuantType::Q8 => { if TypeId::of::() != TypeId::of::() { @@ -219,11 +224,7 @@ impl GarnetProvider { // NOTE: Q8 needs no training, so it always starts with backfill complete. (Some(quantizer), canonical_bytes, true) } - VectorQuantType::Bin => { - if TypeId::of::() != TypeId::of::() { - return Err(GarnetProviderError::InvalidQuantizer); - } - + VectorQuantType::Bin | VectorQuantType::XBinU8 | VectorQuantType::XBinI8 => { let quantizer = Box::new(quantization::Spherical1Bit::new(dim)) as Box; let canonical_bytes = quantizer.canonical_bytes(); @@ -298,7 +299,9 @@ impl GarnetProvider { // We are already able to quantize, so store the quantized start point let mut qpoint = vec![0u8; quantizer.canonical_bytes()]; - quantizer.compress(bytemuck::cast_slice::(point), &mut qpoint)?; + let point_f32 = + T::as_f32(point).map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + quantizer.compress(&point_f32, &mut qpoint)?; if !self .callbacks @@ -405,10 +408,8 @@ impl GarnetProvider { } }; - debug_assert_eq!(std::any::TypeId::of::(), std::any::TypeId::of::()); - let rows = quantizer.required_vectors(); - let mut data = Matrix::new(0f32, rows, self.dim); + let mut data = Matrix::::new(T::default(), rows, self.dim); let mut row_idx = 0usize; if self @@ -424,7 +425,6 @@ impl GarnetProvider { } // Read the vector into the data matrix. - // Note that it's ok to read f32 instead of T here, because this can only get called when T == f32. let row = data.row_mut(row_idx); if !self .callbacks @@ -456,6 +456,20 @@ impl GarnetProvider { }; // Train the quantizer. + let converted = match T::as_f32(view.as_slice()) { + Ok(v) => v, + Err(_) => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + let view = match MatrixView::try_from(&*converted, view.nrows(), view.ncols()) { + Ok(v) => v, + Err(_) => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; match quantizer.train(self.metric_type, view) { Ok(()) => true, Err(_e) => { @@ -493,7 +507,8 @@ impl GarnetProvider { self.fsm.lock_reuse(); - let mut v = vec![0f32; self.dim]; + let mut v = vec![T::default(); self.dim]; + let mut f = vec![0f32; self.dim]; let mut q = vec![0u8; quantizer.canonical_bytes()]; for id in start_id..end_id { if !self @@ -503,7 +518,10 @@ impl GarnetProvider { continue; } - if quantizer.compress(&v, &mut q).is_err() { + if T::as_f32_into(&v, &mut f).is_err() { + continue; + } + if quantizer.compress(&f, &mut q).is_err() { continue; }; @@ -520,9 +538,8 @@ impl GarnetProvider { if backfill_finished { // Finish by quantizing the start points if let Some(v) = self.start_point_cache.get(&0) - && quantizer - .compress(bytemuck::cast_slice::(&v), &mut q) - .is_ok() + && let Ok(v_f32) = T::as_f32(bytemuck::cast_slice::(&v)) + && quantizer.compress(&v_f32, &mut q).is_ok() { let _ = self .callbacks @@ -620,7 +637,10 @@ impl SetElement<&[T]> for GarnetProvider { let mut quant = self .quant_buffer_pool .get_ref(Undef::new(quantizer.canonical_bytes())); - quantizer.compress(bytemuck::cast_slice::(element), &mut quant)?; + let element_f32 = T::as_f32(element).map_err(|e| { + GarnetProviderError::Quantizer(GarnetQuantizerError::Compression(Box::new(e))) + })?; + quantizer.compress(&element_f32, &mut quant)?; self.callbacks .write_iid(&context.term(Term::Quantized), internal_id, &quant) .then_some(()) @@ -1032,8 +1052,11 @@ impl BuildQueryComputer<&[T]> for DynamicAccessor<'_, T> { if self.quantized && let Some(quantizer) = self.provider.quantizer() { + let from_f32 = T::as_f32(from).map_err(|e| { + GarnetProviderError::Quantizer(GarnetQuantizerError::Compression(Box::new(e))) + })?; Ok(quantizer - .query_computer(bytemuck::cast_slice::(from)) + .query_computer(&from_f32) .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?) } else { Ok(GarnetQueryComputer::new(FullPrecisionQueryDistance::( diff --git a/vectorset/src/main.rs b/vectorset/src/main.rs index 8333e4f86..04a2e3bb4 100644 --- a/vectorset/src/main.rs +++ b/vectorset/src/main.rs @@ -32,6 +32,7 @@ const DEFAULT_PORT: u16 = 6379; trait Element: bytemuck::Pod + std::fmt::Debug + Send + Sync + 'static {} impl Element for u8 {} +impl Element for i8 {} impl Element for f32 {} #[derive(Deserialize)] @@ -53,9 +54,9 @@ struct Options { #[arg(short, long)] threads: Option, - /// Data type for vector elements - #[arg(long)] - data_type: DataType, + /// Quantizer + #[arg(long, default_value_t)] + quantizer: Quantizer, #[command(subcommand)] command: Commands, @@ -107,10 +108,6 @@ struct IngestArgs { #[arg(long)] no_header_with_dim: Option, - /// Quantizer - #[arg(long)] - quantizer: Option, - /// Metric #[arg(long)] metric: Option, @@ -166,14 +163,31 @@ struct QueryArgs { #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] enum DataType { Uint8, + Int8, Float32, } -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Default)] enum Quantizer { - None, + #[default] + NoQuant, Bin, Q8, + XNoQuant_U8, + XNoQuant_I8, + XBin_U8, + XBin_I8, +} + +impl Quantizer { + pub fn data_type(&self) -> DataType { + match self { + Quantizer::NoQuant | Quantizer::Bin | Quantizer::Q8 => DataType::Float32, + Quantizer::XNoQuant_I8 | Quantizer::XBin_I8 => DataType::Int8, + Quantizer::XNoQuant_U8 | Quantizer::XBin_U8 => DataType::Uint8, + } + } } impl ToRedisArgs for Quantizer { @@ -182,14 +196,33 @@ impl ToRedisArgs for Quantizer { W: ?Sized + redis::RedisWrite, { let q = match self { - Quantizer::None => b"NOQUANT".as_slice(), + Quantizer::NoQuant => b"NOQUANT".as_slice(), Quantizer::Bin => b"BIN".as_slice(), Quantizer::Q8 => b"Q8".as_slice(), + Quantizer::XNoQuant_U8 => b"XNOQUANT_U8", + Quantizer::XNoQuant_I8 => b"XNOQUANT_I8", + Quantizer::XBin_U8 => b"XBIN_U8", + Quantizer::XBin_I8 => b"XBIN_I8", }; out.write_arg(q); } } +impl std::fmt::Display for Quantizer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Quantizer::NoQuant => "NOQUANT", + Quantizer::Bin => "BIN", + Quantizer::Q8 => "Q8", + Quantizer::XNoQuant_U8 => "XNOQUANT_U8", + Quantizer::XNoQuant_I8 => "XNOQUANT_I8", + Quantizer::XBin_U8 => "XBIN_U8", + Quantizer::XBin_I8 => "XBIN_I8", + }; + f.write_str(s) + } +} + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] enum DistanceMetric { L2, @@ -207,7 +240,7 @@ impl ToRedisArgs for DistanceMetric { DistanceMetric::L2 => b"L2".as_slice(), DistanceMetric::Cosine => b"COSINE".as_slice(), DistanceMetric::CosineNormalized => b"COSINE_NORMALIZED".as_slice(), - DistanceMetric::InnerProduct => b"INNER_PRODUCT".as_slice(), + DistanceMetric::InnerProduct => b"IP".as_slice(), }; out.write_arg(q); } @@ -337,8 +370,9 @@ async fn async_main(opts: Options) -> Result<()> { None }; - match opts.data_type { + match opts.quantizer.data_type() { DataType::Uint8 => dispatch::(&opts.command, &opts, infos, cred).await, + DataType::Int8 => dispatch::(&opts.command, &opts, infos, cred).await, DataType::Float32 => dispatch::(&opts.command, &opts, infos, cred).await, } } @@ -413,8 +447,8 @@ async fn ingest( let l_build = args.l_build; let degree = args.degree; let mut cred = cred.clone(); - let data_type = opts.data_type; - let quantizer = args.quantizer; + let data_type = opts.quantizer.data_type(); + let quantizer = opts.quantizer; let metric = args.metric; tasks.spawn(async move { @@ -446,7 +480,10 @@ async fn ingest( match data_type { DataType::Uint8 => { - pipeline.arg(b"XB8"); + pipeline.arg(b"XU8"); + } + DataType::Int8 => { + pipeline.arg(b"XI8"); } DataType::Float32 => { pipeline.arg(b"FP32"); @@ -457,14 +494,7 @@ async fn ingest( .arg(bytemuck::cast_slice::<_, u8>(&buf[buf_start..buf_end])) .arg(element); - match data_type { - DataType::Uint8 => { - pipeline.arg(b"XPREQ8"); - } - DataType::Float32 => { - pipeline.arg(quantizer); - } - } + pipeline.arg(quantizer); if let Some(metric) = metric { pipeline.arg(b"XDISTANCE_METRIC").arg(metric); @@ -567,7 +597,7 @@ async fn query( let n = args.n; let l_search = args.l_search; let mut cred = cred.clone(); - let data_type = opts.data_type; + let data_type = opts.quantizer.data_type(); tasks.spawn(async move { let mut pipeline = Pipeline::with_capacity(pipeline_size); @@ -591,7 +621,8 @@ async fn query( pipeline.cmd("VSIM").arg(&vset); match data_type { - DataType::Uint8 => pipeline.arg(b"XB8"), + DataType::Uint8 => pipeline.arg(b"XU8"), + DataType::Int8 => pipeline.arg(b"XI8"), DataType::Float32 => pipeline.arg(b"FP32"), };