diff --git a/bindings/cpp/CMakeLists.txt b/bindings/cpp/CMakeLists.txt index f63143dc9..65e9384f2 100644 --- a/bindings/cpp/CMakeLists.txt +++ b/bindings/cpp/CMakeLists.txt @@ -102,7 +102,9 @@ if (SVS_RUNTIME_ENABLE_LVQ_LEANVEC) svs::svs svs_compile_options ) - link_mkl_static(${TARGET_NAME}) + if(SVS_EXPERIMENTAL_LINK_STATIC_MKL) + link_mkl_static(${TARGET_NAME}) + endif() if(SVS_LVQ_HEADER) target_compile_definitions(${TARGET_NAME} PRIVATE SVS_LVQ_HEADER="${SVS_LVQ_HEADER}" diff --git a/bindings/cpp/include/svs/runtime/dynamic_vamana_index.h b/bindings/cpp/include/svs/runtime/dynamic_vamana_index.h index 97596b3af..1d487baf9 100644 --- a/bindings/cpp/include/svs/runtime/dynamic_vamana_index.h +++ b/bindings/cpp/include/svs/runtime/dynamic_vamana_index.h @@ -74,6 +74,14 @@ struct SVS_RUNTIME_API DynamicVamanaIndex : public VamanaIndex { ) noexcept; virtual size_t blocksize_bytes() const noexcept = 0; + + // Override for VamanaIndex interface + Status add(size_t, const float*) noexcept override { + return Status( + ErrorCode::NOT_IMPLEMENTED, + "Use add(size_t n, const size_t* labels, const float* x) for DynamicVamanaIndex" + ); + } }; struct SVS_RUNTIME_API DynamicVamanaIndexLeanVec : public DynamicVamanaIndex { diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index ba9739fb4..988319528 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -16,38 +16,47 @@ #pragma once #include +#include #include +#include namespace svs { namespace runtime { namespace v0 { +namespace detail { +struct VamanaBuildParameters { + size_t graph_max_degree = Unspecify(); + size_t prune_to = Unspecify(); + float alpha = Unspecify(); + size_t construction_window_size = Unspecify(); + size_t max_candidate_pool_size = Unspecify(); + OptionalBool use_full_search_history = Unspecify(); +}; + +struct VamanaSearchParameters { + size_t search_window_size = Unspecify(); + size_t search_buffer_capacity = Unspecify(); + size_t prefetch_lookahead = Unspecify(); + size_t prefetch_step = Unspecify(); +}; +} // namespace detail + // Abstract interface for Vamana-based indices. -// NOTE VamanaIndex is not implemented directly, only DynamicVamanaIndex is implemented. struct SVS_RUNTIME_API VamanaIndex { virtual ~VamanaIndex(); - struct BuildParams { - size_t graph_max_degree = Unspecify(); - size_t prune_to = Unspecify(); - float alpha = Unspecify(); - size_t construction_window_size = Unspecify(); - size_t max_candidate_pool_size = Unspecify(); - OptionalBool use_full_search_history = Unspecify(); - }; - - struct SearchParams { - size_t search_window_size = Unspecify(); - size_t search_buffer_capacity = Unspecify(); - size_t prefetch_lookahead = Unspecify(); - size_t prefetch_step = Unspecify(); - }; + using BuildParams = detail::VamanaBuildParameters; + using SearchParams = detail::VamanaSearchParameters; struct DynamicIndexParams { size_t blocksize_exp = 30; }; + virtual Status add(size_t n, const float* x) noexcept = 0; + virtual Status reset() noexcept = 0; + virtual Status search( size_t n, const float* x, @@ -66,6 +75,50 @@ struct SVS_RUNTIME_API VamanaIndex { const SearchParams* params = nullptr, IDFilter* filter = nullptr ) const noexcept = 0; + + // Utility function to check storage kind support + static Status check_storage_kind(StorageKind storage_kind) noexcept; + + // Static constructors and destructors + static Status build( + VamanaIndex** index, + size_t dim, + MetricType metric, + StorageKind storage_kind, + const VamanaIndex::BuildParams& params = VamanaIndex::BuildParams{}, + const VamanaIndex::SearchParams& default_search_params = VamanaIndex::SearchParams{} + ) noexcept; + + static Status destroy(VamanaIndex* index) noexcept; + + virtual Status save(std::ostream& out) const noexcept = 0; + static Status load( + VamanaIndex** index, std::istream& in, MetricType metric, StorageKind storage_kind + ) noexcept; +}; + +struct SVS_RUNTIME_API VamanaIndexLeanVec : public VamanaIndex { + // Specialization to build LeanVec-based Vamana index with specified leanvec dims + static Status build( + VamanaIndex** index, + size_t dim, + MetricType metric, + StorageKind storage_kind, + size_t leanvec_dims, + const VamanaIndex::BuildParams& params = {}, + const VamanaIndex::SearchParams& default_search_params = {} + ) noexcept; + + // Specialization to build LeanVec-based Vamana index with provided training data + static Status build( + VamanaIndex** index, + size_t dim, + MetricType metric, + StorageKind storage_kind, + const LeanVecTrainingData* training_data, + const VamanaIndex::BuildParams& params = {}, + const VamanaIndex::SearchParams& default_search_params = {} + ) noexcept; }; } // namespace v0 diff --git a/bindings/cpp/src/dynamic_ivf_index_impl.h b/bindings/cpp/src/dynamic_ivf_index_impl.h index 1eb5016db..c014d8f81 100644 --- a/bindings/cpp/src/dynamic_ivf_index_impl.h +++ b/bindings/cpp/src/dynamic_ivf_index_impl.h @@ -23,6 +23,8 @@ namespace runtime { // Dynamic IVF index implementation (non-LeanVec storage kinds) class DynamicIVFIndexImpl { + using allocator_type = svs::data::Blocked>; + public: DynamicIVFIndexImpl( size_t dim, @@ -184,25 +186,32 @@ class DynamicIVFIndexImpl { } // Dispatch on storage kind to load with correct data type - return ivf_storage::dispatch_ivf_storage_kind(storage_kind, [&](auto tag) { - using Tag = decltype(tag); - using DataType = ivf_storage::IVFBlockedStorageType_t; - - svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); - return distance_dispatcher([&](auto&& distance) { - auto impl = std::make_unique( - svs::DynamicIVF::assemble( - in, - std::forward(distance), + return ivf_storage::dispatch_ivf_storage_kind( + storage_kind, + [&](auto tag) { + using Tag = decltype(tag); + using DataType = typename Tag::type; + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + return distance_dispatcher([&](auto&& distance) { + auto impl = std::make_unique( + svs::DynamicIVF::assemble( + in, + std::forward(distance), + num_threads, + intra_query_threads + ) + ); + return new DynamicIVFIndexImpl( + std::move(impl), + metric, + storage_kind, num_threads, intra_query_threads - ) - ); - return new DynamicIVFIndexImpl( - std::move(impl), metric, storage_kind, num_threads, intra_query_threads - ); - }); - }); + ); + }); + } + ); } protected: @@ -297,20 +306,25 @@ class DynamicIVFIndexImpl { ); // Dispatch on storage kind to compress and assemble - return ivf_storage::dispatch_ivf_storage_kind(storage_kind_, [&](auto tag) { - // Compress data to target storage type using the factory - auto compressed_data = - ivf_storage::make_ivf_blocked_storage(tag, data, threadpool); - - return new svs::DynamicIVF(svs::DynamicIVF::assemble_from_clustering( - std::move(clustering), - std::move(compressed_data), - ids, - std::forward(distance), - num_threads_, - intra_query_threads_ - )); - }); + return ivf_storage::dispatch_ivf_storage_kind( + storage_kind_, + [&](auto tag) { + // Compress data to target storage type using the factory + auto compressed_data = + ivf_storage::make_ivf_blocked_storage(tag, data, threadpool); + + return new svs::DynamicIVF( + svs::DynamicIVF::assemble_from_clustering( + std::move(clustering), + std::move(compressed_data), + ids, + std::forward(distance), + num_threads_, + intra_query_threads_ + ) + ); + } + ); })); } @@ -328,6 +342,8 @@ class DynamicIVFIndexImpl { #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC // Dynamic IVF index implementation for LeanVec storage kinds class DynamicIVFIndexLeanVecImpl : public DynamicIVFIndexImpl { + using allocator_type = svs::data::Blocked>; + public: using LeanVecMatricesType = LeanVecTrainingDataImpl::LeanVecMatricesType; @@ -397,25 +413,32 @@ class DynamicIVFIndexLeanVecImpl : public DynamicIVFIndexImpl { num_threads = static_cast(omp_get_max_threads()); } - return ivf_storage::dispatch_ivf_leanvec_storage_kind(storage_kind, [&](auto tag) { - using Tag = decltype(tag); - using DataType = ivf_storage::IVFBlockedStorageType_t; - - svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); - return distance_dispatcher([&](auto&& distance) { - auto impl = std::make_unique( - svs::DynamicIVF::assemble( - in, - std::forward(distance), + return ivf_storage::dispatch_ivf_leanvec_storage_kind( + storage_kind, + [&](auto tag) { + using Tag = decltype(tag); + using DataType = typename Tag::type; + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + return distance_dispatcher([&](auto&& distance) { + auto impl = std::make_unique( + svs::DynamicIVF::assemble( + in, + std::forward(distance), + num_threads, + intra_query_threads + ) + ); + return new DynamicIVFIndexLeanVecImpl( + std::move(impl), + metric, + storage_kind, num_threads, intra_query_threads - ) - ); - return new DynamicIVFIndexLeanVecImpl( - std::move(impl), metric, storage_kind, num_threads, intra_query_threads - ); - }); - }); + ); + }); + } + ); } protected: @@ -451,7 +474,7 @@ class DynamicIVFIndexLeanVecImpl : public DynamicIVFIndexImpl { ); // Dispatch on LeanVec storage kind to compress and assemble - return ivf_storage::dispatch_ivf_leanvec_storage_kind( + return ivf_storage::dispatch_ivf_leanvec_storage_kind( storage_kind_, [&](auto tag) { // Compress data to LeanVec storage type using the factory with matrices diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 9e4656288..4b16cf4bc 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -38,8 +38,10 @@ namespace svs { namespace runtime { -// Vamana index implementation +// Dynamic Vamana index implementation class DynamicVamanaIndexImpl { + using allocator_type = svs::data::Blocked>; + public: DynamicVamanaIndexImpl( size_t dim, @@ -138,10 +140,9 @@ class DynamicVamanaIndexImpl { // Pad results if not enough neighbors found if (found < k) { - auto& dists = result.distances(); - std::fill(dists.begin() + found, dists.end(), Unspecify()); - auto& inds = result.indices(); - std::fill(inds.begin() + found, inds.end(), Unspecify()); + for (size_t j = found; j < k; ++j) { + result.set(Neighbor{Unspecify(), Unspecify()}, i, j); + } } } }; @@ -394,12 +395,14 @@ class DynamicVamanaIndexImpl { StorageArgs&&... storage_args ) { auto threadpool = default_threadpool(); + using storage_alloc_t = typename Tag::allocator_type; + auto allocator = storage::make_allocator(blocksize_bytes); auto storage = make_storage( std::forward(tag), data, threadpool, - blocksize_bytes, + allocator, std::forward(storage_args)... ); @@ -420,7 +423,7 @@ class DynamicVamanaIndexImpl { std::span labels, lib::PowerOfTwo blocksize_bytes ) { - impl_.reset(storage::dispatch_storage_kind( + impl_.reset(storage::dispatch_storage_kind( get_storage_kind(), [this]( auto&& tag, @@ -466,7 +469,7 @@ class DynamicVamanaIndexImpl { impl_->get_full_search_history()}; } - template + template static svs::DynamicVamana* load_impl_t(Tag&& tag, std::istream& stream, MetricType metric) { namespace fs = std::filesystem; @@ -514,7 +517,7 @@ class DynamicVamanaIndexImpl { public: static DynamicVamanaIndexImpl* load(std::istream& stream, MetricType metric, StorageKind storage_kind) { - return storage::dispatch_storage_kind( + return storage::dispatch_storage_kind( storage_kind, [&](auto&& tag, std::istream& stream, MetricType metric) { using Tag = std::decay_t; diff --git a/bindings/cpp/src/dynamic_vamana_index_leanvec_impl.h b/bindings/cpp/src/dynamic_vamana_index_leanvec_impl.h index 89b2bf926..4d59281d5 100644 --- a/bindings/cpp/src/dynamic_vamana_index_leanvec_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_leanvec_impl.h @@ -34,6 +34,7 @@ namespace runtime { // Vamana index implementation for LeanVec storage kinds struct DynamicVamanaIndexLeanVecImpl : public DynamicVamanaIndexImpl { using LeanVecMatricesType = LeanVecTrainingDataImpl::LeanVecMatricesType; + using allocator_type = svs::data::Blocked>; DynamicVamanaIndexLeanVecImpl( std::unique_ptr&& impl, @@ -80,11 +81,20 @@ struct DynamicVamanaIndexLeanVecImpl : public DynamicVamanaIndexImpl { static auto dispatch_leanvec_storage_kind(StorageKind kind, F&& f, Args&&... args) { switch (kind) { case StorageKind::LeanVec4x4: - return f(storage::LeanVec4x4Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LeanVec4x8: - return f(storage::LeanVec4x8Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LeanVec8x8: - return f(storage::LeanVec8x8Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); default: throw StatusException{ ErrorCode::INVALID_ARGUMENT, "SVS LeanVec storage kind required"}; diff --git a/bindings/cpp/src/flat_index_impl.h b/bindings/cpp/src/flat_index_impl.h index 1557ab9be..02e07f54d 100644 --- a/bindings/cpp/src/flat_index_impl.h +++ b/bindings/cpp/src/flat_index_impl.h @@ -35,6 +35,8 @@ namespace runtime { // Vamana index implementation class FlatIndexImpl { + using allocator_type = svs::lib::Allocator; + public: FlatIndexImpl(size_t dim, MetricType metric) : dim_{dim} @@ -96,7 +98,8 @@ class FlatIndexImpl { static FlatIndexImpl* load(std::istream& in, MetricType metric) { auto threadpool = default_threadpool(); - using storage_type = svs::runtime::storage::StorageType_t; + using storage_type = svs::runtime::storage:: + StorageType_t>; svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); return distance_dispatcher([&](auto&& distance) { @@ -119,7 +122,7 @@ class FlatIndexImpl { auto threadpool = default_threadpool(); auto storage = svs::runtime::storage::make_storage( - svs::runtime::storage::FP32Tag{}, data, threadpool + storage::StorageType{}, data, threadpool ); svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric_type_)); diff --git a/bindings/cpp/src/ivf_index_impl.h b/bindings/cpp/src/ivf_index_impl.h index c0220e15f..ace923290 100644 --- a/bindings/cpp/src/ivf_index_impl.h +++ b/bindings/cpp/src/ivf_index_impl.h @@ -81,118 +81,14 @@ inline bool is_supported_storage_kind(StorageKind kind) { return is_supported_non_leanvec_storage_kind(kind) || is_leanvec_storage_kind(kind); } -///// IVF Data Types ///// - -// Simple uncompressed data types -template -using IVFSimpleDataType = svs::data::SimpleData>; - -template -using IVFBlockedSimpleDataType = - svs::data::SimpleData>>; - -// Scalar Quantization data types -template -using IVFSQDataType = - svs::quantization::scalar::SQDataset>; - -template -using IVFBlockedSQDataType = svs::quantization::scalar:: - SQDataset>>; - -#ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC -// LVQ data types -template -using IVFLVQDataType = svs::quantization::lvq:: - LVQDataset>; - -template -using IVFBlockedLVQDataType = svs::quantization::lvq::LVQDataset< - Primary, - Residual, - svs::Dynamic, - Strategy, - svs::data::Blocked>>; - -using Sequential = svs::quantization::lvq::Sequential; -using Turbo16x8 = svs::quantization::lvq::Turbo<16, 8>; - -// LeanVec data types -template -using IVFLeanVecDataType = svs::leanvec::LeanDataset< - svs::leanvec::UsingLVQ, - svs::leanvec::UsingLVQ, - svs::Dynamic, - svs::Dynamic, - svs::lib::Allocator>; - -template -using IVFBlockedLeanVecDataType = svs::leanvec::LeanDataset< - svs::leanvec::UsingLVQ, - svs::leanvec::UsingLVQ, - svs::Dynamic, - svs::Dynamic, - svs::data::Blocked>>; -#endif // SVS_RUNTIME_HAVE_LVQ_LEANVEC - -///// Storage Type Mapping ///// - -// Map StorageKind to data type using storage tags -template struct IVFStorageType { - using type = storage::UnsupportedStorageType; -}; - -template struct IVFBlockedStorageType { - using type = storage::UnsupportedStorageType; -}; - -template -using IVFStorageType_t = typename IVFStorageType::type; - -template -using IVFBlockedStorageType_t = typename IVFBlockedStorageType::type; - -// clang-format off -template <> struct IVFStorageType { using type = IVFSimpleDataType; }; -template <> struct IVFStorageType { using type = IVFSimpleDataType; }; -template <> struct IVFStorageType { using type = IVFSQDataType; }; - -template <> struct IVFBlockedStorageType { using type = IVFBlockedSimpleDataType; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedSimpleDataType; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedSQDataType; }; -// clang-format on - -#ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC -// clang-format off -template <> struct IVFStorageType { using type = IVFLVQDataType<4, 0, Turbo16x8>; }; -template <> struct IVFStorageType { using type = IVFLVQDataType<8, 0, Sequential>; }; -template <> struct IVFStorageType { using type = IVFLVQDataType<4, 4, Turbo16x8>; }; -template <> struct IVFStorageType { using type = IVFLVQDataType<4, 8, Turbo16x8>; }; - -template <> struct IVFBlockedStorageType { using type = IVFBlockedLVQDataType<4, 0, Turbo16x8>; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedLVQDataType<8, 0, Sequential>; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedLVQDataType<4, 4, Turbo16x8>; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedLVQDataType<4, 8, Turbo16x8>; }; -// clang-format on - -// clang-format off -template <> struct IVFStorageType { using type = IVFLeanVecDataType<4, 4>; }; -template <> struct IVFStorageType { using type = IVFLeanVecDataType<4, 8>; }; -template <> struct IVFStorageType { using type = IVFLeanVecDataType<8, 8>; }; - -template <> struct IVFBlockedStorageType { using type = IVFBlockedLeanVecDataType<4, 4>; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedLeanVecDataType<4, 8>; }; -template <> struct IVFBlockedStorageType { using type = IVFBlockedLeanVecDataType<8, 8>; }; -// clang-format on -#endif // SVS_RUNTIME_HAVE_LVQ_LEANVEC - ///// Storage Factory ///// template struct IVFStorageFactory; // Unsupported storage factory -template <> struct IVFStorageFactory { - using DataType = IVFSimpleDataType; +template struct IVFStorageFactory> { + using DataType = storage:: + SimpleDatasetType>; template static DataType @@ -276,52 +172,66 @@ struct IVFStorageFactory - requires storage::StorageTag> +template auto make_ivf_storage( - Tag&&, const svs::data::ConstSimpleDataView& data, Pool& pool, size_t arg = 0 + storage::StorageType SVS_UNUSED(tag), + const svs::data::ConstSimpleDataView& data, + Pool& pool, + size_t arg = 0 ) { - using TagDecay = std::decay_t; - return IVFStorageFactory>::compress(data, pool, arg); + static_assert( + !svs::data::is_blocked_v, "Allocator must not be blocked for IVF storage" + ); + return IVFStorageFactory>::compress( + data, pool, arg + ); } -template - requires storage::StorageTag> +template auto make_ivf_blocked_storage( - Tag&&, const svs::data::ConstSimpleDataView& data, Pool& pool, size_t arg = 0 + storage::StorageType SVS_UNUSED(tag), + const svs::data::ConstSimpleDataView& data, + Pool& pool, + size_t arg = 0 ) { - using TagDecay = std::decay_t; - return IVFStorageFactory>::compress(data, pool, arg); + static_assert( + svs::data::is_blocked_v, "Allocator must be blocked for IVF storage" + ); + return IVFStorageFactory>::compress( + data, pool, arg + ); } #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC // LeanVec-specific make functions with matrices parameter -template - requires storage::StorageTag> +template auto make_ivf_leanvec_storage( - Tag&&, + storage::StorageType SVS_UNUSED(tag), const svs::data::ConstSimpleDataView& data, Pool& pool, size_t leanvec_d, std::optional> matrices ) { - using TagDecay = std::decay_t; - return IVFStorageFactory>::compress( + static_assert( + !svs::data::is_blocked_v, "Allocator must not be blocked for IVF storage" + ); + return IVFStorageFactory>::compress( data, pool, leanvec_d, std::move(matrices) ); } -template - requires storage::StorageTag> +template auto make_ivf_blocked_leanvec_storage( - Tag&&, + storage::StorageType SVS_UNUSED(tag), const svs::data::ConstSimpleDataView& data, Pool& pool, size_t leanvec_d, std::optional> matrices ) { - using TagDecay = std::decay_t; - return IVFStorageFactory>::compress( + static_assert( + svs::data::is_blocked_v, "Allocator must be blocked for IVF storage" + ); + return IVFStorageFactory>::compress( data, pool, leanvec_d, std::move(matrices) ); } @@ -330,24 +240,45 @@ auto make_ivf_blocked_leanvec_storage( ///// Dispatch Functions ///// // Dispatch on storage kind for IVF operations (excludes LeanVec - handled separately) -template +template auto dispatch_ivf_storage_kind(StorageKind kind, F&& f, Args&&... args) { switch (kind) { case StorageKind::FP32: - return f(storage::FP32Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::FP16: - return f(storage::FP16Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::SQI8: - return f(storage::SQI8Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC case StorageKind::LVQ4x0: - return f(storage::LVQ4x0Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LVQ8x0: - return f(storage::LVQ8x0Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LVQ4x4: - return f(storage::LVQ4x4Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LVQ4x8: - return f(storage::LVQ4x8Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); #endif default: throw StatusException{ @@ -358,15 +289,24 @@ auto dispatch_ivf_storage_kind(StorageKind kind, F&& f, Args&&... args) { #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC // Dispatch on LeanVec storage kinds only -template +template auto dispatch_ivf_leanvec_storage_kind(StorageKind kind, F&& f, Args&&... args) { switch (kind) { case StorageKind::LeanVec4x4: - return f(storage::LeanVec4x4Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LeanVec4x8: - return f(storage::LeanVec4x8Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); case StorageKind::LeanVec8x8: - return f(storage::LeanVec8x8Tag{}, std::forward(args)...); + return f( + storage::StorageType{}, + std::forward(args)... + ); default: throw StatusException{ ErrorCode::INVALID_ARGUMENT, "LeanVec storage kind required"}; @@ -378,6 +318,8 @@ auto dispatch_ivf_leanvec_storage_kind(StorageKind kind, F&& f, Args&&... args) // Static IVF index implementation (non-LeanVec storage kinds) class IVFIndexImpl { + using allocator_type = svs::lib::Allocator; + public: IVFIndexImpl( size_t dim, @@ -481,25 +423,32 @@ class IVFIndexImpl { } // Dispatch on storage kind to load with correct data type - return ivf_storage::dispatch_ivf_storage_kind(storage_kind, [&](auto tag) { - using Tag = decltype(tag); - using DataType = ivf_storage::IVFStorageType_t; - - svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); - return distance_dispatcher([&](auto&& distance) { - auto impl = std::make_unique( - svs::IVF::assemble( - in, - std::forward(distance), + return ivf_storage::dispatch_ivf_storage_kind( + storage_kind, + [&](auto tag) { + using Tag = decltype(tag); + using DataType = typename Tag::type; + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + return distance_dispatcher([&](auto&& distance) { + auto impl = std::make_unique( + svs::IVF::assemble( + in, + std::forward(distance), + num_threads, + intra_query_threads + ) + ); + return new IVFIndexImpl( + std::move(impl), + metric, + storage_kind, num_threads, intra_query_threads - ) - ); - return new IVFIndexImpl( - std::move(impl), metric, storage_kind, num_threads, intra_query_threads - ); - }); - }); + ); + }); + } + ); } protected: @@ -594,18 +543,22 @@ class IVFIndexImpl { ); // Dispatch on storage kind to compress and assemble - return ivf_storage::dispatch_ivf_storage_kind(storage_kind_, [&](auto tag) { - // Compress data to target storage type using the factory - auto compressed_data = ivf_storage::make_ivf_storage(tag, data, threadpool); - - return new svs::IVF(svs::IVF::assemble_from_clustering( - std::move(clustering), - std::move(compressed_data), - std::forward(distance), - num_threads_, - intra_query_threads_ - )); - }); + return ivf_storage::dispatch_ivf_storage_kind( + storage_kind_, + [&](auto tag) { + // Compress data to target storage type using the factory + auto compressed_data = + ivf_storage::make_ivf_storage(tag, data, threadpool); + + return new svs::IVF(svs::IVF::assemble_from_clustering( + std::move(clustering), + std::move(compressed_data), + std::forward(distance), + num_threads_, + intra_query_threads_ + )); + } + ); })); } @@ -623,6 +576,8 @@ class IVFIndexImpl { #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC // Static IVF index implementation for LeanVec storage kinds class IVFIndexLeanVecImpl : public IVFIndexImpl { + using allocator_type = svs::lib::Allocator; + public: using LeanVecMatricesType = LeanVecTrainingDataImpl::LeanVecMatricesType; @@ -692,25 +647,32 @@ class IVFIndexLeanVecImpl : public IVFIndexImpl { num_threads = static_cast(omp_get_max_threads()); } - return ivf_storage::dispatch_ivf_leanvec_storage_kind(storage_kind, [&](auto tag) { - using Tag = decltype(tag); - using DataType = ivf_storage::IVFStorageType_t; - - svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); - return distance_dispatcher([&](auto&& distance) { - auto impl = std::make_unique( - svs::IVF::assemble( - in, - std::forward(distance), + return ivf_storage::dispatch_ivf_leanvec_storage_kind( + storage_kind, + [&](auto tag) { + using Tag = decltype(tag); + using DataType = typename Tag::type; + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + return distance_dispatcher([&](auto&& distance) { + auto impl = std::make_unique( + svs::IVF::assemble( + in, + std::forward(distance), + num_threads, + intra_query_threads + ) + ); + return new IVFIndexLeanVecImpl( + std::move(impl), + metric, + storage_kind, num_threads, intra_query_threads - ) - ); - return new IVFIndexLeanVecImpl( - std::move(impl), metric, storage_kind, num_threads, intra_query_threads - ); - }); - }); + ); + }); + } + ); } protected: @@ -745,7 +707,7 @@ class IVFIndexLeanVecImpl : public IVFIndexImpl { ); // Dispatch on LeanVec storage kind to compress and assemble - return ivf_storage::dispatch_ivf_leanvec_storage_kind( + return ivf_storage::dispatch_ivf_leanvec_storage_kind( storage_kind_, [&](auto tag) { // Compress data to LeanVec storage type using the factory with matrices diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index 2270e4e8b..e0d7c68af 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -123,71 +123,100 @@ inline bool is_supported_storage_kind(StorageKind kind) { return true; } -// Storage kind processing -// Most kinds map to std::byte storage, but some have specific element types. -// Storage kind tag types for function argument deduction -template struct StorageKindTag { - static constexpr StorageKind value = K; +template struct AllocatorTypeExtractor { + using type = A; }; -#define SVS_DEFINE_STORAGE_KIND_TAG(Kind) \ - using Kind##Tag = StorageKindTag +template +concept AllocatorAwareType = requires { typename A::allocator_type; }; -SVS_DEFINE_STORAGE_KIND_TAG(FP32); -SVS_DEFINE_STORAGE_KIND_TAG(FP16); -SVS_DEFINE_STORAGE_KIND_TAG(SQI8); -SVS_DEFINE_STORAGE_KIND_TAG(LVQ4x0); -SVS_DEFINE_STORAGE_KIND_TAG(LVQ8x0); -SVS_DEFINE_STORAGE_KIND_TAG(LVQ4x4); -SVS_DEFINE_STORAGE_KIND_TAG(LVQ4x8); -SVS_DEFINE_STORAGE_KIND_TAG(LeanVec4x4); -SVS_DEFINE_STORAGE_KIND_TAG(LeanVec4x8); -SVS_DEFINE_STORAGE_KIND_TAG(LeanVec8x8); +template struct AllocatorTypeExtractor { + using type = typename A::allocator_type; +}; + +template using extract_allocator_t = typename AllocatorTypeExtractor::type; -#undef SVS_DEFINE_STORAGE_KIND_TAG +template struct ExtractedAllocatorRebinder { + using type = lib::rebind_allocator_t>; +}; -template inline constexpr bool is_storage_tag = false; -template inline constexpr bool is_storage_tag> = true; +template +struct ExtractedAllocatorRebinder> { + using type = svs::data::Blocked>>; +}; + +template +using rebind_extracted_allocator_t = typename ExtractedAllocatorRebinder::type; + +template +Alloc make_allocator() + requires(!svs::data::is_blocked_v) +{ + return Alloc{}; +} -template -concept StorageTag = is_storage_tag; +template +Alloc make_allocator(svs::lib::PowerOfTwo blocksize_bytes) + requires(svs::data::is_blocked_v) +{ + if (blocksize_bytes.raw() == 0) { + throw StatusException( + ErrorCode::INVALID_ARGUMENT, + "Blocked storage types require a non-zero blocksize" + ); + } + auto parameters = svs::data::BlockingParameters{.blocksize_bytes = blocksize_bytes}; + return Alloc(parameters); +} // Storage types -template -using SimpleDatasetType = - svs::data::SimpleData>>; +template +using SimpleDatasetType = svs::data::SimpleData; -template -using SQDatasetType = svs::quantization::scalar:: - SQDataset>>; +template +using SQDatasetType = svs::quantization::scalar::SQDataset; // Storage type mapping // Unsupported storage type defined as unique type to cause runtime error if used -struct UnsupportedStorageType {}; +template struct UnsupportedStorageType { + using allocator_type = Alloc; // Dummy allocator type to satisfy template requirements +}; -// clang-format off -template struct StorageType { using type = UnsupportedStorageType; }; -template using StorageType_t = typename StorageType::type; +template struct StorageType { + using allocator_type = Alloc; + using type = UnsupportedStorageType; +}; +template +using StorageType_t = typename StorageType::type; -template <> struct StorageType { using type = SimpleDatasetType; }; -template <> struct StorageType { using type = SimpleDatasetType; }; -template <> struct StorageType { using type = SQDatasetType; }; -// clang-format on +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = SimpleDatasetType; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = SimpleDatasetType; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = SQDatasetType; +}; // Storage factory template struct StorageFactory; // Unsupported storage type factory returning runtime error when attempted to be used. // Return type defined to simple to allow substitution in templates. -template <> struct StorageFactory { - using StorageType = SimpleDatasetType; +template struct StorageFactory> { + using StorageType = + SimpleDatasetType>; template static StorageType init( const svs::data::ConstSimpleDataView& SVS_UNUSED(data), Pool& SVS_UNUSED(pool), - svs::lib::PowerOfTwo SVS_UNUSED(blocksize_bytes) + const typename StorageType::allocator_type& SVS_UNUSED(alloc) = {} ) { throw StatusException( ErrorCode::NOT_IMPLEMENTED, "Requested storage kind is not supported" @@ -203,18 +232,16 @@ template <> struct StorageFactory { } }; -template struct StorageFactory> { - using StorageType = SimpleDatasetType; +template +struct StorageFactory> { + using StorageType = svs::data::SimpleData; template static StorageType init( const svs::data::ConstSimpleDataView& data, Pool& pool, - svs::lib::PowerOfTwo blocksize_bytes = - svs::data::BlockingParameters::default_blocksize_bytes + const Alloc& alloc = {} ) { - auto parameters = svs::data::BlockingParameters{.blocksize_bytes = blocksize_bytes}; - typename StorageType::allocator_type alloc(parameters); StorageType result(data.size(), data.dimensions(), alloc); svs::threads::parallel_for( pool, @@ -235,19 +262,17 @@ template struct StorageFactory -struct StorageFactory { - using StorageType = SQStorageType; +template +struct StorageFactory> { + using StorageType = svs::quantization::scalar::SQDataset; template static StorageType init( const svs::data::ConstSimpleDataView& data, Pool& pool, - svs::lib::PowerOfTwo blocksize_bytes + const Alloc& alloc = {} ) { - auto parameters = svs::data::BlockingParameters{.blocksize_bytes = blocksize_bytes}; - typename StorageType::allocator_type alloc(parameters); - return SQStorageType::compress(data, pool, alloc); + return StorageType::compress(data, pool, alloc); } template @@ -258,37 +283,49 @@ struct StorageFactory { // LVQ Storage support #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC -template -using LVQDatasetType = svs::quantization::lvq::LVQDataset< - Primary, - Residual, - svs::Dynamic, - Strategy, - svs::data::Blocked>>; - using Sequential = svs::quantization::lvq::Sequential; using Turbo16x8 = svs::quantization::lvq::Turbo<16, 8>; - -// clang-format off -template <> struct StorageType { using type = LVQDatasetType<4, 0, Turbo16x8>; }; -template <> struct StorageType { using type = LVQDatasetType<8, 0, Sequential>; }; -template <> struct StorageType { using type = LVQDatasetType<4, 4, Turbo16x8>; }; -template <> struct StorageType { using type = LVQDatasetType<4, 8, Turbo16x8>; }; -// clang-format on +template +using AutoStrategy = std::conditional_t<(Primary == 4), Turbo16x8, Sequential>; + +template < + size_t Primary, + size_t Residual, + typename Alloc, + size_t Extent = svs::Dynamic, + svs::quantization::lvq::LVQPackingStrategy Strategy = AutoStrategy> +using LVQDatasetType = + svs::quantization::lvq::LVQDataset; + +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LVQDatasetType<4, 0, allocator_type>; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LVQDatasetType<8, 0, allocator_type>; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LVQDatasetType<4, 4, allocator_type>; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LVQDatasetType<4, 8, allocator_type>; +}; template struct StorageFactory { using StorageType = LVQStorageType; + using Alloc = typename StorageType::allocator_type; template static StorageType init( const svs::data::ConstSimpleDataView& data, Pool& pool, - svs::lib::PowerOfTwo blocksize_bytes + const Alloc& alloc = {} ) { - auto parameters = svs::data::BlockingParameters{.blocksize_bytes = blocksize_bytes}; - typename LVQStorageType::allocator_type alloc(parameters); - return LVQStorageType::compress(data, pool, 0, alloc); + return StorageType::compress(data, pool, 0, alloc); } template @@ -298,37 +335,48 @@ struct StorageFactory { }; // LeanVec Storage support -template +template < + size_t I1, + size_t I2, + typename Alloc, + size_t LeanVecDims = svs::Dynamic, + size_t Extent = svs::Dynamic> using LeanDatasetType = svs::leanvec::LeanDataset< svs::leanvec::UsingLVQ, svs::leanvec::UsingLVQ, - svs::Dynamic, - svs::Dynamic, - svs::data::Blocked>>; + LeanVecDims, + Extent, + Alloc>; -// clang-format off -template <> struct StorageType { using type = LeanDatasetType<4, 4>; }; -template <> struct StorageType { using type = LeanDatasetType<4, 8>; }; -template <> struct StorageType { using type = LeanDatasetType<8, 8>; }; -// clang-format on +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LeanDatasetType<4, 4, allocator_type>; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LeanDatasetType<4, 8, allocator_type>; +}; +template struct StorageType { + using allocator_type = rebind_extracted_allocator_t; + using type = LeanDatasetType<8, 8, allocator_type>; +}; template struct StorageFactory { using StorageType = LeanVecStorageType; + using Alloc = typename StorageType::allocator_type; template static StorageType init( const svs::data::ConstSimpleDataView& data, Pool& pool, - svs::lib::PowerOfTwo blocksize_bytes, + const Alloc& alloc = {}, size_t leanvec_d = 0, std::optional> matrices = std::nullopt ) { if (leanvec_d == 0) { leanvec_d = (data.dimensions() + 1) / 2; } - auto parameters = svs::data::BlockingParameters{.blocksize_bytes = blocksize_bytes}; - typename LeanVecStorageType::allocator_type alloc(parameters); return LeanVecStorageType::reduce( data, std::move(matrices), pool, 0, svs::lib::MaybeStatic{leanvec_d}, alloc ); @@ -341,47 +389,45 @@ struct StorageFactory { }; #endif // SVS_RUNTIME_HAVE_LVQ_LEANVEC -template -auto make_storage(Tag&& SVS_UNUSED(tag), Args&&... args) { - return StorageFactory>::init(std::forward(args)...); +template +auto make_storage(StorageType SVS_UNUSED(tag), Args&&... args) { + return StorageFactory>::init(std::forward(args)...); } -template -auto load_storage(Tag&& SVS_UNUSED(tag), Args&&... args) { - return StorageFactory>::load(std::forward(args)...); +template +auto load_storage(StorageType SVS_UNUSED(tag), Args&&... args) { + return StorageFactory>::load(std::forward(args)...); } -template +template auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { if (!is_supported_storage_kind(kind)) { throw StatusException( ErrorCode::NOT_IMPLEMENTED, "Requested storage kind is not supported by CPU" ); } +#define SVS_DISPATCH_STORAGE_KIND(Kind) \ + case StorageKind::Kind: \ + return f(StorageType{}, std::forward(args)...) + switch (kind) { - case StorageKind::FP32: - return f(FP32Tag{}, std::forward(args)...); - case StorageKind::FP16: - return f(FP16Tag{}, std::forward(args)...); - case StorageKind::SQI8: - return f(SQI8Tag{}, std::forward(args)...); - case StorageKind::LVQ4x0: - return f(LVQ4x0Tag{}, std::forward(args)...); - case StorageKind::LVQ8x0: - return f(LVQ8x0Tag{}, std::forward(args)...); - case StorageKind::LVQ4x4: - return f(LVQ4x4Tag{}, std::forward(args)...); - case StorageKind::LVQ4x8: - return f(LVQ4x8Tag{}, std::forward(args)...); - case StorageKind::LeanVec4x4: - return f(LeanVec4x4Tag{}, std::forward(args)...); - case StorageKind::LeanVec4x8: - return f(LeanVec4x8Tag{}, std::forward(args)...); - case StorageKind::LeanVec8x8: - return f(LeanVec8x8Tag{}, std::forward(args)...); + SVS_DISPATCH_STORAGE_KIND(FP32); + SVS_DISPATCH_STORAGE_KIND(FP16); + SVS_DISPATCH_STORAGE_KIND(SQI8); + SVS_DISPATCH_STORAGE_KIND(LVQ4x0); + SVS_DISPATCH_STORAGE_KIND(LVQ8x0); + SVS_DISPATCH_STORAGE_KIND(LVQ4x4); + SVS_DISPATCH_STORAGE_KIND(LVQ4x8); + SVS_DISPATCH_STORAGE_KIND(LeanVec4x4); + SVS_DISPATCH_STORAGE_KIND(LeanVec4x8); + SVS_DISPATCH_STORAGE_KIND(LeanVec8x8); default: - throw ANNEXCEPTION("not supported SVS storage kind"); + throw StatusException( + ErrorCode::INVALID_ARGUMENT, "Unknown or unsupported SVS storage kind" + ); } + +#undef SVS_DISPATCH_STORAGE_KIND } } // namespace storage diff --git a/bindings/cpp/src/training_impl.h b/bindings/cpp/src/training_impl.h index adc9e80d7..9f97220b9 100644 --- a/bindings/cpp/src/training_impl.h +++ b/bindings/cpp/src/training_impl.h @@ -41,21 +41,26 @@ struct LeanVecTrainingDataImpl { LeanVecTrainingDataImpl(LeanVecMatricesType&& matrices) : leanvec_dims_{matrices.view_data_matrix().dimensions()} - , leanvec_matrices_{std::move(matrices)} {} + , leanvec_matrices_{std::move(matrices)} { + if (!svs::detail::lvq_leanvec_enabled()) { + throw StatusException( + ErrorCode::NOT_IMPLEMENTED, "LeanVec is not supported by CPU." + ); + } + } LeanVecTrainingDataImpl( const svs::data::ConstSimpleDataView& data, size_t leanvec_dims ) - : leanvec_dims_{leanvec_dims} - , leanvec_matrices_{compute_leanvec_matrices(data, leanvec_dims)} {} + : LeanVecTrainingDataImpl(compute_leanvec_matrices(data, leanvec_dims)) {} LeanVecTrainingDataImpl( const svs::data::ConstSimpleDataView& data, const svs::data::ConstSimpleDataView& queries, size_t leanvec_dims ) - : leanvec_dims_{leanvec_dims} - , leanvec_matrices_{compute_leanvec_matrices_ood(data, queries, leanvec_dims)} {} + : LeanVecTrainingDataImpl(compute_leanvec_matrices_ood(data, queries, leanvec_dims) + ) {} size_t get_leanvec_dims() const { return leanvec_dims_; } const LeanVecMatricesType& get_leanvec_matrices() const { return leanvec_matrices_; } diff --git a/bindings/cpp/src/vamana_index.cpp b/bindings/cpp/src/vamana_index.cpp index 39809b195..c015dd21e 100644 --- a/bindings/cpp/src/vamana_index.cpp +++ b/bindings/cpp/src/vamana_index.cpp @@ -16,11 +16,200 @@ #include "svs/runtime/vamana_index.h" +#include "svs_runtime_utils.h" +#include "vamana_index_impl.h" + namespace svs { namespace runtime { +namespace { +template +struct VamanaIndexManagerBase : public VamanaIndex { + std::unique_ptr impl_; + + VamanaIndexManagerBase(std::unique_ptr impl) + : impl_{std::move(impl)} { + assert(impl_ != nullptr); + } + + VamanaIndexManagerBase(const VamanaIndexManagerBase&) = delete; + VamanaIndexManagerBase& operator=(const VamanaIndexManagerBase&) = delete; + VamanaIndexManagerBase(VamanaIndexManagerBase&&) = default; + VamanaIndexManagerBase& operator=(VamanaIndexManagerBase&&) = default; + ~VamanaIndexManagerBase() override = default; + + Status add(size_t n, const float* x) noexcept override { + return runtime_error_wrapper([&] { + svs::data::ConstSimpleDataView data{x, n, impl_->dimensions()}; + impl_->add(data); + return Status_Ok; + }); + } + + Status reset() noexcept override { + return runtime_error_wrapper([&] { + impl_->reset(); + return Status_Ok; + }); + } + + Status search( + size_t n, + const float* x, + size_t k, + float* distances, + size_t* labels, + const SearchParams* params = nullptr, + IDFilter* filter = nullptr + ) const noexcept override { + return runtime_error_wrapper([&] { + auto result = svs::QueryResultView{ + svs::MatrixView{svs::make_dims(n, k), labels}, + svs::MatrixView{svs::make_dims(n, k), distances}}; + auto queries = svs::data::ConstSimpleDataView(x, n, impl_->dimensions()); + impl_->search(result, queries, params, filter); + }); + } + + Status range_search( + size_t n, + const float* x, + float radius, + const ResultsAllocator& results, + const SearchParams* params = nullptr, + IDFilter* filter = nullptr + ) const noexcept override { + return runtime_error_wrapper([&] { + auto queries = svs::data::ConstSimpleDataView(x, n, impl_->dimensions()); + impl_->range_search(queries, radius, results, params, filter); + }); + } + + Status save(std::ostream& out) const noexcept override { + return runtime_error_wrapper([&] { impl_->save(out); }); + } +}; +} // namespace + // VamanaIndex interface implementation VamanaIndex::~VamanaIndex() = default; +Status VamanaIndex::check_storage_kind(StorageKind storage_kind) noexcept { + bool supported = false; + auto status = runtime_error_wrapper([&] { + supported = storage::is_supported_storage_kind(storage_kind); + }); + if (!status.ok()) { + return status; + } + return supported + ? Status_Ok + : Status( + ErrorCode::INVALID_ARGUMENT, + "The specified storage kind is not compatible with the VamanaIndex" + ); +} + +Status VamanaIndex::build( + VamanaIndex** index, + size_t dim, + MetricType metric, + StorageKind storage_kind, + const VamanaIndex::BuildParams& params, + const VamanaIndex::SearchParams& default_search_params +) noexcept { + using Impl = VamanaIndexImpl; + *index = nullptr; + + return runtime_error_wrapper([&] { + auto impl = std::make_unique( + dim, metric, storage_kind, params, default_search_params + ); + *index = new VamanaIndexManagerBase{std::move(impl)}; + }); +} + +Status VamanaIndex::destroy(VamanaIndex* index) noexcept { + return runtime_error_wrapper([&] { delete index; }); +} + +Status VamanaIndex::load( + VamanaIndex** index, std::istream& in, MetricType metric, StorageKind storage_kind +) noexcept { + using Impl = VamanaIndexImpl; + *index = nullptr; + return runtime_error_wrapper([&] { + std::unique_ptr impl{Impl::load(in, metric, storage_kind)}; + *index = new VamanaIndexManagerBase{std::move(impl)}; + }); +} + +#ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC +// Specialization to build LeanVec-based Vamana index with specified leanvec dims +Status VamanaIndexLeanVec::build( + VamanaIndex** index, + size_t dim, + MetricType metric, + StorageKind storage_kind, + size_t leanvec_dims, + const VamanaIndex::BuildParams& params, + const VamanaIndex::SearchParams& default_search_params +) noexcept { + using Impl = VamanaIndexLeanVecImpl; + *index = nullptr; + + return runtime_error_wrapper([&] { + auto impl = std::make_unique( + dim, metric, storage_kind, leanvec_dims, params, default_search_params + ); + *index = new VamanaIndexManagerBase{std::move(impl)}; + }); +} + +// Specialization to build LeanVec-based Vamana index with provided training data +Status VamanaIndexLeanVec::build( + VamanaIndex** index, + size_t dim, + MetricType metric, + StorageKind storage_kind, + const LeanVecTrainingData* training_data, + const VamanaIndex::BuildParams& params, + const VamanaIndex::SearchParams& default_search_params +) noexcept { + using Impl = VamanaIndexLeanVecImpl; + *index = nullptr; + + return runtime_error_wrapper([&] { + if (training_data == nullptr) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "Training data must not be null"}; + } + auto training_data_impl = + static_cast(training_data)->impl_; + auto impl = std::make_unique( + dim, metric, storage_kind, training_data_impl, params, default_search_params + ); + *index = new VamanaIndexManagerBase{std::move(impl)}; + }); +} + +#else // SVS_RUNTIME_HAVE_LVQ_LEANVEC +// LeanVec storage kind is not supported in this build configuration +Status VamanaIndexLeanVec:: + build(VamanaIndex**, size_t, MetricType, StorageKind, size_t, const VamanaIndex::BuildParams&, const VamanaIndex::SearchParams&) noexcept { + return Status( + ErrorCode::NOT_IMPLEMENTED, + "VamanaIndexLeanVec is not supported in this build configuration." + ); +} + +Status VamanaIndexLeanVec:: + build(VamanaIndex**, size_t, MetricType, StorageKind, const LeanVecTrainingData*, const VamanaIndex::BuildParams&, const VamanaIndex::SearchParams&) noexcept { + return Status( + ErrorCode::NOT_IMPLEMENTED, + "VamanaIndexLeanVec is not supported in this build configuration." + ); +} +#endif // SVS_RUNTIME_HAVE_LVQ_LEANVEC } // namespace runtime } // namespace svs diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h new file mode 100644 index 000000000..4cf58d7e0 --- /dev/null +++ b/bindings/cpp/src/vamana_index_impl.h @@ -0,0 +1,601 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "svs/runtime/vamana_index.h" + +#include "svs_runtime_utils.h" + +#ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC +#include "training_impl.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace svs { +namespace runtime { + +// Vamana index implementation +class VamanaIndexImpl { + using allocator_type = svs::lib::Allocator; + + public: + VamanaIndexImpl( + size_t dim, + MetricType metric, + StorageKind storage_kind, + const VamanaIndex::BuildParams& build_params, + const VamanaIndex::SearchParams& default_search_params + ) + : dim_{dim} + , metric_type_{metric} + , storage_kind_{storage_kind} + , build_params_{build_params} + , default_search_params_{default_search_params} { + if (!storage::is_supported_storage_kind(storage_kind)) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, + "The specified storage kind is not compatible with the " + "VamanaIndex"}; + } + } + + size_t size() const { return impl_ ? get_impl()->size() : 0; } + + size_t dimensions() const { return dim_; } + + MetricType metric_type() const { return metric_type_; } + + StorageKind get_storage_kind() const { return storage_kind_; } + + void add(const data::ConstSimpleDataView& data) { + if (!impl_) { + return init_impl(data); + } + + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, + "Vamana index does not support adding points after initialization"}; + } + + void search( + svs::QueryResultView result, + svs::data::ConstSimpleDataView queries, + const VamanaIndex::SearchParams* params = nullptr, + IDFilter* filter = nullptr + ) const { + if (!impl_) { + auto& dists = result.distances(); + std::fill(dists.begin(), dists.end(), Unspecify()); + auto& inds = result.indices(); + std::fill(inds.begin(), inds.end(), Unspecify()); + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + + if (queries.size() == 0) { + return; + } + + const size_t k = result.n_neighbors(); + if (k == 0) { + throw StatusException{ErrorCode::INVALID_ARGUMENT, "k must be greater than 0"}; + } + + auto sp = make_search_parameters(params); + + // Simple search + if (filter == nullptr) { + get_impl()->search(result, queries, sp); + return; + } + + // Selective search with IDSelector + auto old_sp = get_impl()->get_search_parameters(); + auto sp_restore = svs::lib::make_scope_guard([&]() noexcept { + get_impl()->set_search_parameters(old_sp); + }); + get_impl()->set_search_parameters(sp); + + auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) { + for (auto i : range) { + // For every query + auto query = queries.get_datum(i); + auto iterator = get_impl()->batch_iterator(query); + size_t found = 0; + do { + iterator.next(k); + for (auto& neighbor : iterator.results()) { + if (filter->is_member(neighbor.id())) { + result.set(neighbor, i, found); + found++; + if (found == k) { + break; + } + } + } + } while (found < k && !iterator.done()); + + // Pad results if not enough neighbors found + if (found < k) { + for (size_t j = found; j < k; ++j) { + result.set(Neighbor{Unspecify(), Unspecify()}, i, j); + } + } + } + }; + + auto threadpool = default_threadpool(); + + svs::threads::parallel_for( + threadpool, svs::threads::StaticPartition{queries.size()}, search_closure + ); + } + + void range_search( + svs::data::ConstSimpleDataView queries, + float radius, + const ResultsAllocator& results, + const VamanaIndex::SearchParams* params = nullptr, + IDFilter* filter = nullptr + ) const { + if (radius <= 0) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "radius must be greater than 0"}; + } + + const size_t n = queries.size(); + if (n == 0) { + return; + } + + auto sp = make_search_parameters(params); + auto old_sp = get_impl()->get_search_parameters(); + auto sp_restore = svs::lib::make_scope_guard([&]() noexcept { + get_impl()->set_search_parameters(old_sp); + }); + get_impl()->set_search_parameters(sp); + + // Using ResultHandler makes no sense due to it's complexity, overhead and + // missed features; e.g. add_result() does not indicate whether result added + // or not - we have to manually manage threshold comparison and id + // selection. + + // Prepare output buffers + std::vector>> all_results(n); + // Reserve space for allocation to avoid multiple reallocations + // Use search_buffer_capacity as a heuristic + const auto result_capacity = sp.buffer_config_.get_total_capacity(); + for (auto& res : all_results) { + res.reserve(result_capacity); + } + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric_type_)); + + std::function compare = distance_dispatcher([](auto&& dist) { + return std::function{svs::distance::comparator(dist)}; + }); + + std::function select = [](size_t) { return true; }; + if (filter != nullptr) { + select = [&](size_t id) { return filter->is_member(id); }; + } + + // Set iterator batch size to search window size + auto batch_size = sp.buffer_config_.get_search_window_size(); + // Ensure batch size is at least 10 to avoid excessive overhead of small batches + batch_size = std::max(batch_size, size_t(10)); + + auto range_search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) { + for (auto i : range) { + // For every query + auto query = queries.get_datum(i); + + auto iterator = get_impl()->batch_iterator(query); + bool in_range = true; + + do { + iterator.next(batch_size); + for (auto& neighbor : iterator.results()) { + // SVS comparator functor returns true if the first distance + // is 'closer' than the second one + in_range = compare(neighbor.distance(), radius); + if (in_range) { + // Selective search with IDSelector + if (select(neighbor.id())) { + all_results[i].push_back(neighbor); + } + } else { + // Since iterator.results() are ordered by distance, we + // can stop processing + break; + } + } + } while (in_range && !iterator.done()); + } + }; + + auto threadpool = default_threadpool(); + + svs::threads::parallel_for( + threadpool, svs::threads::StaticPartition{n}, range_search_closure + ); + + // Allocate output + std::vector result_counts(n); + std::transform( + all_results.begin(), + all_results.end(), + result_counts.begin(), + [](const auto& res) { return res.size(); } + ); + auto results_storage = results(result_counts); + + // Fill in results + for (size_t q = 0, ofs = 0; q < n; ++q) { + for (const auto& [id, distance] : all_results[q]) { + results_storage.labels[ofs] = id; + results_storage.distances[ofs] = distance; + ofs++; + } + } + } + + void reset() { impl_.reset(); } + + void save(std::ostream& out) const { + lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + get_impl()->save(config_dir, graph_dir, data_dir); + lib::DirectoryArchiver::pack(tempdir, out); + } + + protected: + // Utility functions + svs::Vamana* get_impl() const { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + return impl_.get(); + } + + svs::index::vamana::VamanaBuildParameters vamana_build_parameters() const { + svs::index::vamana::VamanaBuildParameters result; + set_if_specified(result.alpha, build_params_.alpha); + set_if_specified(result.graph_max_degree, build_params_.graph_max_degree); + set_if_specified(result.window_size, build_params_.construction_window_size); + set_if_specified( + result.max_candidate_pool_size, build_params_.max_candidate_pool_size + ); + set_if_specified(result.prune_to, build_params_.prune_to); + if (is_specified(build_params_.use_full_search_history)) { + result.use_full_search_history = + build_params_.use_full_search_history.is_enabled(); + } + return result; + } + + svs::index::vamana::VamanaSearchParameters + make_search_parameters(const VamanaIndex::SearchParams* params) const { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + + // Copy default search parameters + auto search_params = default_search_params_; + // Update with user-specified parameters + if (params) { + set_if_specified(search_params.search_window_size, params->search_window_size); + set_if_specified( + search_params.search_buffer_capacity, params->search_buffer_capacity + ); + set_if_specified(search_params.prefetch_lookahead, params->prefetch_lookahead); + set_if_specified(search_params.prefetch_step, params->prefetch_step); + } + + // Get current search parameters from the index + auto result = impl_->get_search_parameters(); + // Update with specified parameters + if (is_specified(search_params.search_window_size)) { + if (is_specified(search_params.search_buffer_capacity)) { + result.buffer_config( + {search_params.search_window_size, search_params.search_buffer_capacity} + ); + } else { + result.buffer_config(search_params.search_window_size); + } + } else if (is_specified(search_params.search_buffer_capacity)) { + result.buffer_config(search_params.search_buffer_capacity); + } + + set_if_specified(result.prefetch_lookahead_, search_params.prefetch_lookahead); + set_if_specified(result.prefetch_step_, search_params.prefetch_step); + + return result; + } + + template + static svs::Vamana* build_impl( + Tag&& tag, + MetricType metric, + const index::vamana::VamanaBuildParameters& parameters, + const svs::data::ConstSimpleDataView& data, + StorageArgs&&... storage_args + ) { + auto threadpool = default_threadpool(); + using storage_alloc_t = typename Tag::allocator_type; + auto allocator = storage::make_allocator(); + + auto storage = make_storage( + std::forward(tag), + data, + threadpool, + allocator, + std::forward(storage_args)... + ); + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + return distance_dispatcher([&](auto&& distance) { + return new svs::Vamana(svs::Vamana::build( + parameters, + std::move(storage), + std::forward(distance), + std::move(threadpool) + )); + }); + } + + virtual void init_impl(const data::ConstSimpleDataView& data) { + impl_.reset(storage::dispatch_storage_kind( + get_storage_kind(), + [&](auto&& tag, const data::ConstSimpleDataView& data) { + using Tag = std::decay_t; + return build_impl( + std::forward(tag), + this->metric_type_, + this->vamana_build_parameters(), + data + ); + }, + data + )); + get_impl()->set_search_parameters(make_search_parameters(&default_search_params_)); + } + + // Constructor used during loading + VamanaIndexImpl( + std::unique_ptr&& impl, MetricType metric, StorageKind storage_kind + ) + : dim_{0} + , metric_type_{metric} + , storage_kind_{storage_kind} + , build_params_{} + , default_search_params_{} + , impl_{std::move(impl)} { + if (impl_) { + dim_ = impl_->dimensions(); + const auto& buffer_config = impl_->get_search_parameters().buffer_config_; + default_search_params_ = { + buffer_config.get_search_window_size(), buffer_config.get_total_capacity()}; + build_params_ = VamanaIndex::BuildParams{ + impl_->get_graph_max_degree(), + impl_->get_prune_to(), + impl_->get_alpha(), + impl_->get_construction_window_size(), + impl_->get_max_candidates(), + impl_->get_full_search_history()}; + } + } + + template + static svs::Vamana* load_impl_t(Tag&& tag, std::istream& stream, MetricType metric) { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw StatusException{ + ErrorCode::RUNTIME_ERROR, + "Invalid Vamana index archive: missing config directory!"}; + } + + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw StatusException{ + ErrorCode::RUNTIME_ERROR, + "Invalid Vamana index archive: missing graph directory!"}; + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw StatusException{ + ErrorCode::RUNTIME_ERROR, + "Invalid Vamana index archive: missing data directory!"}; + } + + auto storage = storage::load_storage(std::forward(tag), data_path); + auto threadpool = default_threadpool(); + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + + return distance_dispatcher([&](auto&& distance) { + return new svs::Vamana(svs::Vamana::assemble( + config_path, + svs::GraphLoader{graph_path}, + std::move(storage), + std::forward(distance), + std::move(threadpool) + )); + }); + } + + public: + static VamanaIndexImpl* + load(std::istream& stream, MetricType metric, StorageKind storage_kind) { + return storage::dispatch_storage_kind( + storage_kind, + [&](auto&& tag, std::istream& stream, MetricType metric) { + using Tag = std::decay_t; + std::unique_ptr impl{ + load_impl_t(std::forward(tag), stream, metric)}; + + return new VamanaIndexImpl(std::move(impl), metric, storage_kind); + }, + stream, + metric + ); + } + + // Data members + protected: + size_t dim_; + MetricType metric_type_; + StorageKind storage_kind_; + VamanaIndex::BuildParams build_params_; + VamanaIndex::SearchParams default_search_params_; + std::unique_ptr impl_; +}; + +#ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC +struct VamanaIndexLeanVecImpl : public VamanaIndexImpl { + using LeanVecMatricesType = LeanVecTrainingDataImpl::LeanVecMatricesType; + using allocator_type = svs::lib::Allocator; + + VamanaIndexLeanVecImpl( + std::unique_ptr&& impl, MetricType metric, StorageKind storage_kind + ) + : VamanaIndexImpl{std::move(impl), metric, storage_kind} + , leanvec_dims_{0} + , leanvec_matrices_{std::nullopt} { + check_storage_kind(storage_kind); + } + + VamanaIndexLeanVecImpl( + size_t dim, + MetricType metric, + StorageKind storage_kind, + const LeanVecTrainingDataImpl& training_data, + const VamanaIndex::BuildParams& params, + const VamanaIndex::SearchParams& default_search_params + ) + : VamanaIndexImpl{dim, metric, storage_kind, params, default_search_params} + , leanvec_dims_{training_data.get_leanvec_dims()} + , leanvec_matrices_{training_data.get_leanvec_matrices()} { + check_storage_kind(storage_kind); + } + + VamanaIndexLeanVecImpl( + size_t dim, + MetricType metric, + StorageKind storage_kind, + size_t leanvec_dims, + const VamanaIndex::BuildParams& params, + const VamanaIndex::SearchParams& default_search_params + ) + : VamanaIndexImpl{dim, metric, storage_kind, params, default_search_params} + , leanvec_dims_{leanvec_dims} + , leanvec_matrices_{std::nullopt} { + check_storage_kind(storage_kind); + } + + template + static auto dispatch_leanvec_storage_kind(StorageKind kind, F&& f, Args&&... args) { + switch (kind) { + case StorageKind::LeanVec4x4: + return f( + storage::StorageType{}, + std::forward(args)... + ); + case StorageKind::LeanVec4x8: + return f( + storage::StorageType{}, + std::forward(args)... + ); + case StorageKind::LeanVec8x8: + return f( + storage::StorageType{}, + std::forward(args)... + ); + default: + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "SVS LeanVec storage kind required"}; + } + } + + void init_impl(const data::ConstSimpleDataView& data) override { + assert(storage::is_leanvec_storage(this->storage_kind_)); + impl_.reset(dispatch_leanvec_storage_kind( + this->storage_kind_, + [&](auto&& tag, const data::ConstSimpleDataView& data) { + using Tag = std::decay_t; + return VamanaIndexImpl::build_impl( + std::forward(tag), + this->metric_type_, + this->vamana_build_parameters(), + data, + leanvec_dims_, + leanvec_matrices_ + ); + }, + data + )); + impl_->set_search_parameters(make_search_parameters(&default_search_params_)); + } + + protected: + size_t leanvec_dims_; + std::optional leanvec_matrices_; + + StorageKind check_storage_kind(StorageKind kind) { + if (!storage::is_leanvec_storage(kind)) { + throw StatusException( + ErrorCode::INVALID_ARGUMENT, "SVS LeanVec storage kind required" + ); + } + if (!svs::detail::lvq_leanvec_enabled()) { + throw StatusException( + ErrorCode::NOT_IMPLEMENTED, + "LeanVec storage kind requested but not supported by CPU" + ); + } + return kind; + } +}; +#endif // SVS_RUNTIME_HAVE_LVQ_LEANVEC + +} // namespace runtime +} // namespace svs diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index ec6f309d9..201375d3c 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -92,10 +92,9 @@ void write_and_read_index( svs::runtime::v0::Status status = build_func(&index); // Stop here if storage kind is not supported on this platform - if constexpr (std::is_same_v) { + if constexpr (std::is_base_of_v) { if (storage_kind.has_value()) { - if (!svs::runtime::v0::DynamicVamanaIndex::check_storage_kind(*storage_kind) - .ok()) { + if (!Index::check_storage_kind(*storage_kind).ok()) { CATCH_REQUIRE(!status.ok()); return; } @@ -105,7 +104,7 @@ void write_and_read_index( CATCH_REQUIRE(index != nullptr); // Add data to index - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { status = index->add(n, xb.data()); } else { std::vector labels(n); @@ -675,3 +674,210 @@ CATCH_TEST_CASE("SetIfSpecifiedUtility", "[runtime]") { CATCH_REQUIRE(target == false); } } + +CATCH_TEST_CASE("WriteAndReadStaticIndexSVS", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + }; + write_and_read_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP32 + ); +} + +CATCH_TEST_CASE("WriteAndReadStaticIndexSVSFP16", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP16, + build_params + ); + }; + write_and_read_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP16 + ); +} + +CATCH_TEST_CASE("WriteAndReadStaticIndexSVSSQI8", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::SQI8, + build_params + ); + }; + write_and_read_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::SQI8 + ); +} + +CATCH_TEST_CASE("WriteAndReadStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LVQ4x4, + build_params + ); + }; + write_and_read_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LVQ4x4 + ); +} + +CATCH_TEST_CASE("WriteAndReadStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndexLeanVec::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LeanVec4x4, + 32, + build_params + ); + }; + write_and_read_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LeanVec4x4 + ); +} + +CATCH_TEST_CASE("StaticIndexLeanVecWithTrainingData", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + const size_t leanvec_dims = 32; + // Build LeanVec index with explicit training + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + + // Prepare training data + svs::runtime::v0::LeanVecTrainingData* training_data = nullptr; + svs::runtime::v0::Status status = svs::runtime::v0::LeanVecTrainingData::build( + &training_data, test_d, test_n, test_data.data(), leanvec_dims + ); + if (!svs::runtime::v0::VamanaIndexLeanVec::check_storage_kind( + svs::runtime::v0::StorageKind::LeanVec4x4 + ) + .ok()) { + CATCH_REQUIRE(!status.ok()); + CATCH_SKIP("Storage kind is not supported, skipping test."); + } + + status = svs::runtime::v0::VamanaIndexLeanVec::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LeanVec4x4, + training_data, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + svs::runtime::v0::VamanaIndex::destroy(index); + svs::runtime::v0::LeanVecTrainingData::destroy(training_data); +} + +CATCH_TEST_CASE("SearchWithIDFilterStatic", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::VamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Second attempt to add data should fail on static index + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(!status.ok()); + + const int nq = 8; + const float* xq = test_data.data(); + const int k = 10; + + size_t min_id = test_n / 5; + size_t max_id = test_n * 4 / 5; + test_utils::IDFilterRange selector(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = index->search( + nq, xq, k, distances.data(), result_labels.data(), nullptr, &selector + ); + CATCH_REQUIRE(status.ok()); + + // All returned labels must fall inside the selected range + for (int i = 0; i < nq * k; ++i) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + + svs::runtime::v0::VamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::VamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + + // Small radius search + test_utils::TestResultsAllocator allocator_small; + status = index->range_search(nq, xq, 0.05f, allocator_small); + CATCH_REQUIRE(status.ok()); + + // Larger radius to exercise loop continuation + test_utils::TestResultsAllocator allocator_big; + status = index->range_search(nq, xq, 5.0f, allocator_big); + CATCH_REQUIRE(status.ok()); + + svs::runtime::v0::VamanaIndex::destroy(index); +} diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 0fcb31bbb..712dc2e68 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -584,6 +584,9 @@ template class Blocked { Alloc allocator_{}; }; +template inline constexpr bool is_blocked_v = false; +template inline constexpr bool is_blocked_v> = true; + /// /// @brief A specialization of ``SimpleData`` for large-scale dynamic datasets. ///