diff --git a/include/svs/index/ivf/clustering.h b/include/svs/index/ivf/clustering.h index cc9bb09b..65828525 100644 --- a/include/svs/index/ivf/clustering.h +++ b/include/svs/index/ivf/clustering.h @@ -410,6 +410,66 @@ class DenseClusteredDataset { static constexpr lib::Version save_version{0, 0, 0}; static constexpr std::string_view serialization_schema = "ivf_dense_clustered_dataset"; + lib::SaveTable metadata() const { + auto num_clusters = size(); + auto dims = dimensions(); + + return lib::SaveTable( + serialization_schema, + save_version, + {{"num_clusters", lib::save(num_clusters)}, + {"dimensions", lib::save(dims)}, + {"prefetch_offset", lib::save(prefetch_offset_)}, + {"index_type", lib::save(datatype_v)}} + ); + } + + void save(std::ostream& os) const { + auto num_clusters = size(); + + // Compute cluster sizes and ID offsets + std::vector cluster_sizes, ids_offsets; + calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets); + + lib::write_binary(os, cluster_sizes); + for (size_t i = 0; i < num_clusters; ++i) { + lib::save_to_stream(clusters_[i].data_, os); + if (!clusters_[i].ids_.empty()) { + lib::write_binary(os, clusters_[i].ids_); + } + } + } + + template + static DenseClusteredDataset load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const Allocator& allocator = Allocator{} + ) { + auto num_clusters = lib::load_at(table, "num_clusters"); + [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); + auto prefetch_offset = lib::load_at(table, "prefetch_offset"); + check_saved_index_type(table); + + DenseClusteredDataset result; + result.prefetch_offset_ = prefetch_offset; + result.clusters_.reserve(num_clusters); + + std::vector cluster_sizes(num_clusters); + lib::read_binary(is, cluster_sizes); + + for (size_t i = 0; i < num_clusters; ++i) { + auto cluster_data = lib::load_from_stream(is, allocator); + size_t cluster_size = cluster_sizes[i]; + std::vector cluster_ids(cluster_size); + if (cluster_size > 0) { + lib::read_binary(is, cluster_ids); + } + result.clusters_.emplace_back(std::move(cluster_data), std::move(cluster_ids)); + } + return result; + } + /// @brief Save the DenseClusteredDataset to disk. /// /// Saves all cluster data using the existing save mechanisms for each data type @@ -427,16 +487,8 @@ class DenseClusteredDataset { auto dims = dimensions(); // Compute cluster sizes and ID offsets - std::vector cluster_sizes(num_clusters); - std::vector ids_offsets(num_clusters + 1); - - size_t ids_offset = 0; - for (size_t i = 0; i < num_clusters; ++i) { - cluster_sizes[i] = clusters_[i].size(); - ids_offsets[i] = ids_offset; - ids_offset += cluster_sizes[i] * sizeof(I); - } - ids_offsets[num_clusters] = ids_offset; + std::vector cluster_sizes, ids_offsets; + calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets); // Create a temporary directory for cluster data lib::UniqueTempDirectory tempdir{"svs_ivf_clusters_save"}; @@ -495,7 +547,7 @@ class DenseClusteredDataset { {"ids_file", lib::save(std::string("ids.bin"))}, {"cluster_sizes_file", lib::save(std::string("cluster_sizes.bin"))}, {"ids_offsets_file", lib::save(std::string("ids_offsets.bin"))}, - {"total_ids_bytes", lib::save(ids_offset)}} + {"total_ids_bytes", lib::save(ids_offsets[num_clusters])}} ); } @@ -528,17 +580,7 @@ class DenseClusteredDataset { // since each cluster's data type determines its own dimensions [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); auto prefetch_offset = lib::load_at(table, "prefetch_offset"); - - // Verify index type matches - auto saved_index_type = lib::load_at(table, "index_type"); - if (saved_index_type != datatype_v) { - throw ANNEXCEPTION( - "DenseClusteredDataset was saved using index type {} but we're trying to " - "reload it using {}!", - saved_index_type, - datatype_v - ); - } + check_saved_index_type(table); auto base_dir = table.context().get_directory(); @@ -603,6 +645,35 @@ class DenseClusteredDataset { } private: + void calculate_cluster_sizes_and_offsets( + std::vector* cluster_sizes, std::vector* ids_offsets + ) const { + auto num_clusters = size(); + cluster_sizes->resize(num_clusters); + ids_offsets->resize(num_clusters + 1); + + size_t ids_offset = 0; + for (size_t i = 0; i < num_clusters; ++i) { + (*cluster_sizes)[i] = clusters_[i].size(); + (*ids_offsets)[i] = ids_offset; + ids_offset += clusters_[i].size() * sizeof(I); + } + (*ids_offsets)[num_clusters] = ids_offset; + } + + static void check_saved_index_type(const lib::ContextFreeLoadTable& table) { + // Verify index type matches + auto saved_index_type = lib::load_at(table, "index_type"); + if (saved_index_type != datatype_v) { + throw ANNEXCEPTION( + "DenseClusteredDataset was saved using index type {} but we're trying to " + "reload it using {}!", + saved_index_type, + datatype_v + ); + } + } + std::vector> clusters_; size_t prefetch_offset_ = 8; }; diff --git a/include/svs/index/ivf/dynamic_ivf.h b/include/svs/index/ivf/dynamic_ivf.h index 92fc86ed..40ad32c3 100644 --- a/include/svs/index/ivf/dynamic_ivf.h +++ b/include/svs/index/ivf/dynamic_ivf.h @@ -761,6 +761,10 @@ class DynamicIVFIndex { lib::save_to_disk(clusters_, clusters_dir); } + void save(std::ostream& SVS_UNUSED(os)) const { + throw ANNEXCEPTION("Placeholder; not implemented!"); + } + private: ///// Helper Methods ///// diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 263b1868..00f45735 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -465,6 +465,7 @@ class IVFIndex { /// Each directory may be created as a side-effect of this method call provided that /// the parent directory exists. /// + void save( const std::filesystem::path& config_directory, const std::filesystem::path& data_directory @@ -504,6 +505,27 @@ class IVFIndex { lib::save_to_disk(cluster_, clusters_dir); } + void save(std::ostream& os) const { + lib::begin_serialization(os); + + // Save config + auto data_type_config = DataTypeTraits::get_config(); + data_type_config.centroid_type = datatype_v; + auto save_table = lib::SaveTable( + serialization_schema, + save_version, + {{"name", lib::save(name())}, {"num_clusters", lib::save(num_clusters())}} + ); + save_table.insert("data_type_config", lib::save(data_type_config)); + lib::save_to_stream(save_table, os); + + // Save centroids + lib::save_to_stream(centroids_, os); + + // Save clusters + lib::save_to_stream(cluster_, os); + } + private: ///// Core Components ///// Centroids centroids_; @@ -964,4 +986,36 @@ auto load_ivf_index( return index; } +template < + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> +auto load_ivf_index( + std::istream& is, + Distance distance, + ThreadpoolProto threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() +) { + using centroids_type = data::SimpleData; + using data_type = typename DataType::lib_alloc_data_type; + using cluster_type = DenseClusteredDataset; + + lib::detail::read_metadata(is); + auto centroids = lib::load_from_stream(is); + auto clusters = lib::load_from_stream(is); + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + + return IVFIndex( + std::move(centroids), + std::move(clusters), + std::move(distance), + std::move(threadpool), + intra_query_thread_count, + logger + ); +} + } // namespace svs::index::ivf diff --git a/include/svs/orchestrators/ivf.h b/include/svs/orchestrators/ivf.h index 90c98078..87cb1b22 100644 --- a/include/svs/orchestrators/ivf.h +++ b/include/svs/orchestrators/ivf.h @@ -116,13 +116,7 @@ class IVFImpl : public manager::ManagerImpl { void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current IVF backend doesn't support saving!"); } @@ -380,27 +374,54 @@ class IVF : public manager::IndexManager { ThreadpoolProto threadpool_proto, size_t intra_query_threads = 1 ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + stream, + std::move(distance_function), + std::move(threadpool_proto), + intra_query_threads + ) + ); + }); + } else { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + stream, distance, std::move(threadpool_proto), intra_query_threads + ) + ); + } + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); - } + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); + } - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); - } + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); + } - return assemble( - config_path, - data_path, - distance, - std::move(threadpool_proto), - intra_query_threads - ); + return assemble( + config_path, + data_path, + distance, + std::move(threadpool_proto), + intra_query_threads + ); + } } ///// Building diff --git a/tests/svs/index/ivf/index.cpp b/tests/svs/index/ivf/index.cpp index 7c78a3ad..b15f2604 100644 --- a/tests/svs/index/ivf/index.cpp +++ b/tests/svs/index/ivf/index.cpp @@ -16,6 +16,7 @@ // header under test #include "svs/index/ivf/index.h" +#include "svs/orchestrators/ivf.h" // tests #include "tests/utils/test_dataset.h" @@ -33,6 +34,7 @@ // stl #include +#include CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") { namespace ivf = svs::index::ivf; @@ -276,6 +278,81 @@ CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") { svs_test::cleanup_temp_directory(); } + CATCH_SECTION("Load IVF Index serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + using DataType = svs::data::SimpleData; + + auto loaded_ivf = svs::IVF::assemble( + stream, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + CATCH_REQUIRE(loaded_ivf.size() == index.size()); + CATCH_REQUIRE(loaded_ivf.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_ivf.search(loaded_results.view(), batch_queries, search_params); + + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE( + loaded_results.index(q, i) == original_results.index(q, i) + ); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load IVF Index serialized natively to stream") { + std::stringstream stream; + index.save(stream); + + { + using DataType = svs::data::SimpleData; + + auto loaded_ivf = svs::IVF::assemble( + stream, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + CATCH_REQUIRE(loaded_ivf.size() == index.size()); + CATCH_REQUIRE(loaded_ivf.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_ivf.search(loaded_results.view(), batch_queries, search_params); + + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE( + loaded_results.index(q, i) == original_results.index(q, i) + ); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + CATCH_SECTION("Save and load DenseClusteredDataset") { // Prepare temp directory auto tempdir = svs_test::prepare_temp_directory_v2();