diff --git a/Cargo.lock b/Cargo.lock index bead53768..4e3ca080a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -134,7 +134,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -177,7 +177,7 @@ checksum = "4a2df13e12ead1bb9f3b45cebec7e2b36eb6c2e88305745adc34828c1e457f92" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "tracing", "typespec_client_core", ] @@ -207,14 +207,14 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bf-tree" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07155f3f780c77c9ae58fcd1668140a6ac46a9b1e85bddf81489061291b3e80f" +checksum = "11a6082b72699bf88f431c4ebab0e9e2ab6c31cf0576a3b115653fb8a0a0018e" dependencies = [ "cfg-if", "io-uring", "libc", - "rand 0.8.5", + "rand 0.8.6", "rand 0.9.4", "serde", "serde_json", @@ -282,7 +282,7 @@ checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -377,7 +377,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -616,7 +616,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -800,6 +800,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "foldhash 0.2.0", + "rand 0.9.4", "thiserror 2.0.17", "tokio", ] @@ -992,7 +993,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1038,7 +1039,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1067,7 +1068,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1078,7 +1079,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1148,7 +1149,7 @@ checksum = "2cc4b8cd876795d3b19ddfd59b03faa303c0b8adb9af6e188e81fc647c485bb9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1311,7 +1312,7 @@ checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1690,7 +1691,7 @@ dependencies = [ "quote", "serde", "serde_json", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -1915,10 +1916,12 @@ checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "js-sys" -version = "0.3.83" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -2212,7 +2215,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2330,7 +2333,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2416,7 +2419,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2459,14 +2462,14 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "proc-macro2" -version = "1.0.104" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9695f8df41bb4f3d222c95a67532365f569318332d03d5f3f67f37b20e6ebdf0" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] @@ -2510,7 +2513,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -2547,9 +2550,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.42" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -2562,9 +2565,9 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "libc", "rand_chacha 0.3.1", @@ -2836,7 +2839,7 @@ dependencies = [ "regex", "relative-path 1.9.3", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", "unicode-ident", ] @@ -3003,14 +3006,14 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] name = "serde_json" -version = "1.0.148" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -3146,9 +3149,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.113" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "678faa00651c9eb72dd2020cbdf275d92eccb2400d568e419efdd64838145cb4" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -3172,7 +3175,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3243,7 +3246,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3254,7 +3257,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3343,7 +3346,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3530,7 +3533,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3641,7 +3644,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -3813,9 +3816,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" dependencies = [ "cfg-if", "once_cell", @@ -3826,22 +3829,19 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.56" +version = "0.4.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" +checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" dependencies = [ - "cfg-if", "js-sys", - "once_cell", "wasm-bindgen", - "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3849,22 +3849,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.106" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" dependencies = [ "unicode-ident", ] @@ -3918,9 +3918,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.83" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" +checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" dependencies = [ "js-sys", "wasm-bindgen", @@ -4152,7 +4152,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn 2.0.113", + "syn 2.0.117", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -4168,7 +4168,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -4235,7 +4235,7 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "synstructure", ] @@ -4256,7 +4256,7 @@ checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] @@ -4276,7 +4276,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", "synstructure", ] @@ -4316,7 +4316,7 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.113", + "syn 2.0.117", ] [[package]] diff --git a/diskann-garnet/Cargo.toml b/diskann-garnet/Cargo.toml index 339539531..e0547d11c 100644 --- a/diskann-garnet/Cargo.toml +++ b/diskann-garnet/Cargo.toml @@ -14,10 +14,11 @@ bytemuck.workspace = true crossbeam = "0.8.4" dashmap = { workspace = true, features = ["inline"] } diskann.workspace = true -diskann-quantization.workspace = true +diskann-quantization = { workspace = true, features = ["flatbuffers"] } diskann-providers.workspace = true +diskann-utils.workspace = true diskann-vector.workspace = true foldhash = "0.2.0" +rand.workspace = true thiserror.workspace = true -tokio.workspace = true -diskann-utils.workspace = true +tokio = { workspace = true, features = ["sync"] } diff --git a/diskann-garnet/diskann-garnet.nuspec b/diskann-garnet/diskann-garnet.nuspec index 08ce3d764..e77e73275 100644 --- a/diskann-garnet/diskann-garnet.nuspec +++ b/diskann-garnet/diskann-garnet.nuspec @@ -2,7 +2,7 @@ diskann-garnet - 1.0.27 + 1.0.28 docs/README.md Microsoft https://github.com/microsoft/DiskANN diff --git a/diskann-garnet/src/dyn_index.rs b/diskann-garnet/src/dyn_index.rs index d58e3daaa..f92a5d350 100644 --- a/diskann-garnet/src/dyn_index.rs +++ b/diskann-garnet/src/dyn_index.rs @@ -7,7 +7,7 @@ use crate::{ SearchResults, garnet::{Context, GarnetId}, labels::GarnetQueryLabelProvider, - provider::GarnetProvider, + provider::{DynamicQuantization, GarnetProvider}, }; use diskann::{ ANNResult, @@ -16,14 +16,13 @@ use diskann::{ utils::VectorRepr, }; use diskann_providers::{ - index::wrapped_async::DiskANNIndex, - model::graph::provider::{async_::common::FullPrecision, layers::BetaFilter}, + index::wrapped_async::DiskANNIndex, model::graph::provider::layers::BetaFilter, }; use std::sync::Arc; /// Type-erased version of `DiskANNIndex`. /// All vector data is passed as untyped byte slices. -pub trait DynIndex: Send + Sync { +pub(crate) trait DynIndex: Send + Sync { fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()>; fn set_attributes(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()>; @@ -55,6 +54,10 @@ pub trait DynIndex: Send + Sync { fn internal_id_exists(&self, context: &Context, id: u32) -> bool; fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool; + + fn train_quantizer(&self, context: &Context) -> bool; + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize); } impl DynIndex for DiskANNIndex> { @@ -63,7 +66,7 @@ impl DynIndex for DiskANNIndex> { /// The data slice here must be aligned to `T` or this will panic. fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()> { self.insert( - FullPrecision, + DynamicQuantization, context, id, bytemuck::cast_slice::(data), @@ -87,10 +90,10 @@ impl DynIndex for DiskANNIndex> { ) -> ANNResult { let query = bytemuck::cast_slice::(data); if let Some((labels, beta)) = filter { - let beta_filter = BetaFilter::new(FullPrecision, Arc::new(labels.clone()), beta); + let beta_filter = BetaFilter::new(DynamicQuantization, Arc::new(labels.clone()), beta); self.search(*params, &beta_filter, context, query, output) } else { - self.search(*params, &FullPrecision, context, query, output) + self.search(*params, &DynamicQuantization, context, query, output) } } @@ -104,14 +107,14 @@ impl DynIndex for DiskANNIndex> { ) -> ANNResult { // Look up internal ID let iid = self.inner.provider().to_internal_id(context, id)?; - let data = self.inner.provider().get_vector(context, iid)?; + let data = self.inner.provider().get_full_vector(context, iid)?; let data_bytes = bytemuck::cast_slice::(&data); self.search_vector(context, data_bytes, params, filter, output) } fn remove(&self, context: &Context, id: &GarnetId) -> ANNResult<()> { self.inplace_delete( - FullPrecision, + DynamicQuantization, context, id, 3, @@ -137,4 +140,14 @@ impl DynIndex for DiskANNIndex> { fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool { self.inner.provider().vector_id_exists(context, id) } + + fn train_quantizer(&self, context: &Context) -> bool { + self.inner.provider().train_quantizer(context) + } + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + self.inner + .provider() + .backfill_quant_vectors(context, task_idx, task_count); + } } diff --git a/diskann-garnet/src/ffi_recall_tests.rs b/diskann-garnet/src/ffi_recall_tests.rs index fdfaac349..b15178346 100644 --- a/diskann-garnet/src/ffi_recall_tests.rs +++ b/diskann-garnet/src/ffi_recall_tests.rs @@ -10,8 +10,8 @@ mod tests { use diskann_vector::distance::{Cosine, Metric, SquaredL2}; use crate::{ - VectorQuantType, VectorValueType, create_index, drop_index, garnet::Context, insert, - search_vector, test_utils::Store, + VectorQuantType, create_index, drop_index, garnet::Context, insert, search_vector, + test_utils::Store, }; /// Helper to insert a vector with a string external ID and FP32 data. @@ -25,16 +25,15 @@ mod tests { let vector_bytes: &[u8] = bytemuck::cast_slice(vector); unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector.len(), b"".as_ptr(), 0, - ) + ) > 0 } } @@ -156,14 +155,14 @@ mod tests { ) -> f64 { store.clear(); let callbacks = store.callbacks(); - let ctx = Context(0); + let ctx = Context::new(0); let reduce_dimensions = 0; let l_build = 100; let max_degree = 32; let index_ptr = unsafe { create_index( - ctx.0, + ctx.get(), dimensions, reduce_dimensions, VectorQuantType::NoQuant, @@ -204,9 +203,8 @@ mod tests { let count = unsafe { search_vector( - ctx.0, + ctx.get(), index_ptr, - VectorValueType::FP32, query_bytes.as_ptr(), vec.len(), delta, @@ -237,7 +235,7 @@ mod tests { total_expected += expected_ids.len(); } - unsafe { drop_index(ctx.0, index_ptr) }; + unsafe { drop_index(ctx.get(), index_ptr) }; total_matches as f64 / total_expected as f64 } diff --git a/diskann-garnet/src/ffi_tests.rs b/diskann-garnet/src/ffi_tests.rs index 6fc2022a4..2aadb92d6 100644 --- a/diskann-garnet/src/ffi_tests.rs +++ b/diskann-garnet/src/ffi_tests.rs @@ -9,8 +9,8 @@ mod tests { use diskann_vector::distance::Metric; use crate::{ - VectorQuantType, VectorValueType, card, check_external_id_valid, create_index, drop_index, - garnet::Context, insert, remove, search_vector, set_attribute, test_utils::Store, + VectorQuantType, card, check_external_id_valid, create_index, drop_index, garnet::Context, + insert, remove, search_vector, set_attribute, test_utils::Store, }; /// Creates an index with default test values and returns (index_ptr, Context). @@ -30,7 +30,7 @@ mod tests { store.clear(); let callbacks = store.callbacks(); - let ctx = Context(0); + let ctx = Context::new(0); let dim: u32 = 2; let reduce_dim = 0; @@ -40,7 +40,7 @@ mod tests { let index_ptr = unsafe { create_index( - ctx.0, + ctx.get(), dim, reduce_dim, quant_type, @@ -64,7 +64,7 @@ mod tests { assert!(!index_ptr.is_null()); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -105,7 +105,7 @@ mod tests { valid_metric ); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } } @@ -126,13 +126,12 @@ mod tests { let attributes_bytes = b"wololo"; let attributes_len = attributes_bytes.len(); - let result: bool = unsafe { + let result = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_len, attributes_bytes.as_ptr(), @@ -140,25 +139,26 @@ mod tests { ) }; - assert!(result); + assert!(result > 0); // Confirm vector exists using FFI function - let exists = - unsafe { check_external_id_valid(ctx.0, index_ptr, id_bytes.as_ptr(), id_bytes.len()) }; + let exists = unsafe { + check_external_id_valid(ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len()) + }; assert!(exists); - let mut cardinality = unsafe { card(ctx.0, index_ptr) }; + let mut cardinality = unsafe { card(ctx.get(), index_ptr) }; assert_eq!(cardinality, 1); - let removed = unsafe { remove(ctx.0, index_ptr, id_bytes.as_ptr(), id_bytes.len()) }; + let removed = unsafe { remove(ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len()) }; assert!(removed); // Currently we're not tracking deletions for cardinality, so it should still be 1 - cardinality = unsafe { card(ctx.0, index_ptr) }; + cardinality = unsafe { card(ctx.get(), index_ptr) }; assert_eq!(cardinality, 1); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -175,7 +175,7 @@ mod tests { let attributes1 = b"wololo"; let set_attribute_result1 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -188,26 +188,25 @@ mod tests { let vector: [f32; 2] = [1.0, 2.0]; let vector_bytes = bytemuck::cast_slice(&vector); - let result: bool = unsafe { + let result = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_bytes.len() / 4, attributes1.as_ptr(), attributes1.len(), ) }; - assert!(result); + assert!(result > 0); // Set attributes after insertion let attributes2 = b"new_attributes"; let set_attribute_result2 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -220,7 +219,7 @@ mod tests { // Set attributes to empty using null ptr let set_attribute_result3 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -233,7 +232,7 @@ mod tests { let empty_attribute = b""; let set_attribute_result4 = unsafe { set_attribute( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), @@ -244,7 +243,7 @@ mod tests { assert!(set_attribute_result4); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -269,72 +268,71 @@ mod tests { // Check external_id exists with EID 1 (should not exist initially) let exists1 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(!exists1, "EID 1 should not exist initially"); // Add vector with EID 1 let insert_result1 = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_len, attributes_bytes.as_ptr(), attributes_bytes.len(), ) }; - assert!(insert_result1, "Insert with EID 1 should succeed"); + assert!(insert_result1 > 0, "Insert with EID 1 should succeed"); // Check external_id exists with EID 1 (should exist after insert) let exists2 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(exists2, "EID 1 should exist after insert"); // Remove vector with EID 1 - let removed = unsafe { remove(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; + let removed = + unsafe { remove(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(removed, "Remove with EID 1 should succeed"); // Check external_id exists with EID 1 (should not exist after removal) let exists3 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(!exists3, "EID 1 should not exist after removal"); // Add vector with EID 2 let insert_result2 = unsafe { insert( - ctx.0, + ctx.get(), index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector_len, attributes_bytes.as_ptr(), attributes_bytes.len(), ) }; - assert!(insert_result2, "Insert with EID 2 should succeed"); + assert!(insert_result2 > 0, "Insert with EID 2 should succeed"); // Check external_id exists with EID 2 (should exist after insert) let exists4 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid2_bytes.as_ptr(), eid2_bytes.len()) }; assert!(exists4, "EID 2 should exist after insert"); // Check external_id exists with EID 1 (should still not exist) let exists5 = unsafe { - check_external_id_valid(ctx.0, index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) + check_external_id_valid(ctx.get(), index_ptr, eid1_bytes.as_ptr(), eid1_bytes.len()) }; assert!(!exists5, "EID 1 should still not exist"); unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -345,51 +343,52 @@ mod tests { let (index_ptr, ctx) = create_test_index(&store); let id1 = 1234u64; - let v1 = &[1u8, 0u8]; + let v1 = &[1.0f32, 0.0]; let id2 = 5678u64; - let v2 = &[0u8, 1u8]; + let v2 = &[0.0f32, 1.0]; let id1_bytes = bytemuck::bytes_of(&id1); let id2_bytes = bytemuck::bytes_of(&id2); - assert!(unsafe { - insert( - ctx.0, - index_ptr, - id1_bytes.as_ptr(), - id1_bytes.len(), - VectorValueType::XB8, - v1.as_ptr(), - v1.len(), - b"".as_ptr(), - 0, - ) - }); + assert!( + unsafe { + insert( + ctx.get(), + index_ptr, + id1_bytes.as_ptr(), + id1_bytes.len(), + bytemuck::cast_slice::(v1).as_ptr(), + v1.len(), + b"".as_ptr(), + 0, + ) + } > 0 + ); - assert!(unsafe { - insert( - ctx.0, - index_ptr, - id2_bytes.as_ptr(), - id2_bytes.len(), - VectorValueType::XB8, - v2.as_ptr(), - v2.len(), - b"".as_ptr(), - 0, - ) - }); + assert!( + unsafe { + insert( + ctx.get(), + index_ptr, + id2_bytes.as_ptr(), + id2_bytes.len(), + bytemuck::cast_slice::(v2).as_ptr(), + v2.len(), + b"".as_ptr(), + 0, + ) + } > 0 + ); - let qv = &[0u8, 0u8]; + let qv = &[0.0f32, 0.0]; let mut output_id_buffer = vec![0u8; 2 * (mem::size_of::() + mem::size_of::())]; let mut output_dists = vec![0f32; 2]; let count = unsafe { search_vector( - ctx.0, + ctx.get(), index_ptr, - VectorValueType::XB8, - qv.as_ptr(), + bytemuck::cast_slice::(qv).as_ptr(), qv.len(), 2.0, 10, @@ -442,7 +441,7 @@ mod tests { } unsafe { - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -457,16 +456,15 @@ mod tests { let vector_bytes = bytemuck::cast_slice(vector); unsafe { insert( - ctx.0, + ctx.get(), index_ptr, id_bytes.as_ptr(), id_bytes.len(), - VectorValueType::FP32, vector_bytes.as_ptr(), vector.len(), b"".as_ptr(), 0, - ) + ) > 0 } } @@ -489,9 +487,8 @@ mod tests { let count = unsafe { search_vector( - ctx.0, + ctx.get(), index_ptr, - VectorValueType::FP32, query_bytes.as_ptr(), query.len(), 0.0, @@ -543,7 +540,7 @@ mod tests { // Closest to [1,0] should be id=10 (exact match) assert_eq!(ids[0], 10); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -566,7 +563,7 @@ mod tests { assert!(ids.len() >= 2, "should return at least 2 matching vectors"); assert_eq!(ids[0], 10, "closest should still be id=10"); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -595,7 +592,7 @@ mod tests { "filtered vector EID 20 should be in results" ); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } @@ -617,7 +614,7 @@ mod tests { assert_eq!(ids_null, ids_empty); assert_eq!(dists_null, dists_empty); - drop_index(ctx.0, index_ptr); + drop_index(ctx.get(), index_ptr); } } } diff --git a/diskann-garnet/src/fsm.rs b/diskann-garnet/src/fsm.rs index 37096278f..fd6d82834 100644 --- a/diskann-garnet/src/fsm.rs +++ b/diskann-garnet/src/fsm.rs @@ -26,7 +26,7 @@ use crossbeam::queue::ArrayQueue; use std::sync::{ RwLock, - atomic::{AtomicBool, AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}, }; use thiserror::Error; @@ -39,27 +39,40 @@ const FAST_SIZE: usize = 1024; const FSM_KEY_PREFIX: u32 = u32::from_be_bytes(*b"_fsm"); #[derive(Debug, Error, PartialEq)] -pub enum FsmError { +pub(crate) enum FsmError { #[error("Garnet operation failed")] Garnet(#[from] GarnetError), #[error("requested ID is out of range {0}")] IdOutOfRange(u32), } -pub struct FreeSpaceMap { +pub(crate) struct FreeSpaceMap { + /// Garnet callbacks for reading/writing FSM keys callbacks: Callbacks, + /// A flag to signal whether there are free IDs in the FSM. + /// This is set after a scan of the FSM, and is used to prevent extraneous reads + /// of FSM blocks. has_free_ids: AtomicBool, + /// A queue of previously deleted IDs to prevent excessive reads to the FSM fast_free_list: ArrayQueue, + /// The maximum block ID stored in the FSM max_block: RwLock, + /// The next ID that will be minted if reusing previously deleted IDs is unavailable next_id: AtomicU32, + /// The total number of IDs marked used in the FSM + total_used: AtomicUsize, + /// Lock that prevents reuse of prevously used IDs. + /// This is used to disable ID reuse during quantization backfill. + reuse_lock: AtomicBool, } impl FreeSpaceMap { - pub fn new(ctx: Context, callbacks: Callbacks) -> Result { + pub(crate) fn new(ctx: &Context, callbacks: Callbacks) -> Result { let has_free_ids = AtomicBool::new(false); let fast_free_list = ArrayQueue::new(FAST_SIZE); let max_block = RwLock::new(u32::MAX); let next_id = AtomicU32::new(0); + let total_used = AtomicUsize::new(0); let mut this = Self { callbacks, @@ -67,13 +80,15 @@ impl FreeSpaceMap { fast_free_list, max_block, next_id, + total_used, + reuse_lock: AtomicBool::new(true), }; // Attempt to load state from Garnet. let block_key = Self::block_key(0); if this .callbacks - .exists_wid(ctx.term(Term::Metadata), block_key) + .exists_wid(&ctx.term(Term::Metadata), block_key) { this.load_state(ctx)?; } else { @@ -86,24 +101,25 @@ impl FreeSpaceMap { } /// Load all state from Garnet by scanning the FSM blocks. - fn load_state(&mut self, ctx: Context) -> Result<(), FsmError> { + fn load_state(&mut self, ctx: &Context) -> Result<(), FsmError> { let mut max_block_id = 0; while self .callbacks - .exists_wid(ctx.term(Term::Metadata), Self::block_key(max_block_id)) + .exists_wid(&ctx.term(Term::Metadata), Self::block_key(max_block_id)) { max_block_id += 1; } let mut block = vec![0u8; BLOCK_SIZE_BYTES]; let mut last_used_id = -1i64; + let mut total_used = 0usize; for block_id in (0..max_block_id).rev() { let block_key = Self::block_key(block_id); if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { break; } @@ -115,8 +131,9 @@ impl FreeSpaceMap { let used = bit_used(byte, bidx); if used { last_used_id = last_used_id.max(id as i64); - } else if (id as i64) < last_used_id && self.fast_free_list.push(id).is_err() { - break; + total_used += 1; + } else if (id as i64) < last_used_id { + let _ = self.fast_free_list.push(id); } id = id.saturating_sub(1); @@ -130,6 +147,8 @@ impl FreeSpaceMap { self.next_id .store((last_used_id + 1) as u32, Ordering::Release); + self.total_used.store(total_used, Ordering::Release); + if !self.fast_free_list.is_empty() { self.has_free_ids.store(true, Ordering::Release); } @@ -138,20 +157,20 @@ impl FreeSpaceMap { } /// Mark an ID as free. - pub fn mark_free(&self, ctx: Context, id: u32) -> Result<(), FsmError> { + pub(crate) fn mark_free(&self, ctx: &Context, id: u32) -> Result<(), FsmError> { // We don't care about the changed status on free. self.mark_id(ctx, id, false).map(|_| ()) } /// Mark an ID as occupied. - fn mark_used(&self, ctx: Context, id: u32) -> Result { + fn mark_used(&self, ctx: &Context, id: u32) -> Result { self.mark_id(ctx, id, true) } /// Mark an ID according to value (true = used, false = free). /// Side effects only happen if the value actually changed. The return value reflects whether /// data changed or not. - fn mark_id(&self, ctx: Context, id: u32, used: bool) -> Result { + fn mark_id(&self, ctx: &Context, id: u32, used: bool) -> Result { if id >= self.next_id.load(Ordering::Acquire) { return Err(FsmError::IdOutOfRange(id)); } @@ -166,7 +185,7 @@ impl FreeSpaceMap { let mut changed = false; if !self.callbacks.rmw_wid( - ctx.term(Term::Metadata), + &ctx.term(Term::Metadata), block_key, BLOCK_SIZE_BYTES, |data: &mut [u8]| changed = update_status(used, &mut data[byte_idx], bit_idx), @@ -174,6 +193,14 @@ impl FreeSpaceMap { return Err(FsmError::Garnet(GarnetError::Write)); } + if changed { + if used { + self.total_used.fetch_add(1, Ordering::AcqRel); + } else { + self.total_used.fetch_sub(1, Ordering::AcqRel); + } + } + // NOTE: We don't modify the free list if the id was already free. if !used && changed { // Push the id onto the fast free list. If the queue is full, ignore it. @@ -184,7 +211,7 @@ impl FreeSpaceMap { Ok(changed) } - pub fn is_free(&self, ctx: Context, id: u32) -> Result { + pub(crate) fn is_free(&self, ctx: &Context, id: u32) -> Result { if id >= self.next_id.load(Ordering::Acquire) { return Err(FsmError::IdOutOfRange(id)); } @@ -200,7 +227,7 @@ impl FreeSpaceMap { if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { return Err(FsmError::Garnet(GarnetError::Read)); } @@ -213,8 +240,8 @@ impl FreeSpaceMap { /// Return a a new ID. /// This may be a a fresh ID larger than all the others, or it may be a reused ID that /// previously belonged to a deleted element. The returned ID is marked as used. - pub fn next_id(&self, ctx: Context) -> Result { - if self.has_free_ids.load(Ordering::Acquire) { + pub(crate) fn next_id(&self, ctx: &Context) -> Result { + if self.can_reuse() && self.has_free_ids.load(Ordering::Acquire) { // We retry reusing a freed ID until there are none or we get one and marking it used // succeeds in changing the value. loop { @@ -266,10 +293,14 @@ impl FreeSpaceMap { Ok(id) } - pub fn max_id(&self) -> u32 { + pub(crate) fn max_id(&self) -> u32 { self.next_id.load(Ordering::Acquire).saturating_sub(1) } + pub(crate) fn total_used(&self) -> usize { + self.total_used.load(Ordering::Acquire) + } + /// Return the FSM block number, byte index, and bit index for a given ID. /// The block number is the block which stores this ID, the byte index is byte offset /// within the block which contains the status bits, and the bit index is the bit index @@ -289,7 +320,7 @@ impl FreeSpaceMap { } /// Scan the FSM blocks to fill up fast_free_list. - fn refill_fast_free_list(&self, ctx: Context) -> Result { + fn refill_fast_free_list(&self, ctx: &Context) -> Result { // NOTE: We take a write lock to prevent multiple refills happening simultaneously. #[allow(clippy::readonly_write_lock)] let max_block = self.max_block.write().unwrap(); @@ -301,23 +332,38 @@ impl FreeSpaceMap { let mut has_free_ids = false; let mut id = 0u32; - 'scan: for block_id in 0..*max_block { + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; + 'scan: for block_id in 0..=*max_block { + if id >= self.next_id.load(Ordering::Acquire) { + // Don't look at IDs outside the current range. + break; + } + let block_key = Self::block_key(block_id); - let mut block = vec![0u8; BLOCK_SIZE_BYTES]; if !self .callbacks - .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) { return Err(FsmError::Garnet(GarnetError::Read)); } for &byte in &block { + if id >= self.next_id.load(Ordering::Acquire) { + // Don't look at IDs outside the current range. + break 'scan; + } + if byte == 0xff { id += 8; continue; } for bidx in 0..8 { + if id >= self.next_id.load(Ordering::Acquire) { + // Don't look at IDs outside the current range. + break 'scan; + } + if !bit_used(byte, bidx) { has_free_ids = true; self.has_free_ids.store(true, Ordering::Release); @@ -338,7 +384,7 @@ impl FreeSpaceMap { } /// Ensure enough blocks exist in the FSM to hold `id`. - fn expand_to(&self, ctx: Context, id: u32) -> Result<(), FsmError> { + fn expand_to(&self, ctx: &Context, id: u32) -> Result<(), FsmError> { let (block_id, _, _) = self.indexes_for_id(id); let mut max_block = self.max_block.write().unwrap(); if *max_block == u32::MAX || block_id == *max_block + 1 { @@ -347,7 +393,7 @@ impl FreeSpaceMap { if !self .callbacks - .write_wid(ctx.term(Term::Metadata), block_key, &block_bytes) + .write_wid(&ctx.term(Term::Metadata), block_key, &block_bytes) { return Err(FsmError::Garnet(GarnetError::Write)); } @@ -363,6 +409,60 @@ impl FreeSpaceMap { Ok(()) } + + /// Visit each used id in the FSM, invoking f on each id. + pub(crate) fn visit_used(&self, ctx: &Context, mut f: F) -> Result<(), FsmError> + where + F: FnMut(u32) -> bool, + { + let max_block = { *self.max_block.read().unwrap() }; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; + let mut id = 0u32; + + for block_id in 0..max_block + 1 { + let block_key = Self::block_key(block_id); + if !self + .callbacks + .read_single_wid(&ctx.term(Term::Metadata), block_key, &mut block) + { + return Err(FsmError::Garnet(GarnetError::Read)); + } + + for &byte in &block { + if byte == 0x00 { + id += 8; + continue; + } + + for bidx in 0..8 { + if bit_used(byte, bidx) { + let keep_going = f(id); + if !keep_going { + return Ok(()); + } + } + id += 1; + } + } + } + + Ok(()) + } + + /// Returns whether previously deleted IDs may be reused. + fn can_reuse(&self) -> bool { + self.reuse_lock.load(Ordering::Acquire) + } + + /// Prevent the reuse of previously deleted IDs. + pub(crate) fn lock_reuse(&self) { + self.reuse_lock.store(false, Ordering::Release); + } + + /// Resume reuse of previously deleted IDs. + pub(crate) fn unlock_reuse(&self) { + self.reuse_lock.store(true, Ordering::Release); + } } /// Return whether the `bidx`th bit is set in byte, where bits are labeled from left to right. @@ -394,9 +494,10 @@ mod tests { where F: FnMut(&[u8]), { + let ctx = Context::new(0); let key = FreeSpaceMap::block_key(block_id); let block = store - .get(Context(0).term(Term::Metadata).0, bytemuck::bytes_of(&key)) + .get(ctx.term(Term::Metadata).get(), bytemuck::bytes_of(&key)) .unwrap(); f(&block); } @@ -404,9 +505,10 @@ mod tests { #[test] fn create_fresh() { let store = Store; + let ctx = Context::new(0); // A fresh FSM should result in full block of all zeroes. - let _fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let _fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); verify_block(&store, 0, |block| { assert_eq!(block.len(), BLOCK_SIZE_BYTES); assert_eq!(block.iter().filter(|b| **b != 0).count(), 0); @@ -416,30 +518,32 @@ mod tests { #[test] fn basic_next_id() { let store = Store; + let ctx = Context::new(0); - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); // A fresh FSM should not have free IDs. assert!(!fsm.has_free_ids.load(std::sync::atomic::Ordering::Acquire)); // It should return sequentials IDs starting from 0. - assert_eq!(fsm.next_id(Context(0)).unwrap(), 0); - assert_eq!(fsm.next_id(Context(0)).unwrap(), 1); + assert_eq!(fsm.next_id(&ctx).unwrap(), 0); + assert_eq!(fsm.next_id(&ctx).unwrap(), 1); } #[test] fn basic_delete() { let store = Store; + let ctx = Context::new(0); - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); // Deleting a ID out of range should return an error. - let res = fsm.mark_free(Context(0), 0); + let res = fsm.mark_free(&ctx, 0); assert!(res.is_err()); assert_eq!(res.err().unwrap(), FsmError::IdOutOfRange(0)); for _ in 0u32..64 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&ctx).unwrap(); } verify_block(&store, 0, |block| { @@ -451,8 +555,8 @@ mod tests { } }); - fsm.mark_free(Context(0), 37).unwrap(); - fsm.mark_free(Context(0), 9).unwrap(); + fsm.mark_free(&ctx, 37).unwrap(); + fsm.mark_free(&ctx, 9).unwrap(); verify_block(&store, 0, |block| { assert_eq!(block[0], 0b11111111); @@ -474,21 +578,22 @@ mod tests { #[test] fn basic_id_reuse() { let store = Store; + let ctx = Context::new(0); - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); for _ in 0u32..64 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&ctx).unwrap(); } // Next ID should be 64 since none are free before that. - assert_eq!(fsm.next_id(Context(0)).unwrap(), 64); + assert_eq!(fsm.next_id(&ctx).unwrap(), 64); // After deleting an ID, it should be returned from `next_id()`. - fsm.mark_free(Context(0), 37).unwrap(); - assert_eq!(fsm.next_id(Context(0)).unwrap(), 37); + fsm.mark_free(&ctx, 37).unwrap(); + assert_eq!(fsm.next_id(&ctx).unwrap(), 37); // Once all free IDs are used, fresh ones should be returned. - assert_eq!(fsm.next_id(Context(0)).unwrap(), 65); + assert_eq!(fsm.next_id(&ctx).unwrap(), 65); // And has_free_ids should now be false. assert!(!fsm.has_free_ids.load(Ordering::Acquire)); } @@ -496,33 +601,35 @@ mod tests { #[test] fn basic_recovery() { let store = Store; + let ctx = Context::new(0); - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); for _ in 0u32..64 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&ctx).unwrap(); } - fsm.mark_free(Context(0), 37).unwrap(); + fsm.mark_free(&ctx, 37).unwrap(); // Loading FSM from store should recover all the state. - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); assert_eq!(*fsm.max_block.read().unwrap(), 0); assert_eq!(fsm.next_id.load(Ordering::Acquire), 64); assert!(fsm.has_free_ids.load(Ordering::Acquire)); assert_eq!(fsm.fast_free_list.len(), 1); - assert_eq!(fsm.next_id(Context(0)).unwrap(), 37); + assert_eq!(fsm.next_id(&ctx).unwrap(), 37); } #[test] fn dynamic_expansion() { let store = Store; + let ctx = Context::new(0); - let fsm = FreeSpaceMap::new(Context(0), store.callbacks()).unwrap(); + let fsm = FreeSpaceMap::new(&ctx, store.callbacks()).unwrap(); // Asking for more than BLOCK_SIZE_IDS will force another FSM block to be allocated. for _ in 0u32..BLOCK_SIZE_IDS as u32 + 1 { - let _ = fsm.next_id(Context(0)).unwrap(); + let _ = fsm.next_id(&ctx).unwrap(); } } } diff --git a/diskann-garnet/src/garnet.rs b/diskann-garnet/src/garnet.rs index 85ba0cce5..526cee264 100644 --- a/diskann-garnet/src/garnet.rs +++ b/diskann-garnet/src/garnet.rs @@ -3,17 +3,26 @@ * Licensed under the MIT license. */ -use std::{ffi::c_void, fmt, mem, ops::Deref, slice}; +use std::{ + ffi::c_void, + fmt, mem, + ops::Deref, + slice, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, +}; use diskann::provider::ExecutionContext; use thiserror::Error; /// Bitmask for extracting the Term bits from a Context. /// Must have enough bits to represent all Term variants (max value is 6, needs 3 bits). -pub const TERM_BITMASK: u64 = (1 << 3) - 1; +pub(crate) const TERM_BITMASK: u64 = (1 << 3) - 1; #[derive(Debug)] -pub enum Term { +pub(crate) enum Term { Vector = 0, Neighbors = 1, Quantized = 2, @@ -23,28 +32,62 @@ pub enum Term { ExtMap = 6, } -#[derive(Copy, Clone)] -pub struct Context(pub u64); +#[derive(Clone, Debug)] +pub(crate) struct Context { + inner: u64, + quantizer_ready: Arc, +} impl Context { - pub fn term(&self, kind: Term) -> Self { - Context(self.0 | (kind as u64 & TERM_BITMASK)) + pub(crate) fn new(inner: u64) -> Self { + Self { + inner, + quantizer_ready: Arc::new(AtomicBool::new(false)), + } + } + + #[cfg(test)] + pub(crate) fn get(&self) -> u64 { + self.inner + } + + pub(crate) fn term(&self, kind: Term) -> Self { + let Context { + inner, + quantizer_ready, + } = self; + let inner = *inner | (kind as u64 & TERM_BITMASK); + let quantizer_ready = quantizer_ready.clone(); + + Self { + inner, + quantizer_ready, + } + } + + pub(crate) fn quantizer_ready(&self) -> bool { + self.quantizer_ready.load(Ordering::Acquire) + } + + pub(crate) fn set_quantizer_ready(&self) { + self.quantizer_ready.store(true, Ordering::Release); } } impl ExecutionContext for Context {} -pub type ReadCallback = +pub(crate) type ReadCallback = unsafe extern "C" fn(u64, u32, *const u8, usize, ReadDataCallback, *mut c_void); -pub type WriteCallback = unsafe extern "C" fn(u64, *const u8, usize, *const u8, usize) -> bool; -pub type DeleteCallback = unsafe extern "C" fn(u64, *const u8, usize) -> bool; -pub type ReadModifyWriteCallback = +pub(crate) type WriteCallback = + unsafe extern "C" fn(u64, *const u8, usize, *const u8, usize) -> bool; +pub(crate) type DeleteCallback = unsafe extern "C" fn(u64, *const u8, usize) -> bool; +pub(crate) type ReadModifyWriteCallback = unsafe extern "C" fn(u64, *const u8, usize, usize, RmwDataCallback, *mut c_void) -> bool; -pub type ReadDataCallback = unsafe extern "C" fn(u32, *mut c_void, *const u8, usize); -pub type RmwDataCallback = unsafe extern "C" fn(*mut c_void, *mut u8, usize); +pub(crate) type ReadDataCallback = unsafe extern "C" fn(u32, *mut c_void, *const u8, usize); +pub(crate) type RmwDataCallback = unsafe extern "C" fn(*mut c_void, *mut u8, usize); #[derive(Copy, Clone)] -pub struct Callbacks { +pub(crate) struct Callbacks { read_callback: ReadCallback, write_callback: WriteCallback, delete_callback: DeleteCallback, @@ -52,7 +95,7 @@ pub struct Callbacks { } impl Callbacks { - pub fn new( + pub(crate) fn new( read_callback: ReadCallback, write_callback: WriteCallback, delete_callback: DeleteCallback, @@ -67,32 +110,33 @@ impl Callbacks { } #[cfg(test)] - pub fn read_callback(&self) -> ReadCallback { + pub(crate) fn read_callback(&self) -> ReadCallback { self.read_callback } #[cfg(test)] - pub fn write_callback(&self) -> WriteCallback { + pub(crate) fn write_callback(&self) -> WriteCallback { self.write_callback } #[cfg(test)] - pub fn delete_callback(&self) -> DeleteCallback { + pub(crate) fn delete_callback(&self) -> DeleteCallback { self.delete_callback } #[cfg(test)] - pub fn rmw_callback(&self) -> ReadModifyWriteCallback { + pub(crate) fn rmw_callback(&self) -> ReadModifyWriteCallback { self.rmw_callback } - pub fn exists_iid(&self, ctx: Context, id: u32) -> bool { + #[cfg(test)] + pub(crate) fn exists_iid(&self, ctx: &Context, id: u32) -> bool { let key = [4, id]; // SAFETY: Key bytes are preceded by 4 bytes of space. unsafe { self.exists_raw(ctx, bytemuck::bytes_of(&key)) } } - pub fn exists_wid(&self, ctx: Context, key: u64) -> bool { + pub(crate) fn exists_wid(&self, ctx: &Context, key: u64) -> bool { // NOTE: the length is bit-shifted so that we have a u32 in the lower half of the u64. let mut key = [8 << 32, key]; let key_bytes = bytemuck::bytes_of_mut(&mut key); @@ -100,7 +144,11 @@ impl Callbacks { unsafe { self.exists_raw(ctx, &key_bytes[4..]) } } - pub fn exists_eid(&self, ctx: Context, id: &GarnetId) -> bool { + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] + pub(crate) fn exists_eid(&self, ctx: &Context, id: &GarnetId) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.exists_raw(ctx, id) } } @@ -109,7 +157,7 @@ impl Callbacks { /// /// NOTE: The key bytes must be preceded by 4 valid bytes that Garnet can write into. /// This invariant must be checked by the caller. - unsafe fn exists_raw(&self, ctx: Context, key: &[u8]) -> bool { + unsafe fn exists_raw(&self, ctx: &Context, key: &[u8]) -> bool { let mut called = false; let mut cb = |_, _: &[u8]| { called = true; @@ -117,7 +165,7 @@ impl Callbacks { unsafe { (self.read_callback)( - ctx.0, + ctx.inner, 1, key.as_ptr(), key.len(), @@ -130,9 +178,9 @@ impl Callbacks { } #[must_use] - pub fn read_single_iid( + pub(crate) fn read_single_iid( &self, - ctx: Context, + ctx: &Context, id: u32, value: &mut [D], ) -> bool { @@ -148,9 +196,9 @@ impl Callbacks { } #[must_use] - pub fn read_single_wid( + pub(crate) fn read_single_wid( &self, - ctx: Context, + ctx: &Context, key: u64, value: &mut [D], ) -> bool { @@ -168,9 +216,9 @@ impl Callbacks { } #[must_use] - pub fn read_single_eid( + pub(crate) fn read_single_eid( &self, - ctx: Context, + ctx: &Context, id: &GarnetId, value: &mut [D], ) -> bool { @@ -189,7 +237,7 @@ impl Callbacks { /// NOTE: The key bytes must be preceded by 4 valid bytes that Garnet can write into. /// This invariant must be checked by the caller. #[must_use] - unsafe fn read_single_raw(&self, ctx: Context, key: &[u8], value: &mut [u8]) -> bool { + unsafe fn read_single_raw(&self, ctx: &Context, key: &[u8], value: &mut [u8]) -> bool { let mut found = false; let mut cb = |_, data: &[u8]| { found = true; @@ -198,7 +246,7 @@ impl Callbacks { unsafe { (self.read_callback)( - ctx.0, + ctx.inner, 1, key.as_ptr(), key.len(), @@ -211,8 +259,12 @@ impl Callbacks { } // ids must be passed as 4-byte length prefixed u32s. so [4, I1_u32, 4, I2_u32, ...] - pub fn read_multi_lpiid<'a, F, T: bytemuck::Pod>(&self, ctx: Context, ids: &[u32], mut f: F) - where + pub(crate) fn read_multi_lpiid<'a, F, T: bytemuck::Pod>( + &self, + ctx: &Context, + ids: &[u32], + mut f: F, + ) where F: FnMut(u32, &'a [T]), { if ids.is_empty() { @@ -221,7 +273,7 @@ impl Callbacks { unsafe { (self.read_callback)( - ctx.0, + ctx.inner, ids.len() as u32 / 2, bytemuck::must_cast_slice::<_, u8>(ids).as_ptr(), mem::size_of_val(ids), @@ -236,7 +288,11 @@ impl Callbacks { /// This function allocations inside the read callback since it can't know the size /// of the value up front. #[must_use] - pub fn read_varsize_iid(&self, ctx: Context, id: u32) -> Option> { + pub(crate) fn read_varsize_iid( + &self, + ctx: &Context, + id: u32, + ) -> Option> { let key = [4, id]; let mut result = None; let mut cb = |_, data: &[u8]| { @@ -254,7 +310,7 @@ impl Callbacks { // SAFETY: Key bytes are preceded by 4 bytes of extra space. unsafe { (self.read_callback)( - ctx.0, + ctx.inner, 1, bytemuck::bytes_of(&key).as_ptr(), mem::size_of_val(&key), @@ -267,7 +323,7 @@ impl Callbacks { } #[must_use] - pub fn write_iid(&self, ctx: Context, id: u32, value: &[D]) -> bool { + pub(crate) fn write_iid(&self, ctx: &Context, id: u32, value: &[D]) -> bool { let key = [0, id]; // SAFETY: Key bytes are preceded by 4 bytes of extra space. unsafe { @@ -280,7 +336,7 @@ impl Callbacks { } #[must_use] - pub fn write_wid(&self, ctx: Context, key: u64, value: &[D]) -> bool { + pub(crate) fn write_wid(&self, ctx: &Context, key: u64, value: &[D]) -> bool { let key = [0, key]; // SAFETY: Key bytes are preceded by 8 bytes of extra space. unsafe { @@ -293,7 +349,12 @@ impl Callbacks { } #[must_use] - pub fn write_eid(&self, ctx: Context, id: &GarnetId, value: &[D]) -> bool { + pub(crate) fn write_eid( + &self, + ctx: &Context, + id: &GarnetId, + value: &[D], + ) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.write_raw(ctx, id, bytemuck::must_cast_slice::(value)) } } @@ -303,22 +364,22 @@ impl Callbacks { /// NOTE: The key bytes must be preceded by 4 valid bytes that Garnet can write into. /// This invariant must be checked by the caller. #[must_use] - unsafe fn write_raw(&self, ctx: Context, key: &[u8], value: &[u8]) -> bool { + unsafe fn write_raw(&self, ctx: &Context, key: &[u8], value: &[u8]) -> bool { let value_ptr = value.as_ptr(); let value_len = value.len(); - unsafe { (self.write_callback)(ctx.0, key.as_ptr(), key.len(), value_ptr, value_len) } + unsafe { (self.write_callback)(ctx.inner, key.as_ptr(), key.len(), value_ptr, value_len) } } #[must_use] - pub fn delete_iid(&self, ctx: Context, id: u32) -> bool { + pub(crate) fn delete_iid(&self, ctx: &Context, id: u32) -> bool { let key = [0, id]; - unsafe { (self.delete_callback)(ctx.0, bytemuck::bytes_of(&key[1]).as_ptr(), 4) } + unsafe { (self.delete_callback)(ctx.inner, bytemuck::bytes_of(&key[1]).as_ptr(), 4) } } #[must_use] - pub fn delete_eid(&self, ctx: Context, id: &GarnetId) -> bool { + pub(crate) fn delete_eid(&self, ctx: &Context, id: &GarnetId) -> bool { let id: &[u8] = id; - unsafe { (self.delete_callback)(ctx.0, id.as_ptr(), id.len()) } + unsafe { (self.delete_callback)(ctx.inner, id.as_ptr(), id.len()) } } /// Modify a value in Garnet by internal ID. @@ -328,7 +389,13 @@ impl Callbacks { /// /// `f` should not panic. #[must_use] - pub fn rmw_iid<'a, F, T>(&self, ctx: Context, id: u32, write_len: usize, mut f: F) -> bool + pub(crate) fn rmw_iid<'a, F, T>( + &self, + ctx: &Context, + id: u32, + write_len: usize, + mut f: F, + ) -> bool where F: FnMut(&'a mut [T]), T: bytemuck::Pod, @@ -357,7 +424,13 @@ impl Callbacks { /// /// `f` should not panic. #[must_use] - pub fn rmw_wid<'a, F, T>(&self, ctx: Context, key: u64, write_len: usize, mut f: F) -> bool + pub(crate) fn rmw_wid<'a, F, T>( + &self, + ctx: &Context, + key: u64, + write_len: usize, + mut f: F, + ) -> bool where F: FnMut(&'a mut [T]), T: bytemuck::Pod, @@ -389,13 +462,13 @@ impl Callbacks { /// /// `f` should not panic. #[must_use] - unsafe fn rmw_raw<'a, F>(&self, ctx: Context, key: &[u8], write_len: usize, mut f: F) -> bool + unsafe fn rmw_raw<'a, F>(&self, ctx: &Context, key: &[u8], write_len: usize, mut f: F) -> bool where F: FnMut(&'a mut [u8]), { unsafe { (self.rmw_callback)( - ctx.0, + ctx.inner, key.as_ptr(), key.len(), write_len, @@ -459,7 +532,7 @@ where } #[derive(Debug, Error, PartialEq)] -pub enum GarnetError { +pub(crate) enum GarnetError { #[error("garnet read failed")] Read, #[error("garnet write failed")] @@ -477,12 +550,12 @@ pub enum GarnetError { /// Dereferencing this type will return a slice to the actual ID bytes, without the padding, /// which makes this interchangeable in most respects with using a raw `Box<[u8]>`. #[derive(Clone, PartialEq)] -pub struct GarnetId { +pub(crate) struct GarnetId { inner: Box<[u8]>, } impl GarnetId { - pub fn as_prefixed_key_bytes(&self) -> &[u8] { + pub(crate) fn as_prefixed_key_bytes(&self) -> &[u8] { &self.inner } } diff --git a/diskann-garnet/src/labels.rs b/diskann-garnet/src/labels.rs index f569cb30d..8e0a0b482 100644 --- a/diskann-garnet/src/labels.rs +++ b/diskann-garnet/src/labels.rs @@ -36,7 +36,7 @@ use diskann::graph::index::QueryLabelProvider; /// this struct. In practice, the struct is created and dropped within a /// single FFI search call. #[derive(Debug)] -pub struct GarnetQueryLabelProvider { +pub(crate) struct GarnetQueryLabelProvider { data: *const u8, len: usize, } @@ -68,7 +68,7 @@ impl GarnetQueryLabelProvider { /// The caller must ensure `data` is valid for reads of `len` bytes /// and that the memory remains valid and unmodified for the lifetime /// of this struct. - pub unsafe fn from_raw(data: *const u8, len: usize) -> Self { + pub(crate) unsafe fn from_raw(data: *const u8, len: usize) -> Self { if data.is_null() || len == 0 { return Self { data: std::ptr::null(), @@ -80,7 +80,7 @@ impl GarnetQueryLabelProvider { /// Construct a `GarnetQueryLabelProvider` from a byte slice. #[allow(dead_code)] - pub fn from_bytes(bytes: &[u8]) -> Self { + pub(crate) fn from_bytes(bytes: &[u8]) -> Self { if bytes.is_empty() { Self { data: std::ptr::null(), diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index a66cd7ab0..1eea0d4f2 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -51,6 +51,7 @@ mod fsm; mod garnet; mod labels; mod provider; +mod quantization; #[cfg(test)] mod test_utils; @@ -71,7 +72,7 @@ impl From for IndexState { } } -pub struct Index { +pub(crate) struct Index { inner: Box, quant_type: VectorQuantType, state: AtomicUsize, @@ -92,7 +93,10 @@ pub enum VectorQuantType { NoQuant, Bin, Q8, - XPreQ8, + XNoQuantU8, + XNoQuantI8, + XBinI8, + XBinU8, } struct SearchResults<'a> { @@ -175,7 +179,14 @@ fn create_index_impl( callbacks: Callbacks, context: Context, ) -> Result, GarnetProviderError> { - let provider = GarnetProvider::::new(dim, metric_type, max_degree, callbacks, context)?; + let provider = GarnetProvider::::new( + dim, + quant_type, + metric_type, + max_degree, + callbacks, + &context, + )?; let state = if provider.start_points_exist() { AtomicUsize::new(IndexState::Ready as usize) } else { @@ -213,11 +224,12 @@ pub unsafe extern "C" fn create_index( }; let target_degree = (max_degree as f32 / GRAPH_SLACK_FACTOR) as usize; + let config = if let Ok(config) = config::Builder::new( target_degree, config::MaxDegree::Value(max_degree as usize), l_build as usize, - config::PruneKind::TriangleInequality, + metric_type.into(), ) .build() { @@ -226,11 +238,12 @@ pub unsafe extern "C" fn create_index( return ptr::null(); }; - let context = Context(ctx); + let context = Context::new(ctx); let callbacks = Callbacks::new(read_callback, write_callback, delete_callback, rmw_callback); match quant_type { - VectorQuantType::XPreQ8 => { + VectorQuantType::Invalid => ptr::null(), + VectorQuantType::XNoQuantU8 | VectorQuantType::XBinU8 => { if let Ok(index) = create_index_impl::( quant_type, config, @@ -245,7 +258,22 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - VectorQuantType::NoQuant => { + VectorQuantType::XNoQuantI8 | VectorQuantType::XBinI8 => { + if let Ok(index) = create_index_impl::( + quant_type, + config, + dim as usize, + metric_type, + max_degree as usize, + callbacks, + context, + ) { + Arc::into_raw(index).cast::() + } else { + ptr::null() + } + } + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { if let Ok(index) = create_index_impl::( quant_type, config, @@ -260,7 +288,6 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - _ => ptr::null(), } } @@ -303,40 +330,27 @@ impl<'a> From> for PolyCow<'a> { fn interpret_vector<'a>( quant_type: VectorQuantType, - vector_value_type: VectorValueType, vector_data: &'a *const u8, vector_len: usize, ) -> Option> { - let vector_len_bytes = match vector_value_type { - VectorValueType::FP32 => vector_len * 4, - VectorValueType::XB8 => vector_len, - VectorValueType::Invalid => return None, + let vector_len_bytes = match quant_type { + VectorQuantType::Invalid => return None, + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => vector_len * 4, + VectorQuantType::XNoQuantU8 + | VectorQuantType::XNoQuantI8 + | VectorQuantType::XBinU8 + | VectorQuantType::XBinI8 => vector_len, }; let v = unsafe { slice::from_raw_parts(*vector_data, vector_len_bytes) }; - let v = match vector_value_type { - VectorValueType::Invalid => return None, - VectorValueType::FP32 => match quant_type { - VectorQuantType::XPreQ8 => { - let mut bp = if let Ok(bp) = Poly::broadcast(0u8, vector_len, AlignToEight) { - bp - } else { - return None; - }; - for (idx, e) in bp.iter_mut().enumerate() { - let el_size = mem::size_of::(); - *e = f32::from_le_bytes( - v[idx * el_size..(idx + 1) * el_size].try_into().unwrap(), - ) as u8; - } - PolyCow::from(bp) - } - VectorQuantType::NoQuant if v.as_ptr().align_offset(4) == 0 => { + let v = match quant_type { + VectorQuantType::Invalid => return None, + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { + if v.as_ptr().align_offset(mem::align_of::()) == 0 { // pointer is correctly aligned to interpret as f32 PolyCow::from(v) - } - VectorQuantType::NoQuant => { + } else { // need to copy f32 data as it is unaligned let mut fp = if let Ok(fp) = Poly::broadcast(0u8, vector_len_bytes, AlignToEight) { fp @@ -346,35 +360,32 @@ fn interpret_vector<'a>( fp.copy_from_slice(v); PolyCow::from(fp) } - _ => { - return None; - } - }, - VectorValueType::XB8 => match quant_type { - VectorQuantType::XPreQ8 => PolyCow::from(v), - VectorQuantType::NoQuant => { - let mut fp = if let Ok(p) = - Poly::broadcast(0u8, vector_len_bytes * mem::size_of::(), AlignToEight) - { - p - } else { - return None; - }; - for (fe, be) in bytemuck::cast_slice_mut::(&mut fp) - .iter_mut() - .zip(v) - { - *fe = *be as f32; - } - PolyCow::from(fp) - } - _ => return None, - }, + } + VectorQuantType::XNoQuantU8 + | VectorQuantType::XNoQuantI8 + | VectorQuantType::XBinU8 + | VectorQuantType::XBinI8 => PolyCow::from(v), }; Some(v) } +enum InsertResult { + Fail, + Success, + SuccessStartTraining, +} + +impl From for u8 { + fn from(value: InsertResult) -> Self { + match value { + InsertResult::Fail => 0, + InsertResult::Success => 1, + InsertResult::SuccessStartTraining => 2, + } + } +} + /// # Safety /// /// FFI @@ -384,33 +395,27 @@ pub unsafe extern "C" fn insert( index_ptr: *const c_void, id_data: *const u8, id_len: usize, - vector_value_type: VectorValueType, vector_data: *const u8, vector_len: usize, attribute_data: *const u8, attribute_len: usize, -) -> bool { +) -> u8 { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); - let v = if let Some(v) = interpret_vector( - index.quant_type, - vector_value_type, - &vector_data, - vector_len, - ) { + let v = if let Some(v) = interpret_vector(index.quant_type, &vector_data, vector_len) { v } else { - return false; + return InsertResult::Fail.into(); }; if let Some(_err) = ensure_index_ready_or_init(index, || index.inner.maybe_set_start_point(&ctx, &v).err()) { - return false; + return InsertResult::Fail.into(); }; // Write attributes to garnet @@ -420,11 +425,22 @@ pub unsafe extern "C" fn insert( &[] }; if index.inner.set_attributes(&ctx, &id, attr_data).is_err() { - return false; + return InsertResult::Fail.into(); } + let old_ready = ctx.quantizer_ready(); + // Insert the vector - index.inner.insert(&ctx, &id, &v).is_ok() + if index.inner.insert(&ctx, &id, &v).is_ok() { + let ready = ctx.quantizer_ready(); + if !old_ready && ready { + InsertResult::SuccessStartTraining.into() + } else { + InsertResult::Success.into() + } + } else { + InsertResult::Fail.into() + } } fn ensure_index_ready_or_init(index: &Index, init: F) -> Option @@ -466,6 +482,34 @@ where None } +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn build_quant_table(context: u64, index_ptr: *const c_void) -> bool { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context::new(context); + + index.inner.train_quantizer(&ctx) +} + +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn backfill_quant_vectors( + context: u64, + index_ptr: *const c_void, + task_index: usize, + task_count: usize, +) { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context::new(context); + index + .inner + .backfill_quant_vectors(&ctx, task_index, task_count); +} + /// # Safety /// /// FFI @@ -479,7 +523,7 @@ pub unsafe extern "C" fn set_attribute( attribute_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(context); + let ctx = Context::new(context); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); @@ -505,7 +549,6 @@ pub unsafe extern "C" fn set_attribute( pub unsafe extern "C" fn search_vector( ctx: u64, index_ptr: *const c_void, - vector_value_type: VectorValueType, vector_data: *const u8, vector_len: usize, _delta: f32, @@ -521,18 +564,13 @@ pub unsafe extern "C" fn search_vector( ) -> i32 { let index = unsafe { &*index_ptr.cast::() }; - let v = if let Some(v) = interpret_vector( - index.quant_type, - vector_value_type, - &vector_data, - vector_len, - ) { + let v = if let Some(v) = interpret_vector(index.quant_type, &vector_data, vector_len) { v } else { return -1; }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let mut output = SearchResults::new( output_ids, @@ -597,7 +635,7 @@ pub unsafe extern "C" fn search_element( let index = unsafe { &*index_ptr.cast::() }; let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); - let ctx = Context(ctx); + let ctx = Context::new(ctx); let mut output = SearchResults::new( output_ids, @@ -665,7 +703,7 @@ pub unsafe extern "C" fn remove( id_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); @@ -697,7 +735,7 @@ pub unsafe extern "C" fn check_internal_id_valid( internal_id_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let internal_id_bytes = unsafe { slice::from_raw_parts(internal_id_data, internal_id_len) }; if internal_id_bytes.len() != mem::size_of::() { return false; @@ -720,7 +758,7 @@ pub unsafe extern "C" fn check_external_id_valid( id_len: usize, ) -> bool { let index = unsafe { &*index_ptr.cast::() }; - let ctx = Context(ctx); + let ctx = Context::new(ctx); let id_bytes = unsafe { slice::from_raw_parts(id_data, id_len) }; let id = GarnetId::from(id_bytes); diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index a0dbd7084..4089a847f 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -11,7 +11,7 @@ use diskann::{ config::defaults::MAX_OCCLUSION_SIZE, glue::{ self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchPostProcessStep, SearchStrategy, }, workingset::{self, map::Entry}, }, @@ -22,19 +22,34 @@ use diskann::{ }, utils::VectorRepr, }; -use diskann_providers::model::graph::provider::async_::common::FullPrecision; -use diskann_utils::Reborrow; -use diskann_utils::object_pool::{AsPooled, ObjectPool, PooledRef, Undef}; -use diskann_vector::{PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric}; +use diskann_quantization::alloc::{AllocatorError, Poly}; +use diskann_utils::{Reborrow, views::Matrix}; +use diskann_utils::{ + object_pool::{AsPooled, ObjectPool, PooledRef, Undef}, + views::MatrixView, +}; +use diskann_vector::{ + DistanceFunction, PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric, +}; use std::{ - future, mem, + any::TypeId, + future, + marker::PhantomData, + mem, ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, + time::SystemTime, }; use thiserror::Error; use crate::{ + VectorQuantType, + alloc::AlignToEight, fsm::{FreeSpaceMap, FsmError}, garnet::{Callbacks, Context, GarnetError, GarnetId, Term}, + quantization::{ + self, DynDistanceComputer, DynQueryComputer, GarnetQuantizer, GarnetQuantizerError, + }, }; #[derive(Clone)] @@ -64,14 +79,60 @@ impl AsPooled for AdjList { } } +/// A type erased vector, properly aligned so that it can hold any size element +/// (up to 8 bytes long). +pub(crate) struct DynVector { + inner: Poly<[u8], AlignToEight>, + ty: PhantomData, +} + +impl DynVector { + fn new(inner: Poly<[u8], AlignToEight>) -> Self { + Self { + inner, + ty: PhantomData, + } + } +} + +impl Deref for DynVector { + type Target = Poly<[u8], AlignToEight>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for DynVector { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a, T: VectorRepr> Reborrow<'a> for DynVector { + type Target = &'a [u8]; + + fn reborrow(&'a self) -> Self::Target { + &self.inner + } +} + #[derive(Debug, Error)] -pub enum GarnetProviderError { +pub(crate) enum GarnetProviderError { #[error("Garnet operation failed")] Garnet(#[from] GarnetError), #[error("FSM error")] Fsm(#[from] FsmError), #[error("Start point invalid")] StartPoint, + #[error("Allocation failed")] + AllocFailed(#[from] AllocatorError), + #[error("Invalid quantizer for vector data")] + InvalidQuantizer, + #[error("Quantizer error: {0}")] + Quantizer(#[from] GarnetQuantizerError), + #[error("Post processing error: {0}")] + PostProcessing(Box), } impl From for ANNError { @@ -83,25 +144,50 @@ impl From for ANNError { diskann::always_escalate!(GarnetProviderError); -pub struct GarnetProvider { +/// The Garnet DataProvider implementation. +pub(crate) struct GarnetProvider { + /// Dimension of the full precision vectors dim: usize, + /// Metric to use for comparing distances metric_type: Metric, + /// Maximum degree of the graph. + /// Note: Unlike DiskANN, this is the true maximum. Neighbors can never exceed this degree. max_degree: usize, + /// Garnet storage engine callbacks callbacks: Callbacks, + /// The quantizer the index will use, or None if NOQUANT is used. + quantizer: Option>, + /// Tracks whether quantization backfill is complete. + all_quantized: AtomicBool, + /// Per job tracker for quantization backfill completion + backfills_completed: AtomicU64, + /// Tracks whether training has already started. + training_started: AtomicU64, + /// Pool of pre-allocated buffers to use for neighbor lists id_buffer_pool: ObjectPool, + /// Pool of pre-allocated buffers to use for filtering IDs filtered_ids_pool: ObjectPool>, + /// Pool of pre-allocated buffers to use for quantizing vectors + quant_buffer_pool: ObjectPool>, + /// Small cache for the start points' neighbors neighbor_cache: DashMap, foldhash::fast::RandomState>, - start_point_cache: DashMap, foldhash::fast::RandomState>, + /// Small cache for the start points' full precision vector data + start_point_cache: DashMap, foldhash::fast::RandomState>, + /// Small cache for the start points' quantized vector data + start_point_quant_cache: DashMap, foldhash::fast::RandomState>, + /// Free space map to track internal IDs fsm: FreeSpaceMap, + _phantom: PhantomData, } impl GarnetProvider { - pub fn new( + pub(crate) fn new( dim: usize, + quant_type: VectorQuantType, metric_type: Metric, max_degree: usize, callbacks: Callbacks, - context: Context, + context: &Context, ) -> Result { let parallelism = std::thread::available_parallelism().unwrap().get() * 2; let id_buffer_pool = @@ -114,14 +200,16 @@ impl GarnetProvider { let start_point_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); + let start_point_quant_cache = + DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); let neighbor_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); // Try to read the start point from Garnet - let mut v = vec![T::default(); dim]; - if callbacks.read_single_iid(context.term(Term::Vector), 0, &mut v) { + let mut v = Poly::broadcast(0u8, dim * mem::size_of::(), AlignToEight)?; + if callbacks.read_single_iid(&context.term(Term::Vector), 0, &mut v) { let mut neighbors = vec![0u32; max_degree + 1]; - if !callbacks.read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) { + if !callbacks.read_single_iid(&context.term(Term::Neighbors), 0, &mut neighbors) { return Err(GarnetError::Read.into()); } @@ -134,34 +222,67 @@ impl GarnetProvider { let fsm = FreeSpaceMap::new(context, callbacks)?; + let (quantizer, canonical_bytes, all_quantized) = match quant_type { + VectorQuantType::NoQuant + | VectorQuantType::XNoQuantU8 + | VectorQuantType::XNoQuantI8 => (None, 0, false), + VectorQuantType::Invalid => return Err(GarnetProviderError::InvalidQuantizer), + VectorQuantType::Q8 => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = Box::new(quantization::MinMax8Bit::new(dim, metric_type)?) + as Box; + let canonical_bytes = quantizer.canonical_bytes(); + // NOTE: Q8 needs no training, so it always starts with backfill complete. + (Some(quantizer), canonical_bytes, true) + } + VectorQuantType::Bin | VectorQuantType::XBinU8 | VectorQuantType::XBinI8 => { + let quantizer = + Box::new(quantization::Spherical1Bit::new(dim)) as Box; + let canonical_bytes = quantizer.canonical_bytes(); + (Some(quantizer), canonical_bytes, false) + } + }; + let quant_buffer_pool = + ObjectPool::new(Undef::new(canonical_bytes), parallelism, Some(parallelism)); + Ok(Self { dim, metric_type, max_degree, callbacks, + quantizer, + all_quantized: AtomicBool::new(all_quantized), + backfills_completed: AtomicU64::new(0), + training_started: AtomicU64::new(0), id_buffer_pool, filtered_ids_pool, + quant_buffer_pool, start_point_cache, + start_point_quant_cache, neighbor_cache, fsm, + _phantom: PhantomData, }) } - pub fn maybe_set_start_point( + pub(crate) fn maybe_set_start_point( &self, context: &Context, point: &[T], ) -> Result<(), GarnetProviderError> { - let mut v = vec![T::default(); self.dim]; + let mut v = Poly::broadcast(0u8, self.dim * mem::size_of::(), AlignToEight)?; if self .callbacks - .read_single_iid(context.term(Term::Vector), 0, &mut v) + .read_single_iid(&context.term(Term::Vector), 0, &mut v) { // Garnet already has a start point, so use that instead of `point` let mut neighbors = vec![0u32; self.max_degree + 1]; if !self .callbacks - .read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) + .read_single_iid(&context.term(Term::Neighbors), 0, &mut neighbors) { return Err(GarnetError::Read.into()); } @@ -174,27 +295,54 @@ impl GarnetProvider { let neighbors = vec![0u32; self.max_degree + 1]; // Grab the start point id, which must be zero. - let id = self.fsm.next_id(*context)?; + let id = self.fsm.next_id(context)?; if id != 0 { - self.fsm.mark_free(*context, id)?; + self.fsm.mark_free(context, id)?; return Err(GarnetProviderError::StartPoint); } if !self .callbacks - .write_iid(context.term(Term::Vector), 0, point) + .write_iid(&context.term(Term::Vector), 0, point) { return Err(GarnetError::Write.into()); } + if self.is_quantized() + && let Some(quantizer) = self.quantizer() + { + // We are already able to quantize, so store the quantized start point + + let mut qpoint = vec![0u8; quantizer.canonical_bytes()]; + let point_f32 = + T::as_f32(point).map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + quantizer.compress(&point_f32, &mut qpoint)?; + + if !self + .callbacks + .write_iid(&context.term(Term::Quantized), 0, &qpoint) + { + return Err(GarnetError::Write.into()); + } + + self.start_point_quant_cache + .insert(0, Poly::from_iter(qpoint.iter().copied(), AlignToEight)?); + } + if !self .callbacks - .write_iid(context.term(Term::Neighbors), 0, &neighbors) + .write_iid(&context.term(Term::Neighbors), 0, &neighbors) { return Err(GarnetError::Write.into()); } - self.start_point_cache.insert(0, point.to_vec()); + self.start_point_cache.insert( + 0, + Poly::from_iter( + bytemuck::cast_slice::(point).iter().copied(), + AlignToEight, + )?, + ); self.neighbor_cache .insert(0, Vec::with_capacity(self.max_degree + 1)); } @@ -202,15 +350,11 @@ impl GarnetProvider { Ok(()) } - pub fn start_points_exist(&self) -> bool { + pub(crate) fn start_points_exist(&self) -> bool { self.start_point_cache.get(&0).is_some() && self.neighbor_cache.get(&0).is_some() } - pub fn callbacks(&self) -> &Callbacks { - &self.callbacks - } - - pub fn set_attributes( + pub(crate) fn set_attributes( &self, context: &Context, id: &GarnetId, @@ -218,30 +362,234 @@ impl GarnetProvider { ) -> Result<(), GarnetProviderError> { if self .callbacks - .write_eid(context.term(Term::Attributes), id, data) + .write_eid(&context.term(Term::Attributes), id, data) { Ok(()) } else { Err(GarnetError::Write.into()) } } - pub fn vector_id_exists(&self, context: &Context, id: &GarnetId) -> bool { + pub(crate) fn vector_id_exists(&self, context: &Context, id: &GarnetId) -> bool { let iid = match self.to_internal_id(context, id) { Ok(iid) => iid, Err(_) => return false, }; - !self.fsm.is_free(*context, iid).unwrap_or(true) + !self.fsm.is_free(context, iid).unwrap_or(true) } - pub fn vector_iid_exists(&self, context: &Context, id: u32) -> bool { - !self.fsm.is_free(*context, id).unwrap_or(true) + pub(crate) fn vector_iid_exists(&self, context: &Context, id: u32) -> bool { + !self.fsm.is_free(context, id).unwrap_or(true) } - pub fn max_internal_id(&self) -> u32 { + pub(crate) fn max_internal_id(&self) -> u32 { self.fsm.max_id() } - pub(crate) fn get_vector( + /// Train the quantizer. + /// + /// This should only be called when at least `quantizer.required_vectors()` vectors exist in + /// the provider. This will build quantization tables, but does not quantize any vectors. + pub(crate) fn train_quantizer(&self, context: &Context) -> bool { + // Collect up to `self.quantizer.required_vectors()` vectors and use them for quantizer training. + + let current_time: u64 = + match SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH) { + Ok(t) => { + let t = t.as_millis().try_into().unwrap_or(0); + if t == 0 { + return false; + } + t + } + Err(_) => return false, + }; + + // Ensure we don't kick off training twice. + match self.training_started.compare_exchange( + 0, + current_time, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => (), + Err(_) => return false, + } + + let quantizer = match &self.quantizer { + Some(q) => q, + None => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + + let rows = quantizer.required_vectors(); + let mut data = Matrix::::new(T::default(), rows, self.dim); + let mut row_idx = 0usize; + + if self + .fsm + .visit_used(context, |id| { + // Skip the start point. + if id == 0 { + return true; + } + + if row_idx >= rows { + return false; + } + + // Read the vector into the data matrix. + let row = data.row_mut(row_idx); + if !self + .callbacks + .read_single_iid(&context.term(Term::Vector), id, row) + { + return false; + } + + row_idx += 1; + + true + }) + .is_err() + { + // Training failed + self.training_started.store(0, Ordering::Release); + return false; + } + + if row_idx < quantizer.required_vectors() { + self.training_started.store(0, Ordering::Release); + return false; + } + + let view = if let Some(view) = data.subview(0..row_idx) { + view + } else { + return false; + }; + + // Train the quantizer. + let converted = match T::as_f32(view.as_slice()) { + Ok(v) => v, + Err(_) => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + let view = match MatrixView::try_from(&*converted, view.nrows(), view.ncols()) { + Ok(v) => v, + Err(_) => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + match quantizer.train(self.metric_type, view) { + Ok(()) => true, + Err(_e) => { + self.training_started.store(0, Ordering::Release); + false + } + } + } + + /// Bulk quantize previously inserted vectors. + /// + /// This function will be invoked on multiple threads. The total number of tasks and the ID of + /// the current task are given as inputs. + pub(crate) fn backfill_quant_vectors( + &self, + context: &Context, + task_idx: usize, + task_count: usize, + ) { + let quantizer = match &self.quantizer { + Some(q) => q, + None => return, + }; + + let max_id = self.fsm.max_id() as usize; + + let task_count = task_count.min(max_id + 1); + if task_idx >= task_count { + return; + } + + let work_count = (max_id + 1).div_ceil(task_count); // will be >= 1 + let start_id = (work_count * task_idx) as u32; + let end_id = (work_count * (task_idx + 1)).min(max_id + 1) as u32; + + self.fsm.lock_reuse(); + + let mut v = vec![T::default(); self.dim]; + let mut f = vec![0f32; self.dim]; + let mut q = vec![0u8; quantizer.canonical_bytes()]; + for id in start_id..end_id { + if !self + .callbacks + .read_single_iid(&context.term(Term::Vector), id, &mut v) + { + continue; + } + + if T::as_f32_into(&v, &mut f).is_err() { + continue; + } + if quantizer.compress(&f, &mut q).is_err() { + continue; + }; + + if !self + .callbacks + .write_iid(&context.term(Term::Quantized), id, &q) + { + continue; + } + } + + let backfill_finished = + self.backfills_completed.fetch_add(1, Ordering::AcqRel) + 1 == task_count as u64; + + if backfill_finished { + // Finish by quantizing the start points + if let Some(v) = self.start_point_cache.get(&0) + && let Ok(v_f32) = T::as_f32(bytemuck::cast_slice::(&v)) + && quantizer.compress(&v_f32, &mut q).is_ok() + { + let _ = self + .callbacks + .write_iid(&context.term(Term::Quantized), 0, &q); + + // set the cache + let point = if let Ok(p) = Poly::from_iter(q.iter().copied(), AlignToEight) { + p + } else { + return; + }; + self.start_point_quant_cache.insert(0, point); + } + + self.fsm.unlock_reuse(); + self.all_quantized.store(true, Ordering::Release); + } + } + + /// Returns the quantizer associated with the index. + fn quantizer(&self) -> Option<&dyn GarnetQuantizer> { + if let Some(quantizer) = &self.quantizer { + return Some(&**quantizer as &dyn GarnetQuantizer); + } + + None + } + + /// Returns quantization status. If this is true, the index is operating fully quantized. + pub(crate) fn is_quantized(&self) -> bool { + self.quantizer.is_some() && self.all_quantized.load(Ordering::Acquire) + } + + pub(crate) fn get_full_vector( &self, context: &Context, iid: u32, @@ -254,13 +602,13 @@ impl GarnetProvider { } else { return Err(GarnetError::Read.into()); }; - v.copy_from_slice(&guard); + v.copy_from_slice(bytemuck::cast_slice::(&guard)); return Ok(v); } if !self .callbacks - .read_single_iid(context.term(Term::Vector), iid, &mut v) + .read_single_iid(&context.term(Term::Vector), iid, &mut v) { return Err(GarnetError::Read.into()); } @@ -283,7 +631,7 @@ impl DataProvider for GarnetProvider { ) -> Result { let mut id = 0u32; if !self.callbacks.read_single_eid( - context.term(Term::IntMap), + &context.term(Term::IntMap), gid, bytemuck::bytes_of_mut(&mut id), ) { @@ -295,7 +643,7 @@ impl DataProvider for GarnetProvider { fn to_external_id(&self, context: &Context, id: u32) -> Result { match self .callbacks - .read_varsize_iid(context.term(Term::ExtMap), id) + .read_varsize_iid(&context.term(Term::ExtMap), id) { Some(eid) => Ok(eid.into()), None => Err(GarnetProviderError::Garnet(GarnetError::Read)), @@ -312,20 +660,43 @@ impl SetElement<&[T]> for GarnetProvider { id: &Self::ExternalId, element: &[T], ) -> Result { - let internal_id = self.fsm.next_id(*context)?; + let internal_id = self.fsm.next_id(context)?; + + // Set quantization readiness + if let Some(quantizer) = &self.quantizer + && !quantizer.is_prepared() + && self.fsm.total_used() > quantizer.required_vectors() + { + context.set_quantizer_ready(); + } let insert = || -> Result<(), Self::SetError> { self.callbacks - .write_iid(context.term(Term::Vector), internal_id, element) + .write_iid(&context.term(Term::Vector), internal_id, element) .then_some(()) .ok_or(GarnetError::Write)?; + if let Some(quantizer) = &self.quantizer + && quantizer.is_prepared() + { + let mut quant = self + .quant_buffer_pool + .get_ref(Undef::new(quantizer.canonical_bytes())); + let element_f32 = T::as_f32(element).map_err(|e| { + GarnetProviderError::Quantizer(GarnetQuantizerError::Compression(Box::new(e))) + })?; + quantizer.compress(&element_f32, &mut quant)?; + self.callbacks + .write_iid(&context.term(Term::Quantized), internal_id, &quant) + .then_some(()) + .ok_or(GarnetError::Write)?; + } self.callbacks - .write_iid(context.term(Term::ExtMap), internal_id, id) + .write_iid(&context.term(Term::ExtMap), internal_id, id) .then_some(()) .ok_or(GarnetError::Write)?; self.callbacks .write_eid( - context.term(Term::IntMap), + &context.term(Term::IntMap), id, bytemuck::bytes_of(&internal_id), ) @@ -337,7 +708,7 @@ impl SetElement<&[T]> for GarnetProvider { match insert() { Ok(()) => (), Err(e) => { - self.fsm.mark_free(*context, internal_id)?; + self.fsm.mark_free(context, internal_id)?; return Err(e); } } @@ -358,20 +729,23 @@ impl Delete for GarnetProvider { }; // Mark the ID free in the FSM. - if let Err(e) = self.fsm.mark_free(*context, id) { + if let Err(e) = self.fsm.mark_free(context, id) { return future::ready(Err(e.into())); }; // Delete all the data associated with the vector. let mut ok = true; - ok &= self.callbacks.delete_iid(context.term(Term::ExtMap), id); - ok &= self.callbacks.delete_eid(context.term(Term::IntMap), gid); + ok &= self.callbacks.delete_iid(&context.term(Term::ExtMap), id); + ok &= self.callbacks.delete_eid(&context.term(Term::IntMap), gid); ok &= self .callbacks - .delete_eid(context.term(Term::Attributes), gid); + .delete_eid(&context.term(Term::Attributes), gid); // NOTE: Commented out until DiskANN fixes accessing neighbor data post-delete. //ok &= self.callbacks.delete_iid(context.term(Term::Neighbors), id); - ok &= self.callbacks.delete_iid(context.term(Term::Vector), id); + ok &= self.callbacks.delete_iid(&context.term(Term::Vector), id); + ok &= self + .callbacks + .delete_iid(&context.term(Term::Quantized), id); if !ok { return future::ready(Err(GarnetError::Delete.into())); @@ -394,7 +768,7 @@ impl Delete for GarnetProvider { context: &Self::Context, id: Self::InternalId, ) -> impl Future> + Send { - let status = match self.fsm.is_free(*context, id) { + let status = match self.fsm.is_free(context, id) { Ok(true) => ElementStatus::Deleted, Ok(false) => ElementStatus::Valid, Err(e) => return future::ready(Err(e.into())), @@ -413,20 +787,25 @@ impl Delete for GarnetProvider { } } -#[allow(dead_code)] -pub struct FullAccessor<'a, T: VectorRepr> { +/// Dynamic accessor that seamlessly transitions from full precision vector based operation to +/// quantized-only operation. +#[derive(Copy, Clone, Debug)] +pub(crate) struct DynamicQuantization; + +pub(crate) struct DynamicAccessor<'a, T: VectorRepr> { provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + /// Whether this accessor should use quantized vectors + quantized: bool, id_buffer: PooledRef<'a, AdjList>, filtered_ids: PooledRef<'a, Vec>, } -impl<'a, T: VectorRepr> FullAccessor<'a, T> { +impl<'a, T: VectorRepr> DynamicAccessor<'a, T> { pub(crate) fn new( provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + quantized: bool, ) -> Self { let id_buffer = provider .id_buffer_pool @@ -434,10 +813,10 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { let filtered_ids = provider .filtered_ids_pool .get_ref(Undef::new(MAX_OCCLUSION_SIZE.get() as usize * 2)); // x2 to allow for the length prefixes for garnet - FullAccessor { + DynamicAccessor { provider, context, - is_search, + quantized, id_buffer, filtered_ids, } @@ -457,10 +836,11 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { } if !self.provider.callbacks.read_single_iid( - self.context.term(Term::Neighbors), + &self.context.term(Term::Neighbors), id, &mut guard, ) { + guard.finish(0); return false; } @@ -471,11 +851,11 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { } } -impl HasId for FullAccessor<'_, T> { +impl HasId for DynamicAccessor<'_, T> { type Id = u32; } -impl SearchExt for FullAccessor<'_, T> { +impl SearchExt for DynamicAccessor<'_, T> { fn starting_points(&self) -> impl Future>> + Send { let points = if self.provider.start_points_exist() { vec![0] @@ -493,7 +873,7 @@ impl SearchExt for FullAccessor<'_, T> { } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T> { +impl ExpandBeam<&[T]> for DynamicAccessor<'_, T> { fn expand_beam( &mut self, ids: Itr, @@ -517,14 +897,31 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { .filter(|id| pred.eval_mut(id)) { if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + let dist = if self.quantized + && let Some(_quantizer) = self.provider.quantizer() + { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) + { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) } else { - return future::ready(Err( - GarnetProviderError::Garnet(GarnetError::Read).into() - )); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) }; - let dist = computer.evaluate_similarity(&*guard); + on_neighbors(dist, id); } else { self.filtered_ids.push(4); @@ -532,59 +929,184 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { } } - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |i, v| { - let dist = computer.evaluate_similarity(v); - on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); - }, - ); + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.filtered_ids.is_empty() { + self.provider + .callbacks + .read_multi_lpiid(&ctx, &self.filtered_ids, |i, v| { + let dist = computer.evaluate_similarity(v); + on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); + }); + } } future::ready(Ok(())) } } -impl Accessor for FullAccessor<'_, T> { +impl Accessor for DynamicAccessor<'_, T> { type Element<'a> - = Vec + = DynVector where Self: 'a; - type ElementRef<'a> = &'a [T]; + type ElementRef<'a> = &'a [u8]; type GetError = GarnetProviderError; fn get_element( &mut self, id: Self::Id, ) -> impl Future, Self::GetError>> + Send { - std::future::ready(self.provider.get_vector(self.context, id)) + let v_len = if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + quantizer.canonical_bytes() + } else { + self.provider.dim * mem::size_of::() + }; + + let mut v = match Poly::broadcast(0u8, v_len, AlignToEight) { + Ok(v) => DynVector::new(v), + Err(e) => return future::ready(Err(GarnetProviderError::AllocFailed(e))), + }; + + if id == 0 { + if self.quantized { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); + } else { + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); + } + } + + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.provider.callbacks.read_single_iid(&ctx, id, &mut v) { + return future::ready(Err(GarnetError::Read.into())); + } + + future::ready(Ok(v)) + } +} + +/// Wrapper for full precision distance computer. +pub(crate) struct FullPrecisionDistance(T::Distance); + +impl DynDistanceComputer for FullPrecisionDistance { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + self.0.evaluate_similarity( + bytemuck::cast_slice::(a), + bytemuck::cast_slice::(b), + ) + } +} + +/// Wrapper for full precision query computer. +pub(crate) struct FullPrecisionQueryDistance(T::QueryDistance); + +impl DynQueryComputer for FullPrecisionQueryDistance { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + self.0.evaluate_similarity(bytemuck::cast_slice::(a)) + } +} + +/// Type-erased distance computer. +pub(crate) struct GarnetDistanceComputer { + inner: Box, +} + +impl GarnetDistanceComputer { + pub(crate) fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} +impl DistanceFunction<&[u8], &[u8]> for GarnetDistanceComputer { + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.inner.evaluate_similarity(x, y) + } +} + +/// Type-erased query computer. +pub(crate) struct GarnetQueryComputer { + inner: Box, +} + +impl GarnetQueryComputer { + pub(crate) fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } } } -impl BuildDistanceComputer for FullAccessor<'_, T> { - type DistanceComputer = T::Distance; +impl PreprocessedDistanceFunction<&[u8]> for GarnetQueryComputer { + fn evaluate_similarity(&self, changing: &[u8]) -> f32 { + self.inner.evaluate_similarity(changing) + } +} + +impl BuildDistanceComputer for DynamicAccessor<'_, T> { + type DistanceComputer = GarnetDistanceComputer; type DistanceComputerError = GarnetProviderError; fn build_distance_computer( &self, ) -> Result { - Ok(T::distance( - self.provider.metric_type, - Some(self.provider.dim), - )) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer.distance_computer()?) + } else { + Ok(GarnetDistanceComputer::new(FullPrecisionDistance::( + T::distance(self.provider.metric_type, Some(self.provider.dim)), + ))) + } } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { - type QueryComputer = T::QueryDistance; +impl BuildQueryComputer<&[T]> for DynamicAccessor<'_, T> { + type QueryComputer = GarnetQueryComputer; type QueryComputerError = GarnetProviderError; fn build_query_computer( &self, from: &[T], ) -> Result { - Ok(T::query_distance(from, self.provider.metric_type)) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + let from_f32 = T::as_f32(from).map_err(|e| { + GarnetProviderError::Quantizer(GarnetQuantizerError::Compression(Box::new(e))) + })?; + Ok(quantizer + .query_computer(&from_f32) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?) + } else { + Ok(GarnetQueryComputer::new(FullPrecisionQueryDistance::( + T::query_distance(from, self.provider.metric_type), + ))) + } } } @@ -592,7 +1114,7 @@ impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { /// /// Without an `&[T]: Into>`, the blanket implementation for `workingset::Map` /// is not applicable, allowing customization of `Fill`. -pub struct Escape(Box<[T]>); +pub(crate) struct Escape(Box<[T]>); impl<'a, T> Reborrow<'a> for Escape { type Target = &'a [T]; @@ -601,61 +1123,110 @@ impl<'a, T> Reborrow<'a> for Escape { } } -type WorkingSet = workingset::Map>; -type WorkingSetView<'a, T> = workingset::map::View<'a, u32, Escape>; +pub(crate) struct WorkingSet { + map: workingset::Map>, + contains_unquantized: bool, +} + +impl WorkingSet { + pub(crate) fn new(capacity_type: workingset::map::Capacity, capacity: usize) -> Self { + Self { + map: workingset::map::Builder::new(capacity_type).build(capacity), + contains_unquantized: false, + } + } +} + +impl Deref for WorkingSet { + type Target = workingset::Map>; -impl workingset::Fill> for FullAccessor<'_, T> { + fn deref(&self) -> &Self::Target { + &self.map + } +} + +impl DerefMut for WorkingSet { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.map + } +} + +type WorkingSetView<'a> = workingset::map::View<'a, u32, Escape>; + +impl workingset::Fill for DynamicAccessor<'_, T> { type Error = GarnetProviderError; type View<'a> - = WorkingSetView<'a, T> + = WorkingSetView<'a> where Self: 'a; async fn fill<'a, Itr>( &'a mut self, - set: &'a mut WorkingSet, + set: &'a mut WorkingSet, itr: Itr, ) -> Result, Self::Error> where Itr: ExactSizeIterator + Clone + Send + Sync, Self: 'a, { + if !self.quantized { + // Mark this working set as having full vectors if we're not quantizing yet + set.contains_unquantized = true; + } else if set.contains_unquantized { + // Working set is polluted by full vectors, it must be cleared + set.clear(); + set.contains_unquantized = false; + } + // Evict items from the working set to make room if needed. set.prepare(itr.clone()); self.filtered_ids.clear(); for id in itr { if id == 0 { - if let Entry::Vacant(e) = set.entry(id) { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + if self.quantized + && let Entry::Vacant(e) = set.entry(id) + { + if let Some(guard) = self.provider.start_point_quant_cache.get(&id) { + e.insert(Escape((&**guard).into())); } else { - return Err(GarnetError::Read.into()); - }; - e.insert(Escape((&**guard).into())); - } + return Err(GarnetProviderError::StartPoint); + } + } else if let Entry::Vacant(e) = set.entry(id) { + if let Some(guard) = self.provider.start_point_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } else { + return Err(GarnetProviderError::StartPoint); + } + } else { + continue; + }; } else if !set.contains_key(&id) { self.filtered_ids.push(4); self.filtered_ids.push(id); } } + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + if !self.filtered_ids.is_empty() { - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |id, v| { + self.provider + .callbacks + .read_multi_lpiid(&ctx, &self.filtered_ids, |id, v| { set.insert(self.filtered_ids[id as usize * 2 + 1], Escape(v.into())); - }, - ); + }); } Ok(set.view()) } } -pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut FullAccessor<'p, T>); +pub(crate) struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut DynamicAccessor<'p, T>); impl HasId for DelegateNeighborAccessor<'_, '_, T> { type Id = u32; @@ -675,7 +1246,7 @@ impl NeighborAccessor for DelegateNeighborAccessor<'_, '_, T> { } } -impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for FullAccessor<'p, T> { +impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for DynamicAccessor<'p, T> { type Delegate = DelegateNeighborAccessor<'p, 'a, T>; fn delegate_neighbor(&'a mut self) -> Self::Delegate { @@ -697,7 +1268,7 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> .0 .provider .callbacks - .write_iid(self.0.context.term(Term::Neighbors), id, &guard) + .write_iid(&self.0.context.term(Term::Neighbors), id, &guard) { return future::ready(Err(GarnetProviderError::Garnet(GarnetError::Write).into())); } @@ -721,7 +1292,7 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> ) -> impl Future> + Send { let max_degree = self.0.provider.max_degree; if !self.0.provider.callbacks.rmw_iid( - self.0.context.term(Term::Neighbors), + &self.0.context.term(Term::Neighbors), id, (self.0.provider.max_degree + 1) * mem::size_of::(), |data: &mut [u32]| { @@ -759,23 +1330,23 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> /// A [`SearchPostProcess`] base object that copies each `Neighbor` to a `(ExternalId, f32)` pair /// and writes as many as possible to the output buffer. #[derive(Debug, Default, Clone, Copy)] -pub struct CopyExternalIds; +pub(crate) struct CopyExternalIds; -impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> +impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> for CopyExternalIds { type Error = GarnetProviderError; fn post_process( &self, - accessor: &mut FullAccessor<'a, T>, + accessor: &mut DynamicAccessor<'a, T>, _query: &[T], - _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send where - I: Iterator as HasId>::Id>> + Send, + I: Iterator as HasId>::Id>> + Send, B: SearchOutputBuffer + Send + ?Sized, { let initial = output.current_len(); @@ -789,41 +1360,124 @@ impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], Garn break; } } + let count = output.current_len() - initial; future::ready(Ok(count)) } } -impl SearchStrategy, &[T]> for FullPrecision { - type SearchAccessor<'a> = FullAccessor<'a, T>; +/// A [`SearchPostProcess`] base object that reranks quantized vectors by full precision distance. +#[derive(Debug, Default, Clone, Copy)] +pub(crate) struct Rerank; + +impl<'a, 'b, T: VectorRepr> SearchPostProcessStep, &'b [T], GarnetId> + for Rerank +{ + type Error + = GarnetProviderError + where + NextError: diskann::error::StandardError; + + type NextAccessor = DynamicAccessor<'a, T>; + + async fn post_process_step( + &self, + next: &Next, + accessor: &mut DynamicAccessor<'a, T>, + query: &'b [T], + computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result> + where + I: Iterator as HasId>::Id>> + Send, + B: SearchOutputBuffer + Send + ?Sized, + Next: SearchPostProcess + Sync, + { + if !accessor.quantized { + // Skip reranking if the accessor if working with full precision + return next + .post_process(accessor, query, computer, candidates, output) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))); + } + + let provider = &accessor.provider; + let f = T::distance(provider.metric_type, Some(provider.dim)); + let mut v = Poly::broadcast(0u8, provider.dim * mem::size_of::(), AlignToEight)?; + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if !provider.vector_iid_exists(accessor.context, n.id) { + None + } else if provider.callbacks.read_single_iid( + &accessor.context.term(Term::Vector), + n.id, + &mut v, + ) { + Some(( + n.id, + f.evaluate_similarity(query, bytemuck::cast_slice::(&v)), + )) + } else { + None + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + next.post_process( + accessor, + query, + computer, + reranked.into_iter().map(|(id, d)| Neighbor::new(id, d)), + output, + ) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))) + } +} + +impl SearchStrategy, &[T]> for DynamicQuantization { + type SearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; - type QueryComputer = T::QueryDistance; + type QueryComputer = GarnetQueryComputer; fn search_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, true)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } } -impl DefaultPostProcessor, &[T], GarnetId> for FullPrecision { - default_post_processor!(glue::Pipeline); +impl DefaultPostProcessor, &[T], GarnetId> + for DynamicQuantization +{ + default_post_processor!( + glue::Pipeline> + ); } -impl PruneStrategy> for FullPrecision { - type PruneAccessor<'a> = FullAccessor<'a, T>; +impl PruneStrategy> for DynamicQuantization { + type PruneAccessor<'a> = DynamicAccessor<'a, T>; type PruneAccessorError = GarnetProviderError; - type DistanceComputer<'a> = T::Distance; - type WorkingSet = WorkingSet; + type DistanceComputer<'a> = GarnetDistanceComputer; + type WorkingSet = WorkingSet; fn prune_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { @@ -831,11 +1485,11 @@ impl PruneStrategy> for FullPrecision { // cache and persist up to `capacity` items across uses of the working set. // // This reuse is limited to a single collection of backedges for an insert or multi-insert. - workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity) + WorkingSet::new(workingset::map::Capacity::Default, capacity) } } -impl InsertStrategy, &[T]> for FullPrecision { +impl InsertStrategy, &[T]> for DynamicQuantization { type PruneStrategy = Self; fn insert_search_accessor<'a>( @@ -843,7 +1497,8 @@ impl InsertStrategy, &[T]> for FullPrecision { provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -851,13 +1506,13 @@ impl InsertStrategy, &[T]> for FullPrecision { } } -impl InplaceDeleteStrategy> for FullPrecision { +impl InplaceDeleteStrategy> for DynamicQuantization { type DeleteElement<'a> = &'a [T]; type DeleteElementGuard = Box<[T]>; type DeleteElementError = GarnetProviderError; type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = FullAccessor<'a, T>; + type DeleteSearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchPostProcessor = glue::CopyIds; type SearchStrategy = Self; @@ -881,7 +1536,7 @@ impl InplaceDeleteStrategy> for FullPrecision { ) -> impl Future> + Send { let mut v = vec![T::default(); provider.dim]; - if !provider.callbacks.read_single_iid(*context, id, &mut v) { + if !provider.callbacks.read_single_iid(context, id, &mut v) { return future::ready(Err(GarnetError::Read.into())); } future::ready(Ok(v.into())) diff --git a/diskann-garnet/src/quantization.rs b/diskann-garnet/src/quantization.rs new file mode 100644 index 000000000..7891704a3 --- /dev/null +++ b/diskann-garnet/src/quantization.rs @@ -0,0 +1,317 @@ +use std::{num::NonZero, sync::RwLock}; + +use diskann::utils::VectorRepr; +use diskann_quantization::{ + CompressInto, + algorithms::{Transform, TransformKind, transforms::NewTransformError}, + alloc::{GlobalAllocator, ScopedAllocator}, + minmax, + num::Positive, + spherical::{ + self, Data, PreScale, SphericalQuantizer, SupportedMetric, + iface::{self, Opaque, OpaqueMut, Quantizer}, + }, +}; +use diskann_utils::views::MatrixView; +use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; +use thiserror::Error; + +use crate::provider::{GarnetDistanceComputer, GarnetQueryComputer}; + +#[derive(Debug, Error)] +pub(crate) enum GarnetQuantizerError { + #[error("Quantization training error: {0}")] + Training(Box), + #[error("Quantization alloc error: {0}")] + Alloc(Box), + #[error("Query computer error: {0}")] + QueryComputer(Box), + #[error("Binary quantization error: {0}")] + Compression(Box), + #[error("No quantizer found")] + NoQuantizer, + #[error("Got zero dimension")] + ZeroDim, + #[error("Transform error: {0}")] + BadTransform(#[from] NewTransformError), +} + +/// Quantizer trait that all diskann-garnet quantizers must implement +pub(crate) trait GarnetQuantizer: Send + Sync { + /// Check whether the quantizer is ready to be used + fn is_prepared(&self) -> bool; + /// Returns the number of vectors needed before the quantizer can be trained + fn required_vectors(&self) -> usize; + /// Returns the size of a quantized vector + fn canonical_bytes(&self) -> usize; + /// Train the quantizer. + /// Each row of the matrix will be a vector + fn train(&self, metric: Metric, data: MatrixView) -> Result<(), GarnetQuantizerError>; + /// Quantize a vector + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError>; + /// Returns a distance computer for comparing quantized vectors + fn distance_computer(&self) -> Result; + /// Returns a query computer for comparing distances to a particular query + fn query_computer(&self, query: &[f32]) -> Result; +} + +/// Type-erased distance computer +pub(crate) trait DynDistanceComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32; +} + +/// Type-erased query computer +pub(crate) trait DynQueryComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8]) -> f32; +} + +/// Spherical 1-bit quantization. +/// +/// This quantizer corresponds to `BIN` quantizer in the Redis protocol. It requires hundreds of +/// vectors (but not thousands) for training. Quantized vectors have 1 bit per dimension plus up +/// to 6 bytes of overhead. +pub(crate) struct Spherical1Bit { + dim: usize, + inner: RwLock>>, +} + +impl Spherical1Bit { + pub(crate) fn new(dim: usize) -> Self { + Self { + dim, + inner: RwLock::new(None), + } + } +} + +impl GarnetQuantizer for Spherical1Bit { + fn is_prepared(&self) -> bool { + self.inner.read().unwrap().is_some() + } + + fn required_vectors(&self) -> usize { + 1000 + } + + fn canonical_bytes(&self) -> usize { + Data::<1, GlobalAllocator>::canonical_bytes(self.dim) + } + + fn train( + &self, + metric_type: Metric, + data: MatrixView, + ) -> Result<(), GarnetQuantizerError> { + let mut rng = rand::rng(); + let quantizer = SphericalQuantizer::train( + data.as_view(), + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + SupportedMetric::try_from(metric_type) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?, + PreScale::ReciprocalMeanNorm, + &mut rng, + GlobalAllocator, + ) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?; + + let mut inner = self.inner.write().unwrap(); + *inner = Some( + spherical::iface::Impl::<1>::new(quantizer) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?, + ); + + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + spherical::iface::Quantizer::::compress( + quantizer, + v, + OpaqueMut::new(into), + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn distance_computer(&self) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .distance_computer(GlobalAllocator) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?; + Ok(GarnetDistanceComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn query_computer(&self, query: &[f32]) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .fused_query_computer( + query, + iface::QueryLayout::FullPrecision, + true, + GlobalAllocator, + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?; + Ok(GarnetQueryComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } +} + +impl DynDistanceComputer for iface::DistanceComputer { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + , Opaque<'_>, _>>::evaluate_similarity( + self, + Opaque::new(a), + Opaque::new(b), + ) + .unwrap() + } +} + +impl DynQueryComputer for iface::QueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + , _>>::evaluate_similarity( + self, + Opaque::new(a), + ) + .unwrap() + } +} + +/// 8-bit scalar quantizer using MinMax +/// +/// This quantizer requires no training at all and is usable immediately on the first vector. Each +/// quantized vector has 8 bits per dimension and 20 bytes of overhead. +pub(crate) struct MinMax8Bit { + dim: usize, + metric: Metric, + inner: minmax::MinMaxQuantizer, +} + +impl MinMax8Bit { + pub(crate) fn new(dim: usize, metric: Metric) -> Result { + let dim = match NonZero::new(dim) { + Some(d) => d, + None => return Err(GarnetQuantizerError::ZeroDim), + }; + let mut rng = rand::rng(); + let transform = Transform::new( + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + dim, + Some(&mut rng), + GlobalAllocator, + )?; + let grid_scale = Positive::new(1.0).unwrap(); + + Ok(Self { + dim: dim.get(), + metric, + inner: minmax::MinMaxQuantizer::new(transform, grid_scale), + }) + } +} + +impl GarnetQuantizer for MinMax8Bit { + fn is_prepared(&self) -> bool { + true + } + + fn required_vectors(&self) -> usize { + 0 + } + + fn canonical_bytes(&self) -> usize { + minmax::Data::<8>::canonical_bytes(self.dim) + } + + fn train(&self, _metric: Metric, _data: MatrixView) -> Result<(), GarnetQuantizerError> { + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let into = minmax::DataMutRef::<8>::from_canonical_front_mut(into, self.dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + self.inner + .compress_into(v, into) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } + + fn distance_computer(&self) -> Result { + let computer = GarnetDistanceComputer::new( + ::distance( + self.metric, + Some(self.dim), + ), + ); + Ok(computer) + } + + fn query_computer(&self, query: &[f32]) -> Result { + let computer = GarnetQueryComputer::new(MinMax8BitQueryComputer::new( + &self.inner, + query, + self.dim, + self.metric, + )?); + Ok(computer) + } +} + +impl DynDistanceComputer for diskann_providers::common::FnPtr { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + let b = diskann_providers::common::MinMax8::from_bytes(b); + >::evaluate_similarity(self, a, b) + } +} +struct MinMax8BitQueryComputer( + diskann_providers::common::BufferedFnPtr, +); + +impl MinMax8BitQueryComputer { + fn new( + quantizer: &minmax::MinMaxQuantizer, + query: &[f32], + dim: usize, + metric: Metric, + ) -> Result { + let mut v = vec![Default::default(); minmax::Data::<8>::canonical_bytes(dim)]; + quantizer + .compress_into( + query, + minmax::DataMutRef::<8>::from_canonical_front_mut(&mut v, dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?, + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + let inner = diskann_providers::common::MinMax8::query_distance( + diskann_providers::common::MinMax8::from_bytes(&v), + metric, + ); + Ok(Self(inner)) + } +} + +impl DynQueryComputer for MinMax8BitQueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + self.0.evaluate_similarity(a) + } +} diff --git a/diskann-garnet/src/test_utils.rs b/diskann-garnet/src/test_utils.rs index 26b72f774..cdc59897f 100644 --- a/diskann-garnet/src/test_utils.rs +++ b/diskann-garnet/src/test_utils.rs @@ -131,6 +131,7 @@ mod tests { use std::collections::HashMap; use crate::{ + VectorQuantType, garnet::{Context, Term}, test_utils::Store, }; @@ -140,42 +141,42 @@ mod tests { let store = Store; store.clear(); let callbacks = store.callbacks(); - let ctx = Context(0); + let ctx = Context::new(0); // Reading a non-existant key should fail. - assert!(!callbacks.exists_iid(ctx, 0)); + assert!(!callbacks.exists_iid(&ctx, 0)); // Round tripping a write should work. - assert!(callbacks.write_iid(ctx, 0, b"test")); + assert!(callbacks.write_iid(&ctx, 0, b"test")); let mut val = vec![0u8; 4]; - assert!(callbacks.read_single_iid(ctx, 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx, 0, &mut val)); assert_eq!(val, b"test"); // Overwriting a key should work. - assert!(callbacks.write_iid(ctx, 0, b"again")); + assert!(callbacks.write_iid(&ctx, 0, b"again")); let mut val = vec![0u8; 5]; - assert!(callbacks.read_single_iid(ctx, 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx, 0, &mut val)); assert_eq!(val, b"again"); // Exists and delete should work. - assert!(callbacks.exists_iid(ctx, 0)); - assert!(callbacks.delete_iid(ctx, 0)); - assert!(!callbacks.exists_iid(ctx, 0)); + assert!(callbacks.exists_iid(&ctx, 0)); + assert!(callbacks.delete_iid(&ctx, 0)); + assert!(!callbacks.exists_iid(&ctx, 0)); // Different contexts should stay separate. - assert!(callbacks.write_iid(ctx.term(Term::Vector), 0, b"0000")); - assert!(callbacks.write_iid(ctx.term(Term::Neighbors), 0, b"nnnn")); + assert!(callbacks.write_iid(&ctx.term(Term::Vector), 0, b"0000")); + assert!(callbacks.write_iid(&ctx.term(Term::Neighbors), 0, b"nnnn")); let mut val = vec![0u8; 4]; - assert!(callbacks.read_single_iid(ctx.term(Term::Vector), 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx.term(Term::Vector), 0, &mut val)); assert_eq!(val, b"0000"); - assert!(callbacks.read_single_iid(ctx.term(Term::Neighbors), 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx.term(Term::Neighbors), 0, &mut val)); assert_eq!(val, b"nnnn"); // Multi-read should work. - assert!(callbacks.write_iid(ctx.term(Term::Vector), 1, b"2222")); + assert!(callbacks.write_iid(&ctx.term(Term::Vector), 1, b"2222")); let ids = [4u32, 0, 4, 1, 4, 2]; let mut results = HashMap::new(); - callbacks.read_multi_lpiid(ctx.term(Term::Vector), &ids, |i, v| { + callbacks.read_multi_lpiid(&ctx.term(Term::Vector), &ids, |i, v| { results.insert(i, v.to_owned()); }); assert_eq!(results.get(&0), Some(b"0000".to_vec()).as_ref()); @@ -183,10 +184,10 @@ mod tests { assert_eq!(results.get(&2), None); // RMW should work. - assert!(callbacks.rmw_iid(ctx.term(Term::Vector), 0, 4, |data| { + assert!(callbacks.rmw_iid(&ctx.term(Term::Vector), 0, 4, |data| { data.copy_from_slice(b"0rmw"); })); - assert!(callbacks.read_single_iid(ctx.term(Term::Vector), 0, &mut val)); + assert!(callbacks.read_single_iid(&ctx.term(Term::Vector), 0, &mut val)); assert_eq!(val, b"0rmw"); } @@ -198,12 +199,19 @@ mod tests { let store: Store = Store; store.clear(); let callbacks = store.callbacks(); - let ctx: Context = Context(0); + let ctx: Context = Context::new(0); // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + &ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); @@ -215,7 +223,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + &ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); diff --git a/diskann-providers/src/common/minmax_repr.rs b/diskann-providers/src/common/minmax_repr.rs index 73c19b10d..19dc8ec96 100644 --- a/diskann-providers/src/common/minmax_repr.rs +++ b/diskann-providers/src/common/minmax_repr.rs @@ -117,6 +117,12 @@ impl num_traits::FromPrimitive for MinMaxElement { impl_from_primitive!(NBITS, from_u32, u32); } +impl MinMaxElement { + pub fn from_bytes(x: &[u8]) -> &[Self] { + bytemuck::must_cast_slice(x) + } +} + //////////////////////// /// DistanceProvider /// //////////////////////// diff --git a/diskann-providers/src/common/mod.rs b/diskann-providers/src/common/mod.rs index b5c4a42a4..ec39946f7 100644 --- a/diskann-providers/src/common/mod.rs +++ b/diskann-providers/src/common/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ mod minmax_repr; -pub use minmax_repr::{MinMax4, MinMax8, MinMaxElement}; +pub use minmax_repr::{BufferedFnPtr, FnPtr, MinMax4, MinMax8, MinMaxElement}; mod ignore_lock_poison; pub use ignore_lock_poison::IgnoreLockPoison; diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md new file mode 100644 index 000000000..6577c3c06 --- /dev/null +++ b/rfcs/00000-quantizer-bootstrap.md @@ -0,0 +1,198 @@ +# Quantization Bootstrapping + +| | | +|---|---| +| **Authors** | Jack Moffitt | +| **Contributors** | Mark Hildebrand | +| **Created** | 2026-03-10 | +| **Updated** | 2026-03-10 | + +## Summary + +Indexes that use quantization must have a minimum number of vectors before they +can build quantization tables and start inserting vectors. Bootstrapping is the +process of incrementally building an index that starts empty, operates on +non-quantized vectors until enough vectors are present to build quantization +tables, and then transitions to normal operation. + +## Motivation + +### Background + +DiskANN's quantizers require some statistical information in order to build +quantization tables. For PQ, 10,000 vectors are generally required to build good +tables; for spherical, 100 are needed. In order to create an index, these +vectors must be provided at creation time in order to build the quantization +tables, at which point each vector is quantized as it is inserted. + +This requirement is easy to fulfill when building indexes from existing +datasets, but when starting from scratch, there is no ability for DiskANN to +build a quantized index since the quantization tables are a required part of the +constructor. + +Current deployments of DiskANN work around this issue by not allowing index +creation until a dataset is sufficiently large (pg_diskann), or operating a +separate flat index until sufficient vectors are collected at which point the +quantization tables are calculated and a graph index is built with DiskANN. + +### Problem Statement + +This RFC proposes changing DiskANN to operate in a quantization bootstrap mode +where it operates on full precision vectors until sufficient vectors exist to +create quantization tables, and then seamlessly transitions to a quantized +index. + +This means the index will operate in three different phases. In Phase 1, the +index operates in full precision mode only until sufficient vectors exist to +build quantization tables. During Phase 2, quantization tables will be built and +vectors will be quantized on insert; pre-existing vectors will be quantized in +the background. Once all vectors are quantized, Phase Three begins the normal +operation of the quantized index. + +### Goals + +1. Allow quantized indices to start empty and use full precision data only to + operate until sufficient vectors are inserted. +2. Aside from allowing construction of `DiskANNIndex` without providing + quantization tables, there should be no user-visible changes to using the index. +3. Performance should remain as high as possible during the three phases. +4. The quantization of previously inserted full vectors during Phase Two should + be controllable by the data provider. + +## Proposal + +Bootstrapping needs two changes to a DiskANN data provider implementation. + +1. **Switching Strategies**: DiskANN needs to start by using full precision only + strategies during Phase One, and switching to a quantized-only or hybrid + strategy for Phase Two. +2. **Quantization Backfill**: During Phase Two, previously inserted vectors will + need to be quantized. As background jobs are a performance concern, how + exactly this is accomplished must be customizable by the data provider. + +### Switching Strategies + +DiskANN already has the ability to run multiple strategies including hybrid full +precision and quantized ones. These strategies represent the high level intent, +but the data provider can choose alternate implementations depending on the data +available during the current phase. + +#### Insertion and Deletion + +Insertion and deletion can remain largely the same in the data provider +implementation. Inserts will need to write vectors, mappings, attributes, et al +into storage, and can track the current phase to gate writes to quantized +vectors. The search portion of these operations will return different objects +depending on the current phase. + +For example, consider a `DataProvider::set_element()` implementation: + +```rust +struct ExampleProvider { + // other fields omitted + quantizer: Option, +} + +impl SetElement<[f32]> for ExampleProvider { + // associated types omitted + + async fn set_element( + &self, + context: &Self::Context, + id: &Self::ExternalId, + element: &[T], + ) -> Result { + let internal_id = self.new_id()?; + self.write_vector(context, internal_id, element)?; + self.set_internal_map(internal_id, id)?; + self.set_external_map(id, internal_id)?; + + // Quantize and store quant vector if we have a quantizer. + if let Some(quantizer) = self.quantizer { + let qv = quantizer.quantize(element)?; + self.write_quant_vector(context, internal_id, element)?; + } else { + // This function will check if we are ready for Phase Two, and if so, do or schedule the quantizer initialization. + self.maybe_initialize_quantizer()?; + } + + Ok(NoopGuard::new(internal_id)) + } +} +``` + +Delete can similarly check the status of the quantizer, and delete quantized +vectors if they exist. + +#### Searching + +To avoid complexity of hybrid distance calculations, either full precision +distances will be used (Phase One and Two) or quantized distances will be used +(Phase Three). If a hybrid strategy is in use, then the hybrid distances will +not be used until Phase Three. + +Since vector data may be in one of two representations, the `Accessor::Element` +type should be `Poly` (this should be over-aligned to the correct alignment +for the primitive element type), and the data provider should interpret based on +data size. The distance and query computers will also need modifications to +accept both vector representations, and in the case of query computer the +representation much match that of the query. + + +### Quantization Backfill + +After quantization tables are built, newly inserted vectors will be quantized +before insertion, but previously inserted vectors won't have quantized +representations yet. During Phase Two, these previously inserted full precision +vectors will need to be quantized before the index enters Phase Three. + +Since integrators of DiskANN are sensitive to background jobs, how the index +manages backfilling quantized vectors is controlled by the data provider +implementation. The data provider must have some way to track which vectors have +missing quantized representations so that it generate them. + +Once Phase Two is reached, the data provider can either pause during insertion +of the phase changing vector, or schedule the work to happen asynchronously +however it likes. + +One possibility is to piggy-back on deletion tracking to track quantization +status of vectors. For example, in diskann-garnet, a free space map is kept that +tracks deletes. This could be expanded from 1-bit to 2-bits, and the second bit +used to track whether the vector is quantized. Alternatively, metadata about the +allocated range can be kept and used to iterate over the unquantized set. + + +## Trade-offs + +Currently the workarounds in use are either no index at all until sufficient +vectors exist or operating a side index until sufficient vectors exist and then +building a quantized graph. + +pg_diskann uses the former method, which means users are confused when they try +to create indexes on empty or insufficiently populated tables and get an +error. Cosmos DB uses the latter strategy and operates a flat index until an +asynchronous graph build is complete enough to use the graph index. This +requires the Cosmos DB team to maintain all their own infrastructure for the +flat index and the code around transitioning to the graph index. + +This proposal mitigates the downsides while still allowing the integrator to +retain control over key performance details. + +This proposal also entirely encapsulates this inside the `DataProvider` +implementation. Alternatively, one could attempt to solve this with some kind of +index or strategy layering, but the complexity this would introduce seems not +worth the cost. + +## Benchmark Results + +Since there is no way to build an index currently until quantization tables are +built, there is no way to benchmark the first two phases. There should be no +impact during Phase Three to performance. + +## Future Work + +None. + +## References + +None. \ No newline at end of file diff --git a/vectorset/src/loader.rs b/vectorset/src/loader.rs index f1e62444e..9e0f7a856 100644 --- a/vectorset/src/loader.rs +++ b/vectorset/src/loader.rs @@ -4,22 +4,14 @@ */ use anyhow::{Result, anyhow}; -use std::{ - collections::HashMap, - marker::PhantomData, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{collections::HashMap, marker::PhantomData, mem, path::Path, sync::Arc}; use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; const BATCH_SIZE: usize = 1024; /// Loads base or query vectors from a given path and allow iteration over them. -#[allow(dead_code)] pub struct DatasetLoader { file: Mutex<(File, usize)>, - path: PathBuf, num_vectors: usize, dim: usize, type_: PhantomData, @@ -36,7 +28,6 @@ impl DatasetLoader { Ok(Arc::new(Self { file: Mutex::new((file, 0)), - path, num_vectors, dim, type_: PhantomData, @@ -60,35 +51,33 @@ impl DatasetLoader { /// Returns (count, first_id) where `count` is the number of vectors loaded /// and `first_id` is the id of the first vector. pub async fn next(&self, buffer: &mut Vec) -> Result<(usize, usize)> { + let mut f = self.file.lock().await; + let mut count; let mut first_id; loop { - { - let mut f = self.file.lock().await; - - first_id = f.1; - if f.1 >= self.num_vectors { - buffer.clear(); - return Ok((0, first_id)); - } - - buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); + first_id = f.1; + if f.1 >= self.num_vectors { + buffer.clear(); + return Ok((0, first_id)); + } - let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); - while let bytes_read = f.0.read(buf).await? - && bytes_read > 0 - { - buf = &mut buf[bytes_read..]; - } + buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); - let elements_left = buf.len() / mem::size_of::(); - if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { - return Err(anyhow!("unexpected EOF")); - } + let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); + while let bytes_read = f.0.read(buf).await? + && bytes_read > 0 + { + buf = &mut buf[bytes_read..]; + } - count = BATCH_SIZE - elements_left / self.dim; + let elements_left = buf.len() / mem::size_of::(); + if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { + return Err(anyhow!("unexpected EOF")); } + count = BATCH_SIZE - elements_left / self.dim; + if count == 0 { continue; } @@ -96,7 +85,6 @@ impl DatasetLoader { break; } - let mut f = self.file.lock().await; f.1 += count; Ok((count, first_id)) diff --git a/vectorset/src/main.rs b/vectorset/src/main.rs index 4ba864d56..a16e7cfc9 100644 --- a/vectorset/src/main.rs +++ b/vectorset/src/main.rs @@ -32,6 +32,7 @@ const DEFAULT_PORT: u16 = 6379; trait Element: bytemuck::Pod + std::fmt::Debug + Send + Sync + 'static {} impl Element for u8 {} +impl Element for i8 {} impl Element for f32 {} #[derive(Deserialize)] @@ -53,9 +54,9 @@ struct Options { #[arg(short, long)] threads: Option, - /// Data type for vector elements - #[arg(long)] - data_type: DataType, + /// Quantizer + #[arg(long, default_value_t)] + quantizer: Quantizer, #[command(subcommand)] command: Commands, @@ -107,6 +108,10 @@ struct IngestArgs { #[arg(long)] no_header_with_dim: Option, + /// Metric + #[arg(long)] + metric: Option, + /// Paths to base vectors base_path: PathBuf, } @@ -158,9 +163,96 @@ struct QueryArgs { #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] enum DataType { Uint8, + Int8, Float32, } +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Default)] +enum Quantizer { + /// f32 vectors; no quantization + #[default] + NoQuant, + /// f32 vectors; spherical 1-bit quantization + Bin, + /// f32 vectors; minmax 8-bit scalar quantization + Q8, + /// u8 vectors; no quantization + XNoQuant_U8, + /// i8 vectors; no quantization + XNoQuant_I8, + /// u8 vectors; spherical 1-bit quantization + XBin_U8, + /// i8 vectors; spherical 1-bit quantization + XBin_I8, +} + +impl Quantizer { + pub fn data_type(&self) -> DataType { + match self { + Quantizer::NoQuant | Quantizer::Bin | Quantizer::Q8 => DataType::Float32, + Quantizer::XNoQuant_I8 | Quantizer::XBin_I8 => DataType::Int8, + Quantizer::XNoQuant_U8 | Quantizer::XBin_U8 => DataType::Uint8, + } + } +} + +impl ToRedisArgs for Quantizer { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + Quantizer::NoQuant => b"NOQUANT".as_slice(), + Quantizer::Bin => b"BIN".as_slice(), + Quantizer::Q8 => b"Q8".as_slice(), + Quantizer::XNoQuant_U8 => b"XNOQUANT_U8", + Quantizer::XNoQuant_I8 => b"XNOQUANT_I8", + Quantizer::XBin_U8 => b"XBIN_U8", + Quantizer::XBin_I8 => b"XBIN_I8", + }; + out.write_arg(q); + } +} + +impl std::fmt::Display for Quantizer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Quantizer::NoQuant => "NOQUANT", + Quantizer::Bin => "BIN", + Quantizer::Q8 => "Q8", + Quantizer::XNoQuant_U8 => "XNOQUANT_U8", + Quantizer::XNoQuant_I8 => "XNOQUANT_I8", + Quantizer::XBin_U8 => "XBIN_U8", + Quantizer::XBin_I8 => "XBIN_I8", + }; + f.write_str(s) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum DistanceMetric { + L2, + Cosine, + CosineNormalized, + InnerProduct, +} + +impl ToRedisArgs for DistanceMetric { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + DistanceMetric::L2 => b"L2".as_slice(), + DistanceMetric::Cosine => b"COSINE".as_slice(), + DistanceMetric::CosineNormalized => b"COSINE_NORMALIZED".as_slice(), + DistanceMetric::InnerProduct => b"IP".as_slice(), + }; + out.write_arg(q); + } +} + struct VectorId(u32); impl ToRedisArgs for VectorId { @@ -285,8 +377,9 @@ async fn async_main(opts: Options) -> Result<()> { None }; - match opts.data_type { + match opts.quantizer.data_type() { DataType::Uint8 => dispatch::(&opts.command, &opts, infos, cred).await, + DataType::Int8 => dispatch::(&opts.command, &opts, infos, cred).await, DataType::Float32 => dispatch::(&opts.command, &opts, infos, cred).await, } } @@ -361,7 +454,9 @@ async fn ingest( let l_build = args.l_build; let degree = args.degree; let mut cred = cred.clone(); - let data_type = opts.data_type; + let data_type = opts.quantizer.data_type(); + let quantizer = opts.quantizer; + let metric = args.metric; tasks.spawn(async move { let mut buf = vec![T::zeroed(); ds.batch_size() * ds.dim()]; @@ -387,11 +482,15 @@ async fn ingest( let element = VectorId((first_id + i) as u32); let buf_start = i * ds.dim(); let buf_end = buf_start + ds.dim(); + pipeline.cmd("VADD").arg(&vset); match data_type { DataType::Uint8 => { - pipeline.arg(b"XB8"); + pipeline.arg(b"XU8"); + } + DataType::Int8 => { + pipeline.arg(b"XI8"); } DataType::Float32 => { pipeline.arg(b"FP32"); @@ -402,13 +501,10 @@ async fn ingest( .arg(bytemuck::cast_slice::<_, u8>(&buf[buf_start..buf_end])) .arg(element); - match data_type { - DataType::Uint8 => { - pipeline.arg(b"XPREQ8"); - } - DataType::Float32 => { - pipeline.arg(b"NOQUANT"); - } + pipeline.arg(quantizer); + + if let Some(metric) = metric { + pipeline.arg(b"XDISTANCE_METRIC").arg(metric); } pipeline @@ -508,7 +604,7 @@ async fn query( let n = args.n; let l_search = args.l_search; let mut cred = cred.clone(); - let data_type = opts.data_type; + let data_type = opts.quantizer.data_type(); tasks.spawn(async move { let mut pipeline = Pipeline::with_capacity(pipeline_size); @@ -532,7 +628,8 @@ async fn query( pipeline.cmd("VSIM").arg(&vset); match data_type { - DataType::Uint8 => pipeline.arg(b"XB8"), + DataType::Uint8 => pipeline.arg(b"XU8"), + DataType::Int8 => pipeline.arg(b"XI8"), DataType::Float32 => pipeline.arg(b"FP32"), };