diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index 90d23d9aef..4782db59e5 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -278,6 +278,57 @@ if(CUVS_KNN_BENCH_USE_CUVS_BRUTE_FORCE) ) endif() +# Cluster assignment benchmark: brute force vs CAGRA for assigning vectors to clusters (IVF +# training) +if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) + add_executable(CUVS_CLUSTER_ASSIGNMENT_BENCH src/cuvs/cuvs_cluster_assignment_bench.cu) + target_link_libraries( + CUVS_CLUSTER_ASSIGNMENT_BENCH + PRIVATE cuvs benchmark::benchmark $<$:CUDA::nvtx3> + $ + ) + target_include_directories( + CUVS_CLUSTER_ASSIGNMENT_BENCH + PUBLIC "$" + "$" + PRIVATE "$" + ) + set_target_properties( + CUVS_CLUSTER_ASSIGNMENT_BENCH + PROPERTIES CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + install( + TARGETS CUVS_CLUSTER_ASSIGNMENT_BENCH + COMPONENT ann_bench + DESTINATION bin/ann + ) + add_dependencies(CUVS_ANN_BENCH_ALL CUVS_CLUSTER_ASSIGNMENT_BENCH) + + # IVF-PQ build benchmarks: k-means fit and extend cluster-assignment (brute vs CAGRA) + add_executable(CUVS_IVFPQ_BUILD_BENCH src/cuvs/cuvs_ivf_pq_build_bench.cu) + target_link_libraries(CUVS_IVFPQ_BUILD_BENCH PRIVATE cuvs benchmark::benchmark) + target_include_directories( + CUVS_IVFPQ_BUILD_BENCH PUBLIC "$" + "$" + ) + set_target_properties( + CUVS_IVFPQ_BUILD_BENCH + PROPERTIES CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON + ) + install( + TARGETS CUVS_IVFPQ_BUILD_BENCH + COMPONENT ann_bench + DESTINATION bin/ann + ) + add_dependencies(CUVS_ANN_BENCH_ALL CUVS_IVFPQ_BUILD_BENCH) +endif() + if(CUVS_ANN_BENCH_USE_CUVS_CAGRA) ConfigureAnnBench( NAME diff --git a/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu new file mode 100644 index 0000000000..791d5b5025 --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_cluster_assignment_bench.cu @@ -0,0 +1,292 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + * Benchmark: brute force vs CAGRA-based cluster assignment for IVF training. + * Compares time to assign N vectors to K clusters (nearest centroid) using + * (1) brute force 1-NN and (2) CAGRA build on centroids + k=1 search. + */ +#include + +// kmeans_balanced.cuh is under cpp/src/; CUVS_CLUSTER_ASSIGNMENT_BENCH adds that to include path +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace { + +using namespace cuvs::cluster::kmeans_balanced; + +void init_random_data(raft::resources const& handle, + float* X, + int64_t n_rows, + int64_t dim, + float* centroids, + int64_t n_clusters) +{ + raft::random::RngState rng(12345ULL); + raft::random::uniform(handle, rng, X, n_rows * dim, float(-1), float(1)); + raft::random::uniform(handle, rng, centroids, n_clusters * dim, float(-1), float(1)); + raft::resource::sync_stream(handle); +} + +} // namespace + +static void BM_ClusterAssignment_BruteForce(benchmark::State& state) +{ + int64_t n_rows = static_cast(state.range(0)); + int64_t n_clusters = static_cast(state.range(1)); + int64_t dim = static_cast(state.range(2)); + + raft::device_resources handle; + rmm::device_uvector X(static_cast(n_rows) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector centroids(static_cast(n_clusters) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector labels(static_cast(n_rows), + raft::resource::get_cuda_stream(handle)); + + init_random_data(handle, X.data(), n_rows, dim, centroids.data(), n_clusters); + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + + auto X_view = raft::make_device_matrix_view(X.data(), n_rows, dim); + auto centers_view = + raft::make_device_matrix_view(centroids.data(), n_clusters, dim); + auto labels_view = raft::make_device_vector_view(labels.data(), n_rows); + + for (auto _ : state) { + predict(handle, params, X_view, centers_view, labels_view); + raft::resource::sync_stream(handle); + } + state.SetItemsProcessed(state.iterations() * n_rows); +} + +static void BM_ClusterAssignment_CAGRA(benchmark::State& state) +{ + int64_t n_rows = static_cast(state.range(0)); + int64_t n_clusters = static_cast(state.range(1)); + int64_t dim = static_cast(state.range(2)); + + raft::device_resources handle; + rmm::device_uvector X(static_cast(n_rows) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector centroids(static_cast(n_clusters) * static_cast(dim), + raft::resource::get_cuda_stream(handle)); + rmm::device_uvector labels(static_cast(n_rows), + raft::resource::get_cuda_stream(handle)); + + init_random_data(handle, X.data(), n_rows, dim, centroids.data(), n_clusters); + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + + // Same timing as assign_nearest_centroid_cagra_with_index_reuse with rebuild=true each iteration. + // float X/centroids only. + std::optional> cagra_index_opt; + + for (auto _ : state) { + cuvs::cluster::kmeans::detail::assign_nearest_centroid_cagra_with_index_reuse( + handle, + params, + centroids.data(), + n_clusters, + dim, + X.data(), + n_rows, + labels.data(), + &cagra_index_opt, + true); + raft::resource::sync_stream(handle); + } + state.SetItemsProcessed(state.iterations() * n_rows); +} + +// N = vectors to assign, K = number of clusters, D = dimension +// Small: 10K vectors, 1K clusters, 128 dim +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({10000, 1000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({10000, 1000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Medium: 100K vectors, 4K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({100000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({100000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Large K: 100K vectors, 16K clusters (brute force starts to hurt) +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({100000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({100000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Very large K: 500K vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({500000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({500000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Larger N: amortize CAGRA build over more queries +// 1M vectors, 4K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 4000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 16K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 2M vectors, 16K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 2M vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 5M vectors, 16K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({5000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({5000000, 16000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 5M vectors, 64K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({5000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({5000000, 65536, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// Hundreds of thousands of centroids (K = 100K, 200K, 500K, 1M) +// 1M vectors, 100K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 2M vectors, 100K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 100000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 200K clusters +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 200000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 200000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M vectors, 500K clusters (~2 vectors per cluster) +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({1000000, 500000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({1000000, 500000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 1M clusters with N > K (realistic: many vectors per cluster) +// 2M vectors, 1M clusters (~2 per cluster) +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({2000000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({2000000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +// 5M vectors, 1M clusters (~5 per cluster) +BENCHMARK(BM_ClusterAssignment_BruteForce) + ->Args({5000000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); +BENCHMARK(BM_ClusterAssignment_CAGRA) + ->Args({5000000, 1000000, 128}) + ->Unit(benchmark::kMillisecond) + ->UseRealTime(); + +BENCHMARK_MAIN(); diff --git a/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu new file mode 100644 index 0000000000..a39edf3f81 --- /dev/null +++ b/cpp/bench/ann/src/cuvs/cuvs_ivf_pq_build_bench.cu @@ -0,0 +1,366 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + * IVF-PQ nearest-centroid lookup benchmarks (three scenarios): + * + * All compare brute force vs CAGRA for the same primitive: given centroid vectors, find the + * nearest centroid for each data vector. CAGRA never computes centroids; it only accelerates + * lookup. + * + * - Build fit: lookup during k-means fit E-step (`use_ann_for_build_fit`). Centroids move each + * EM iteration; CAGRA index is rebuilt periodically. Times full `build()`. + * - Build post-fit: lookup after fit to label the train subsample for PQ codebooks + * (`use_ann_for_build_postfit`). Fixed centroids; CAGRA built once per build. Times full + * `build()`. + * - Extend: fixed trained centroids (`use_ann_for_extend`). CAGRA built once per extend batch + * loop. Times `extend()` only; setup deserialize is not timed. + */ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +struct BenchRanges { + int64_t n_rows; + uint32_t n_lists; + int64_t dim; +}; + +BenchRanges read_ranges(benchmark::State const& state) +{ + return BenchRanges{static_cast(state.range(0)), + static_cast(state.range(1)), + static_cast(state.range(2))}; +} + +void init_random_dataset(raft::resources const& handle, float* data, int64_t n_rows, int64_t dim) +{ + raft::random::RngState rng(12345ULL); + raft::random::uniform(handle, rng, data, n_rows * dim, float(-1), float(1)); + raft::resource::sync_stream(handle); +} + +/** Shared dataset + handle setup for both benchmarks. */ +struct DatasetFixture { + raft::device_resources handle; + rmm::device_uvector dataset; + raft::device_matrix_view view; + + explicit DatasetFixture(BenchRanges const& ranges) + : handle{}, + dataset(static_cast(ranges.n_rows) * static_cast(ranges.dim), + raft::resource::get_cuda_stream(handle)) + { + init_random_dataset(handle, dataset.data(), ranges.n_rows, ranges.dim); + raft::resource::set_cuda_stream_pool(handle, std::make_shared(1)); + view = raft::make_device_matrix_view( + dataset.data(), ranges.n_rows, ranges.dim); + } +}; + +void set_common_index_params(cuvs::neighbors::ivf_pq::index_params& params, uint32_t n_lists) +{ + params.n_lists = n_lists; + params.kmeans_n_iters = 3; + params.kmeans_trainset_fraction = 0.2; + params.metric = cuvs::distance::DistanceType::L2Expanded; +} + +/** build() fit E-step only: toggles use_ann_for_build_fit; post-fit and extend stay brute. */ +cuvs::neighbors::ivf_pq::index_params make_build_fit_lookup_params(uint32_t n_lists, + bool use_ann_for_build_fit) +{ + cuvs::neighbors::ivf_pq::index_params params; + set_common_index_params(params, n_lists); + params.add_data_on_build = true; + params.use_ann_for_build_fit = use_ann_for_build_fit; + params.use_ann_for_build_postfit = false; + params.use_ann_for_extend = false; + return params; +} + +/** build() post-fit predict only: toggles use_ann_for_build_postfit; fit E-step and extend stay + * brute. */ +cuvs::neighbors::ivf_pq::index_params make_build_postfit_lookup_params( + uint32_t n_lists, bool use_ann_for_build_postfit) +{ + cuvs::neighbors::ivf_pq::index_params params; + set_common_index_params(params, n_lists); + params.add_data_on_build = true; + params.use_ann_for_build_fit = false; + params.use_ann_for_build_postfit = use_ann_for_build_postfit; + params.use_ann_for_extend = false; + return params; +} + +cuvs::neighbors::ivf_pq::index_params make_extend_lookup_params(uint32_t n_lists, + bool use_ann_for_extend) +{ + cuvs::neighbors::ivf_pq::index_params params; + set_common_index_params(params, n_lists); + params.add_data_on_build = false; + params.use_ann_for_build_fit = false; + params.use_ann_for_build_postfit = false; + params.use_ann_for_extend = use_ann_for_extend; + return params; +} + +std::string serialize_index_blob(raft::resources const& handle, + cuvs::neighbors::ivf_pq::index const& index) +{ + std::ostringstream os(std::ios::binary); + cuvs::neighbors::ivf_pq::serialize(handle, os, index); + raft::resource::sync_stream(handle); + return os.str(); +} + +void deserialize_index_from_blob(raft::resources const& handle, + std::string const& blob, + cuvs::neighbors::ivf_pq::index* index) +{ + std::istringstream is(blob, std::ios::binary); + cuvs::neighbors::ivf_pq::deserialize(handle, is, index); + raft::resource::sync_stream(handle); +} + +/** Serialized empty trained indices for per-iteration extend reset (setup, not timed). */ +struct ExtendIndexSnapshots { + std::string empty_index_snapshot_bf; + std::string empty_index_snapshot_cagra; + + static ExtendIndexSnapshots create( + raft::resources const& handle, + cuvs::neighbors::ivf_pq::index_params const& params_bf, + cuvs::neighbors::ivf_pq::index_params const& params_cagra, + raft::device_matrix_view dataset_view) + { + ExtendIndexSnapshots snapshots; + auto idx_bf = cuvs::neighbors::ivf_pq::build(handle, params_bf, dataset_view); + auto idx_cagra = cuvs::neighbors::ivf_pq::build(handle, params_cagra, dataset_view); + snapshots.empty_index_snapshot_bf = serialize_index_blob(handle, idx_bf); + snapshots.empty_index_snapshot_cagra = serialize_index_blob(handle, idx_cagra); + return snapshots; + } + + void restore(raft::resources const& handle, + cuvs::neighbors::ivf_pq::index* index_bf, + cuvs::neighbors::ivf_pq::index* index_cagra) const + { + deserialize_index_from_blob(handle, empty_index_snapshot_bf, index_bf); + deserialize_index_from_blob(handle, empty_index_snapshot_cagra, index_cagra); + } +}; + +double time_synced_ms(raft::resources const& handle, std::function const& op) +{ + auto start = std::chrono::steady_clock::now(); + op(); + raft::resource::sync_stream(handle); + auto end = std::chrono::steady_clock::now(); + return 1e-6 * std::chrono::duration(end - start).count(); +} + +/** Run brute vs CAGRA paths each iteration; `before_timed_work` runs outside the chrono window. */ +template +void accumulate_bf_vs_cagra_ms(benchmark::State& state, + raft::resources const& handle, + PreIterationFn&& before_timed_work, + BfFn&& run_bf, + CagraFn&& run_cagra, + double& total_bf_ms, + double& total_cagra_ms) +{ + for (auto _ : state) { + before_timed_work(state); + total_bf_ms += time_synced_ms(handle, run_bf); + total_cagra_ms += time_synced_ms(handle, run_cagra); + } +} + +void set_speedup_counters(benchmark::State& state, + char const* speedup_key, + char const* bf_key, + char const* cagra_key, + double total_bf_ms, + double total_cagra_ms) +{ + if (total_cagra_ms > 0) { state.counters[speedup_key] = total_bf_ms / total_cagra_ms; } + state.counters[bf_key] = benchmark::Counter(total_bf_ms, benchmark::Counter::kAvgIterations); + state.counters[cagra_key] = + benchmark::Counter(total_cagra_ms, benchmark::Counter::kAvgIterations); +} + +void run_build_bf_vs_cagra_speedup(benchmark::State& state, + DatasetFixture const& fixture, + cuvs::neighbors::ivf_pq::index_params const& params_bf, + cuvs::neighbors::ivf_pq::index_params const& params_cagra, + char const* speedup_key, + char const* bf_key, + char const* cagra_key) +{ + double total_bf_ms = 0.0, total_cagra_ms = 0.0; + accumulate_bf_vs_cagra_ms( + state, + fixture.handle, + [](benchmark::State&) {}, + [&] { + auto idx = cuvs::neighbors::ivf_pq::build(fixture.handle, params_bf, fixture.view); + benchmark::DoNotOptimize(idx.size()); + }, + [&] { + auto idx = cuvs::neighbors::ivf_pq::build(fixture.handle, params_cagra, fixture.view); + benchmark::DoNotOptimize(idx.size()); + }, + total_bf_ms, + total_cagra_ms); + set_speedup_counters(state, speedup_key, bf_key, cagra_key, total_bf_ms, total_cagra_ms); +} + +} // namespace + +/** + * Full IVF-PQ build: brute vs CAGRA nearest-centroid lookup during k-means fit E-step only + * (`use_ann_for_build_fit`). Post-fit predict and add_data_on_build stay brute. Centroids move + * each fit iteration, so CAGRA is rebuilt across EM iterations. Times all of `build()`; PQ + * training and other steps dilute the measured speedup. + */ +static void BM_IVFPQ_BuildFit_NearestCentroidLookup_Speedup(benchmark::State& state) +{ + auto ranges = read_ranges(state); + DatasetFixture fixture(ranges); + run_build_bf_vs_cagra_speedup(state, + fixture, + make_build_fit_lookup_params(ranges.n_lists, false), + make_build_fit_lookup_params(ranges.n_lists, true), + "speedup_build_fit", + "bf_build_fit_ms", + "cagra_build_fit_ms"); +} + +/** + * Full IVF-PQ build: brute vs CAGRA nearest-centroid lookup during post-fit predict only + * (`use_ann_for_build_postfit`). Fit E-step and add_data_on_build stay brute. Centroids are + * fixed; CAGRA is built once per build for train-subsample labeling. Times all of `build()`; + * k-means fit, PQ training, and other steps dilute the measured speedup. + */ +static void BM_IVFPQ_BuildPostfit_NearestCentroidLookup_Speedup(benchmark::State& state) +{ + auto ranges = read_ranges(state); + DatasetFixture fixture(ranges); + run_build_bf_vs_cagra_speedup(state, + fixture, + make_build_postfit_lookup_params(ranges.n_lists, false), + make_build_postfit_lookup_params(ranges.n_lists, true), + "speedup_build_postfit", + "bf_build_postfit_ms", + "cagra_build_postfit_ms"); +} + +/** + * extend() only: brute vs CAGRA nearest-centroid lookup with fixed trained centroids + * (`use_ann_for_extend`). CAGRA is built once; new vectors are assigned via fast 1-NN search. + * Empty trained indices are restored each iteration via deserialize (not timed). + */ +static void BM_IVFPQ_Extend_NearestCentroidLookup_Speedup(benchmark::State& state) +{ + auto ranges = read_ranges(state); + DatasetFixture fixture(ranges); + auto params_bf = make_extend_lookup_params(ranges.n_lists, false); + auto params_cagra = make_extend_lookup_params(ranges.n_lists, true); + auto snapshots = + ExtendIndexSnapshots::create(fixture.handle, params_bf, params_cagra, fixture.view); + + cuvs::neighbors::ivf_pq::index index_bf(fixture.handle); + cuvs::neighbors::ivf_pq::index index_cagra(fixture.handle); + + double total_bf_ms = 0.0, total_cagra_ms = 0.0; + accumulate_bf_vs_cagra_ms( + state, + fixture.handle, + [&](benchmark::State& st) { + st.PauseTiming(); + snapshots.restore(fixture.handle, &index_bf, &index_cagra); + st.ResumeTiming(); + }, + [&] { cuvs::neighbors::ivf_pq::extend(fixture.handle, fixture.view, std::nullopt, &index_bf); }, + [&] { + cuvs::neighbors::ivf_pq::extend(fixture.handle, fixture.view, std::nullopt, &index_cagra); + }, + total_bf_ms, + total_cagra_ms); + set_speedup_counters( + state, "speedup_extend", "bf_extend_ms", "cagra_extend_ms", total_bf_ms, total_cagra_ms); +} + +constexpr int64_t kDim = 128; + +#define IVFPQ_LOOKUP_BENCH_ARGS(BM) \ + BENCHMARK(BM) \ + ->Args({327680, 65536, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({1000000, 200000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({1500000, 300000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({1750000, 350000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({2000000, 400000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({3000000, 600000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({4000000, 800000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); \ + BENCHMARK(BM) \ + ->Args({5000000, 1000000, kDim}) \ + ->Unit(benchmark::kMillisecond) \ + ->UseRealTime() \ + ->ArgNames({"n_vectors", "n_lists", "dim"}); + +// build() fit E-step only (full build timed). +IVFPQ_LOOKUP_BENCH_ARGS(BM_IVFPQ_BuildFit_NearestCentroidLookup_Speedup) + +// build() post-fit predict only (full build timed). +IVFPQ_LOOKUP_BENCH_ARGS(BM_IVFPQ_BuildPostfit_NearestCentroidLookup_Speedup) + +// extend(): fixed-centroid nearest-centroid lookup, CAGRA built once per extend. +IVFPQ_LOOKUP_BENCH_ARGS(BM_IVFPQ_Extend_NearestCentroidLookup_Speedup) + +BENCHMARK_MAIN(); diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index e2b4ea4a36..d28d2cde2d 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -161,6 +161,21 @@ struct balanced_params : base_params { * Number of training iterations */ uint32_t n_iters = 20; + + /** + * Use approximate nearest neighbor (CAGRA) for cluster assignment during k-means fit (E-step). + * When true and n_clusters is large, assignment uses a CAGRA index over centroids to speed + * up the E-step. The index is rebuilt every `ann_rebuild_interval` iterations to limit + * rebuild cost (batched centroid updates). + */ + bool use_ann_for_build_fit = false; + + /** + * Rebuild the ANN index used for fit assignment every this many iterations (when + * use_ann_for_build_fit is true). Larger values reduce index build cost but use staler + * centroids for assignment in between rebuilds. + */ + uint32_t ann_rebuild_interval = 3; }; /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index 57f8a258fb..cc2eddf0e9 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -131,6 +131,32 @@ struct index_params : cuvs::neighbors::index_params { */ uint32_t max_train_points_per_pq_code = 256; + /** + * Use CAGRA vs brute force for cluster assignment during extend() (and add_data_on_build). + * - std::nullopt (default): brute-force assignment. + * - true: use CAGRA for assignment. + * - false: brute-force assignment (explicit). + */ + std::optional use_ann_for_extend; + + /** + * Use CAGRA vs brute force for nearest-centroid lookup during `build()` k-means fit (E-step; + * centroids move each EM iter). This is the main ANN-accelerated step in build. + * - std::nullopt (default): brute-force assignment during fit. + * - true: ANN-based assignment during fit (CAGRA with periodic index rebuilds). + * - false: brute-force assignment during fit (explicit). + */ + std::optional use_ann_for_build_fit; + + /** + * Use CAGRA vs brute force for nearest-centroid lookup during `build()` post-fit predict + * (assign train subsample labels for PQ codebook training; centroids are fixed). + * - std::nullopt (default): brute-force assignment. + * - true: ANN-based assignment for the post-fit labeling step. + * - false: brute-force assignment (explicit). + */ + std::optional use_ann_for_build_postfit; + /** * Creates index_params based on shape of the input dataset. * Usage example: @@ -421,6 +447,10 @@ class index_iface { const raft::resources& res) const = 0; virtual raft::device_matrix_view centers_half( const raft::resources& res) const = 0; + + /** Stored IVF-PQ build param: CAGRA vs brute for extend/add_data cluster assignment (nullopt => + * brute). */ + virtual std::optional use_ann_for_extend() const = 0; }; /** @@ -656,6 +686,8 @@ class index : public index_iface, cuvs::neighbors::index { */ uint32_t get_list_size_in_bytes(uint32_t label) const override; + std::optional use_ann_for_extend() const noexcept override; + /** * @brief Construct index from implementation pointer. * diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index a290f7372f..5574aff525 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -6,7 +6,9 @@ #pragma once #include "kmeans_common.cuh" +#include "kmeans_predict_cagra.cuh" #include +#include #include "../../core/nvtx.hpp" #include "../../distance/distance.cuh" @@ -660,6 +662,7 @@ void balancing_em_iters(const raft::resources& handle, { auto stream = raft::resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; + std::optional> cagra_index_opt; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes // (but not on the first iteration) @@ -695,18 +698,38 @@ void balancing_em_iters(const raft::resources& handle, } default: break; } - // E: Expectation step - predict labels - predict(handle, - params, - cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - mapping_op, - device_memory, - dataset_norm); + // E: Expectation step - predict labels (optionally via CAGRA with batched index rebuild). + // build fit E-step: moving-centroid CAGRA with index reuse (use_ann_for_build_fit). + const bool use_cagra_for_cluster_assignment = + params.use_ann_for_build_fit && + n_clusters >= cuvs::cluster::kmeans::detail::kMinClustersForAnnFit && + std::is_same_v && std::is_same_v; + if (use_cagra_for_cluster_assignment) { + bool rebuild = (iter % params.ann_rebuild_interval == 0); + cuvs::cluster::kmeans::detail::assign_nearest_centroid_cagra_with_index_reuse( + handle, + params, + cluster_centers, + n_clusters, + dim, + reinterpret_cast(dataset), + n_rows, + cluster_labels, + &cagra_index_opt, + rebuild); + } else { + predict(handle, + params, + cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + mapping_op, + device_memory, + dataset_norm); + } // M: Maximization step - calculate optimal cluster centers calc_centers_and_sizes(handle, cluster_centers, diff --git a/cpp/src/cluster/detail/kmeans_predict_cagra.cuh b/cpp/src/cluster/detail/kmeans_predict_cagra.cuh new file mode 100644 index 0000000000..ca261c9813 --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_predict_cagra.cuh @@ -0,0 +1,149 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + * Cluster assignment via CAGRA: assign each data point to nearest centroid using an + * approximate nearest neighbor search (CAGRA) over the centroids instead of brute force. + * Used for scaling IVF training when the number of clusters K is very large. + * + * Shared helpers (build index on centroids, 1-NN search -> labels) are used by: + * - assign_nearest_centroid_cagra_with_index_reuse (k-means fit: optional index reuse on shifting + * centroids) + * - ivf_pq extend / build post-fit predict (fixed centroids: build_cagra_index_for_centroids + + * assign_nearest_centroid_cagra) + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace cuvs::cluster::kmeans::detail { + +/** Default search params for 1-NN centroid assignment (max_queries = auto). */ +inline cuvs::neighbors::cagra::search_params default_cagra_centroid_search_params() +{ + cuvs::neighbors::cagra::search_params p; + p.max_queries = 0; // auto + return p; +} + +/** + * @brief Build a CAGRA index on centroid vectors (shared by extend, + * assign_nearest_centroid_cagra callers, and assign_nearest_centroid_cagra_with_index_reuse). + */ +inline cuvs::neighbors::cagra::index build_cagra_index_for_centroids( + raft::resources const& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view centroids) +{ + using namespace cuvs::neighbors::cagra; + int64_t n_clusters = centroids.extent(0); + size_t graph_degree = std::min(64, std::max(1, n_clusters - 1)); + size_t inter_degree = std::min(128, std::max(1, n_clusters - 1)); + index_params build_params; + build_params.metric = params.metric; + build_params.graph_degree = graph_degree; + build_params.intermediate_graph_degree = inter_degree; + build_params.attach_dataset_on_build = true; + return build(handle, build_params, centroids); +} + +/** + * @brief Assign each query to its nearest centroid using an existing CAGRA index built on + * centroids. Writes cluster labels (and optional per-query distances). Queries must be float + * row-major [n_queries, dim]. + * + * Uses explicit row_major matrix view and raw label pointer so this compiles when + * raft::make_device_*_view returns layout_c_contiguous mdspan (not assignable to the view types + * required by cagra::search / device_vector_view). + */ +template +void assign_nearest_centroid_cagra( + raft::resources const& handle, + cuvs::neighbors::cagra::search_params const& search_params, + cuvs::neighbors::cagra::index const& cagra_index, + raft::device_matrix_view queries, + LabelT* labels_out, + int64_t n_labels, + float* distances_out = nullptr) +{ + using namespace cuvs::neighbors::cagra; + int64_t n_rows = queries.extent(0); + RAFT_EXPECTS(n_labels == n_rows, + "assign_nearest_centroid_cagra: labels length must match n_queries"); + auto neighbors = raft::make_device_matrix(handle, n_rows, 1); + auto distances = raft::make_device_matrix(handle, n_rows, 1); + search(handle, search_params, cagra_index, queries, neighbors.view(), distances.view()); + auto neighbors_col = + raft::make_device_vector_view(neighbors.data_handle(), n_rows); + auto labels_view = raft::make_device_vector_view(labels_out, n_rows); + raft::linalg::map( + handle, raft::make_const_mdspan(neighbors_col), labels_view, raft::cast_op()); + if (distances_out != nullptr) { + raft::copy( + handle, + raft::make_device_vector_view(distances_out, n_rows), + raft::make_device_vector_view(distances.data_handle(), n_rows)); + } +} + +/** Minimum number of clusters to use ANN for k-means fit (below this, brute is faster). */ +constexpr uint32_t kMinClustersForAnnFit = 5000; + +/** + * @brief Assign each row to nearest centroid using CAGRA, reusing or rebuilding the index. + * + * When rebuild is true (or index is empty), builds the index on current centroids and stores it + * in *index_opt. Otherwise skips build and searches with the existing index: centroid vectors in + * memory may have shifted since that build (k-means M-step), so the graph still indexes a stale + * snapshot — assignments are intentionally approximate between rebuilds. + * + * Used by k-means fit when use_ann_for_build_fit and n_clusters >= kMinClustersForAnnFit. + */ +template +void assign_nearest_centroid_cagra_with_index_reuse( + raft::resources const& handle, + cuvs::cluster::kmeans::balanced_params const& params, + const float* centers, + IdxT n_clusters, + IdxT dim, + const float* dataset, + IdxT n_rows, + LabelT* labels, + std::optional>* index_opt, + bool rebuild) +{ + RAFT_EXPECTS( + centers != nullptr && dataset != nullptr && labels != nullptr && index_opt != nullptr, + "assign_nearest_centroid_cagra_with_index_reuse: null argument"); + RAFT_EXPECTS(n_clusters >= 1 && dim >= 1 && n_rows >= 1, + "assign_nearest_centroid_cagra_with_index_reuse: bad extents"); + + raft::device_matrix_view centers_view( + centers, static_cast(n_clusters), static_cast(dim)); + raft::device_matrix_view queries_view( + dataset, static_cast(n_rows), static_cast(dim)); + + if (rebuild || !index_opt->has_value()) { + *index_opt = build_cagra_index_for_centroids(handle, params, centers_view); + } + + assign_nearest_centroid_cagra(handle, + default_cagra_centroid_search_params(), + index_opt->value(), + queries_view, + labels, + static_cast(n_rows), + nullptr); +} + +} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index c562ca9e00..c861936495 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -954,7 +954,8 @@ auto clone(const raft::resources& res, const index& source) -> index source.pq_bits(), source.pq_dim(), source.conservative_memory_allocation(), - source.codes_layout()); + source.codes_layout(), + source.use_ann_for_extend()); // Copy the independent parts using mutable accessors raft::copy(res, impl->list_sizes(), source.list_sizes()); @@ -1112,21 +1113,65 @@ void extend(raft::resources const& handle, n_clusters, stream); vec_batches.prefetch_next_batch(); - for (const auto& batch : vec_batches) { - auto batch_data_view = raft::make_device_matrix_view( - batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( - cluster_centers.data(), n_clusters, index->dim()); - cuvs::cluster::kmeans::balanced_params kmeans_params; - kmeans_params.metric = coarse_clustering_metric(index->metric()); - cuvs::cluster::kmeans::predict( - handle, kmeans_params, batch_data_view, centers_view, batch_labels_view); - vec_batches.prefetch_next_batch(); - // User needs to make sure kernel finishes its work before we overwrite batch in the next - // iteration if different streams are used for kernel and copy. - raft::resource::sync_stream(handle); + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.metric = coarse_clustering_metric(index->metric()); + + // Default: brute-force assignment; set use_ann_for_extend to opt in to CAGRA. + // extend(): fixed-centroid CAGRA nearest-centroid assignment — build index once, assign per + // batch. + const bool use_cagra_for_cluster_assignment = index->use_ann_for_extend().value_or(false); + if (use_cagra_for_cluster_assignment) { + raft::device_matrix_view centers_view( + cluster_centers.data(), + static_cast(n_clusters), + static_cast(index->dim())); + const auto cagra_index = cuvs::cluster::kmeans::detail::build_cagra_index_for_centroids( + handle, kmeans_params, centers_view); + const auto cagra_search_params = + cuvs::cluster::kmeans::detail::default_cagra_centroid_search_params(); + + for (const auto& batch : vec_batches) { + auto batch_size = batch.size(); + rmm::device_uvector queries_float( + static_cast(batch_size) * static_cast(index->dim()), + stream, + device_memory); + auto batch_view = raft::make_device_matrix_view( + batch.data(), batch_size, index->dim()); + raft::linalg::map(handle, + raft::make_const_mdspan(batch_view), + raft::make_device_matrix_view( + queries_float.data(), batch_size, index->dim()), + utils::mapping{}); + raft::device_matrix_view queries_view( + queries_float.data(), + static_cast(batch_size), + static_cast(index->dim())); + cuvs::cluster::kmeans::detail::assign_nearest_centroid_cagra( + handle, + cagra_search_params, + cagra_index, + queries_view, + new_data_labels.data() + batch.offset(), + static_cast(batch_size)); + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(handle); + } + } else { + for (const auto& batch : vec_batches) { + auto batch_data_view = raft::make_device_matrix_view( + batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_data_labels.data() + batch.offset(), batch.size()); + auto centers_view = raft::make_device_matrix_view( + cluster_centers.data(), n_clusters, index->dim()); + cuvs::cluster::kmeans::predict( + handle, kmeans_params, batch_data_view, centers_view, batch_labels_view); + vec_batches.prefetch_next_batch(); + // User needs to make sure kernel finishes its work before we overwrite batch in the next + // iteration if different streams are used for kernel and copy. + raft::resource::sync_stream(handle); + } } } @@ -1253,7 +1298,8 @@ auto build(raft::resources const& handle, params.pq_bits, params.pq_dim == 0 ? index::calculate_pq_dim(dim) : params.pq_dim, params.conservative_memory_allocation, - params.codes_layout); + params.codes_layout, + params.use_ann_for_extend); auto stream = raft::resource::get_cuda_stream(handle); utils::memzero( @@ -1332,6 +1378,12 @@ auto build(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = coarse_clustering_metric(impl->metric()); + // Propagate use_ann_for_build_fit into k-means; CAGRA runs inside fit's E-step + // (balancing_em_iters → assign_nearest_centroid_cagra_with_index_reuse), not here. + if (params.use_ann_for_build_fit.value_or(false)) { + kmeans_params.use_ann_for_build_fit = true; + kmeans_params.ann_rebuild_interval = 3; + } if (impl->metric() == distance::DistanceType::CosineExpanded) { raft::linalg::row_normalize( @@ -1348,8 +1400,29 @@ auto build(raft::resources const& handle, } auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); - cuvs::cluster::kmeans::predict( - handle, kmeans_params, trainset_const_view, centers_const_view, labels_view); + // build post-fit: fixed-centroid CAGRA nearest-centroid assignment (use_ann_for_build_postfit). + const bool use_cagra_for_cluster_assignment = + params.use_ann_for_build_postfit.value_or(false) && + impl->n_lists() >= cuvs::cluster::kmeans::detail::kMinClustersForAnnFit; + if (use_cagra_for_cluster_assignment) { + raft::device_matrix_view centers_view( + cluster_centers, static_cast(impl->n_lists()), static_cast(impl->dim())); + raft::device_matrix_view queries_view( + trainset.data_handle(), + static_cast(n_rows_train), + static_cast(impl->dim())); + cuvs::cluster::kmeans::detail::assign_nearest_centroid_cagra( + handle, + cuvs::cluster::kmeans::detail::default_cagra_centroid_search_params(), + cuvs::cluster::kmeans::detail::build_cagra_index_for_centroids( + handle, kmeans_params, centers_view), + queries_view, + labels.data(), + static_cast(n_rows_train)); + } else { + cuvs::cluster::kmeans::predict( + handle, kmeans_params, trainset_const_view, centers_const_view, labels_view); + } // Make rotation matrix helpers::make_rotation_matrix(handle, impl->rotation_matrix(), params.force_random_rotation); @@ -1562,7 +1635,8 @@ auto build( index_params.pq_bits, pq_dim, index_params.conservative_memory_allocation, - index_params.codes_layout); + index_params.codes_layout, + index_params.use_ann_for_extend); utils::memzero( impl->accum_sorted_sizes().data_handle(), impl->accum_sorted_sizes().size(), stream); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh index f20f2bf81e..13b8b66400 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh @@ -25,7 +25,28 @@ namespace cuvs::neighbors::ivf_pq::detail { // Serialization version // Version 4 adds codes_layout field -constexpr int kSerializationVersion = 4; +// Version 5 adds use_ann_for_extend (tri-state int8: -1 nullopt, 0 false, 1 true) +constexpr int kSerializationVersion = 5; + +/** Serialize std::optional as int8_t (-1 = nullopt, 0/1 = false/true). */ +inline void serialize_optional_bool(raft::resources const& handle, + std::ostream& os, + std::optional v) +{ + int8_t flag = v.has_value() ? static_cast(v.value() ? 1 : 0) : int8_t{-1}; + raft::serialize_scalar(handle, os, flag); +} + +/** Deserialize std::optional from int8_t (-1 = nullopt, 0/1 = false/true). */ +inline std::optional deserialize_optional_bool(raft::resources const& handle, + std::istream& is) +{ + int8_t flag = raft::deserialize_scalar(handle, is); + if (flag == 1) { return true; } + if (flag == 0) { return false; } + RAFT_EXPECTS(flag == -1, "ivf_pq::deserialize: invalid use_ann_for_extend flag %d", int{flag}); + return std::nullopt; +} /** * Write the index to an output stream @@ -57,6 +78,7 @@ void serialize(raft::resources const& handle_, std::ostream& os, const index{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { - auto& typed_list = static_cast&>(*index.lists()[label]); + auto typed_list = std::static_pointer_cast>(index.lists()[label]); ivf::serialize_list(handle_, os, typed_list, list_store_spec, sizes_host(label)); } } else { auto list_store_spec = list_spec_interleaved{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { - auto& typed_list = static_cast&>(*index.lists()[label]); + auto typed_list = std::static_pointer_cast>(index.lists()[label]); ivf::serialize_list(handle_, os, typed_list, list_store_spec, sizes_host(label)); } } @@ -137,6 +158,7 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index(handle_, is); auto codes_layout = raft::deserialize_scalar(handle_, is); auto n_lists = raft::deserialize_scalar(handle_, is); + auto use_ann_for_extend = deserialize_optional_bool(handle_, is); RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d", static_cast(n_rows), @@ -174,8 +196,16 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index>( - handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma, codes_layout); + auto impl = std::make_unique>(handle_, + metric, + codebook_kind, + n_lists, + dim, + pq_bits, + pq_dim, + cma, + codes_layout, + use_ann_for_extend); // Deserialize center/matrix data using mutable accessors raft::deserialize_mdspan(handle_, is, impl->pq_centers()); diff --git a/cpp/src/neighbors/ivf_pq_impl.hpp b/cpp/src/neighbors/ivf_pq_impl.hpp index 5c96755808..8c44442be0 100644 --- a/cpp/src/neighbors/ivf_pq_impl.hpp +++ b/cpp/src/neighbors/ivf_pq_impl.hpp @@ -20,7 +20,8 @@ class index_impl : public index_iface { uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout = list_layout::INTERLEAVED); + list_layout codes_layout = list_layout::INTERLEAVED, + std::optional use_ann_for_extend = std::nullopt); ~index_impl() = default; index_impl(index_impl&&) = default; @@ -71,9 +72,11 @@ class index_impl : public index_iface { const raft::resources& res) const override; uint32_t get_list_size_in_bytes(uint32_t label) const override; + std::optional use_ann_for_extend() const noexcept override; protected: cuvs::distance::DistanceType metric_; + std::optional use_ann_for_extend_; codebook_gen codebook_kind_; list_layout codes_layout_; uint32_t dim_; @@ -113,7 +116,8 @@ class owning_impl : public index_impl { uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout = list_layout::INTERLEAVED); + list_layout codes_layout = list_layout::INTERLEAVED, + std::optional use_ann_for_extend = std::nullopt); ~owning_impl() = default; owning_impl(owning_impl&&) = default; @@ -159,7 +163,8 @@ class view_impl : public index_impl { raft::device_matrix_view centers_view, raft::device_matrix_view centers_rot_view, raft::device_matrix_view rotation_matrix_view, - list_layout codes_layout = list_layout::INTERLEAVED); + list_layout codes_layout = list_layout::INTERLEAVED, + std::optional use_ann_for_extend = std::nullopt); ~view_impl() = default; view_impl(view_impl&&) = default; diff --git a/cpp/src/neighbors/ivf_pq_index.cu b/cpp/src/neighbors/ivf_pq_index.cu index 28b985eec8..d96ca6d26f 100644 --- a/cpp/src/neighbors/ivf_pq_index.cu +++ b/cpp/src/neighbors/ivf_pq_index.cu @@ -26,7 +26,8 @@ index_impl::index_impl(raft::resources const& handle, uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout) + list_layout codes_layout, + std::optional use_ann_for_extend) : metric_(metric), codebook_kind_(codebook_kind), codes_layout_(codes_layout), @@ -34,6 +35,7 @@ index_impl::index_impl(raft::resources const& handle, pq_bits_(pq_bits), pq_dim_(pq_dim == 0 ? index::calculate_pq_dim(dim) : pq_dim), conservative_memory_allocation_(conservative_memory_allocation), + use_ann_for_extend_(use_ann_for_extend), lists_(n_lists), list_sizes_{raft::make_device_vector(handle, n_lists)}, data_ptrs_{raft::make_device_vector(handle, n_lists)}, @@ -122,6 +124,12 @@ bool index_impl::conservative_memory_allocation() const noexcept return conservative_memory_allocation_; } +template +std::optional index_impl::use_ann_for_extend() const noexcept +{ + return use_ann_for_extend_; +} + template std::vector>>& index_impl::lists() noexcept { @@ -198,7 +206,8 @@ owning_impl::owning_impl(raft::resources const& handle, uint32_t pq_bits, uint32_t pq_dim, bool conservative_memory_allocation, - list_layout codes_layout) + list_layout codes_layout, + std::optional use_ann_for_extend) : index_impl(handle, metric, codebook_kind, @@ -207,7 +216,8 @@ owning_impl::owning_impl(raft::resources const& handle, pq_bits, pq_dim, conservative_memory_allocation, - codes_layout), + codes_layout, + use_ann_for_extend), pq_centers_{raft::make_device_mdarray( handle, index::make_pq_centers_extents(dim, pq_dim, pq_bits, codebook_kind, n_lists))}, centers_{ @@ -248,7 +258,8 @@ view_impl::view_impl( raft::device_matrix_view centers_view, raft::device_matrix_view centers_rot_view, raft::device_matrix_view rotation_matrix_view, - list_layout codes_layout) + list_layout codes_layout, + std::optional use_ann_for_extend) : index_impl(handle, metric, codebook_kind, @@ -257,7 +268,8 @@ view_impl::view_impl( pq_bits, pq_dim, conservative_memory_allocation, - codes_layout), + codes_layout, + use_ann_for_extend), pq_centers_view_(pq_centers_view), centers_view_(centers_view), centers_rot_view_(centers_rot_view), @@ -595,6 +607,12 @@ uint32_t index::get_list_size_in_bytes(uint32_t label) const return impl_->get_list_size_in_bytes(label); } +template +std::optional index::use_ann_for_extend() const noexcept +{ + return impl_->use_ann_for_extend(); +} + template void index_impl::check_consistency() {