From c6c42c4b87527457d2818e3b1548bd5cba5ddd8f Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Wed, 4 Mar 2026 15:00:42 +0100 Subject: [PATCH 1/7] Native serialization to a stream for FlatIndex (#280) Reopening of https://github.com/intel/ScalableVectorSearch/pull/275 for developer branch --------- Co-authored-by: Dmitry Razdoburdin Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/svs/core/data/io.h | 33 +++++++++ include/svs/core/data/simple.h | 87 +++++++++++++++++++++--- include/svs/core/io/native.h | 78 +++++++++++++-------- include/svs/index/flat/flat.h | 2 + include/svs/lib/archiver.h | 93 ++++++++++++++++++++++++++ include/svs/lib/file.h | 63 ++--------------- include/svs/lib/saveload/load.h | 56 ++++++++++++++++ include/svs/lib/saveload/save.h | 21 ++++++ include/svs/lib/stream.h | 73 ++++++++++++++++++++ include/svs/orchestrators/exhaustive.h | 10 +-- tests/svs/index/flat/flat.cpp | 78 +++++++++++++++++++++ 11 files changed, 493 insertions(+), 101 deletions(-) create mode 100644 include/svs/lib/archiver.h create mode 100644 include/svs/lib/stream.h diff --git a/include/svs/core/data/io.h b/include/svs/core/data/io.h index 6bb856fed..572ffa63a 100644 --- a/include/svs/core/data/io.h +++ b/include/svs/core/data/io.h @@ -79,6 +79,22 @@ void populate_impl( } } +template void populate(std::istream& is, Data& data) { + auto accessor = DefaultWriteAccessor(); + + size_t num_vectors = data.size(); + size_t dims = data.dimensions(); + + auto max_lines = Dynamic; + auto nvectors = std::min(num_vectors, max_lines); + + auto reader = lib::VectorReader(dims); + for (size_t i = 0; i < nvectors; ++i) { + reader.read(is); + accessor.set(data, i, reader.data()); + } +} + // Intercept the native file to perform dispatch on the actual file type. template void populate_impl( @@ -120,6 +136,15 @@ void save(const Dataset& data, const File& file, const lib::UUID& uuid = lib::Ze return save(data, accessor, file, uuid); } +template +void save(const Dataset& data, std::ostream& os) { + auto accessor = DefaultReadAccessor(); + auto writer = svs::io::v1::StreamWriter(os); + for (size_t i = 0; i < data.size(); ++i) { + writer << accessor.get(data, i); + } +} + /// /// @brief Save the dataset as a "*vecs" file. /// @@ -169,6 +194,14 @@ lib::lazy_result_t load_dataset(const File& file, const F& la return load_impl(detail::to_native(file), default_accessor, lazy); } +template F> +lib::lazy_result_t +load_dataset(std::istream& is, const F& lazy, size_t num_vectors, size_t dims) { + auto data = lazy(num_vectors, dims); + populate(is, data); + return data; +} + // Return whether or not a file is directly loadable via file-extension. inline bool special_by_file_extension(std::string_view path) { return (path.ends_with("svs") || path.ends_with("vecs") || path.ends_with("bin")); diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 0fcb31bbb..cf3c5df24 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -75,24 +75,42 @@ class GenericSerializer { } template - static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) { + static lib::SaveTable save_table(const Data& data) { using T = typename Data::element_type; - // UUID used to identify the file. - auto uuid = lib::UUID{}; - auto filename = ctx.generate_name("data"); - io::save(data, io::NativeFile(filename), uuid); - return lib::SaveTable( + auto table = lib::SaveTable( serialization_schema, save_version, { {"name", "uncompressed"}, - {"binary_file", lib::save(filename.filename())}, {"dims", lib::save(data.dimensions())}, {"num_vectors", lib::save(data.size())}, - {"uuid", uuid.str()}, {"eltype", lib::save(datatype_v)}, } ); + return table; + } + + template + static lib::SaveTable + save_table(const Data& data, const FileName_t& filename, const lib::UUID& uuid) { + auto table = save_table(data); + table.insert("binary_file", filename); + table.insert("uuid", uuid.str()); + return table; + } + + template + static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) { + // UUID used to identify the file. + auto uuid = lib::UUID{}; + auto filename = ctx.generate_name("data"); + io::save(data, io::NativeFile(filename), uuid); + return save_table(data, lib::save(filename.filename()), uuid); + } + + template + static void save(const Data& data, std::ostream& os) { + io::save(data, os); } template F> @@ -116,6 +134,25 @@ class GenericSerializer { } return io::load_dataset(binaryfile.value(), lazy); } + + template F> + static lib::lazy_result_t + load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) { + auto datatype = lib::load_at(table, "eltype"); + if (datatype != datatype_v) { + throw ANNEXCEPTION( + "Trying to load an uncompressed dataset with element types {} to a dataset " + "with element types {}.", + name(datatype), + name>() + ); + } + + size_t num_vectors = lib::load_at(table, "num_vectors"); + size_t dims = lib::load_at(table, "dims"); + + return io::load_dataset(is, lazy, num_vectors, dims); + } }; struct Matcher { @@ -405,6 +442,10 @@ class SimpleData { return GenericSerializer::save(*this, ctx); } + void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } + + lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); } @@ -431,6 +472,20 @@ class SimpleData { ); } + static SimpleData load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const allocator_type& allocator = {} + ) + requires(!is_view) + { + return GenericSerializer::load( + table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { + return SimpleData(n_elements, n_dimensions, allocator); + }) + ); + } + /// /// @brief Try to automatically load the dataset. /// @@ -805,6 +860,10 @@ class SimpleData> { return GenericSerializer::save(*this, ctx); } + void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } + + lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); } @@ -818,6 +877,18 @@ class SimpleData> { ); } + static SimpleData load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const Blocked& allocator = {} + ) { + return GenericSerializer::load( + table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { + return SimpleData(n_elements, n_dimensions, allocator); + }) + ); + } + static SimpleData load(const std::filesystem::path& path, const Blocked& allocator = {}) { if (detail::is_likely_reload(path)) { diff --git a/include/svs/core/io/native.h b/include/svs/core/io/native.h index 0039128d3..0f6476686 100644 --- a/include/svs/core/io/native.h +++ b/include/svs/core/io/native.h @@ -344,28 +344,16 @@ struct Header { static_assert(sizeof(Header) == header_size, "Mismatch in Native io::v1 header sizes!"); static_assert(std::is_trivially_copyable_v
, "Header must be trivially copyable!"); -template class Writer { +// CRTP +template class Writer { public: - Writer( - const std::string& path, - size_t dimension, - lib::UUID uuid = lib::UUID(lib::ZeroInitializer()) - ) - : dimension_{dimension} - , uuid_{uuid} - , stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} { - // Write a temporary header. - stream_.seekp(0, std::ofstream::beg); - lib::write_binary(stream_, Header()); - } - - size_t dimensions() const { return dimension_; } void overwrite_num_vectors(size_t num_vectors) { vectors_written_ = num_vectors; } // TODO: Error checking to make sure the length is correct. template Writer& append(U&& v) { + std::ostream& os = static_cast(this)->stream(); for (const auto& i : v) { - lib::write_binary(stream_, lib::io_convert(i)); + lib::write_binary(os, lib::io_convert(i)); } ++vectors_written_; return *this; @@ -374,13 +362,37 @@ template class Writer { template requires std::is_same_v Writer& append(std::tuple&& v) { - lib::foreach (v, [&](const auto& x) { lib::write_binary(stream_, x); }); + std::ostream& os = static_cast(this)->stream(); + lib::foreach (v, [&](const auto& x) { lib::write_binary(os, x); }); ++vectors_written_; return *this; } template Writer& operator<<(U&& v) { return append(std::forward(v)); } + protected: + size_t vectors_written_ = 0; +}; + +template class FileWriter : public Writer> { + public: + FileWriter( + const std::string& path, + size_t dimension, + lib::UUID uuid = lib::UUID(lib::ZeroInitializer()) + ) + : dimension_{dimension} + , uuid_{uuid} + , stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} { + // Write a temporary header. + stream_.seekp(0, std::ofstream::beg); + lib::write_binary(stream_, Header()); + } + + std::ostream& stream() { return stream_; } + + size_t dimensions() const { return dimension_; } + void flush() { stream_.flush(); } void writeheader(bool resume = true) { @@ -388,7 +400,7 @@ template class Writer { // Write to the header the number of vectors actually written. stream_.seekp(0); assert(stream_.good()); - lib::write_binary(stream_, Header(vectors_written_, dimension_, uuid_)); + lib::write_binary(stream_, Header(this->vectors_written_, dimension_, uuid_)); if (resume) { stream_.seekp(position, std::ofstream::beg); } @@ -402,20 +414,30 @@ template class Writer { // // We delete the copy constructor and copy assignment operators because // `std::ofstream` isn't copyable anyways. - Writer(const Writer&) = delete; - Writer& operator=(const Writer&) = delete; - Writer(Writer&&) = delete; - Writer& operator=(Writer&&) = delete; + FileWriter(const FileWriter&) = delete; + FileWriter& operator=(const FileWriter&) = delete; + FileWriter(FileWriter&&) = delete; + FileWriter& operator=(FileWriter&&) = delete; // Write the header for the file. - ~Writer() noexcept { writeheader(); } + ~FileWriter() noexcept { writeheader(); } private: size_t dimension_; lib::UUID uuid_; std::ofstream stream_; size_t writes_this_vector_ = 0; - size_t vectors_written_ = 0; +}; + +template class StreamWriter : public Writer> { + public: + StreamWriter(std::ostream& os) + : stream_{os} {} + + std::ostream& stream() { return stream_; } + + private: + std::ostream& stream_; }; /// @@ -449,13 +471,13 @@ class NativeFile { } template - Writer writer( + FileWriter writer( lib::Type SVS_UNUSED(type), size_t dimension, lib::UUID uuid = lib::ZeroUUID ) const { - return Writer(path_, dimension, uuid); + return FileWriter(path_, dimension, uuid); } - Writer<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const { + FileWriter<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const { return writer(lib::Type(), dimensions, uuid); } @@ -715,7 +737,7 @@ class NativeFile { public: using compatible_file_types = lib::Types; - template using Writer = v1::Writer; + template using Writer = v1::FileWriter; explicit NativeFile(std::filesystem::path path) : path_{std::move(path)} {} diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index 187fc7440..d81e7cf54 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -522,6 +522,8 @@ class FlatIndex { void save(const std::filesystem::path& data_directory) const { lib::save_to_disk(data_, data_directory); } + + void save(std::ostream& os) const { lib::save_to_stream(data_, os); } }; /// diff --git a/include/svs/lib/archiver.h b/include/svs/lib/archiver.h new file mode 100644 index 000000000..ce3ca438f --- /dev/null +++ b/include/svs/lib/archiver.h @@ -0,0 +1,93 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// svs +#include "svs/lib/exception.h" + +// stl +#include +#include +#include +#include +#include + +namespace svs::lib { + +// CRTP +template struct Archiver { + using size_type = uint64_t; + + // TODO: Define CACHELINE_BYTES in a common place + // rather than duplicating it here and in prefetch.h + static constexpr auto CACHELINE_BYTES = 64; + + static size_type write_size(std::ostream& os, size_type size) { + os.write(reinterpret_cast(&size), sizeof(size)); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + return sizeof(size); + } + + static size_type read_size(std::istream& is, size_type& size) { + is.read(reinterpret_cast(&size), sizeof(size)); + if (!is) { + throw ANNEXCEPTION("Error reading from stream!"); + } + return sizeof(size); + } + + static size_type write_name(std::ostream& os, const std::string& name) { + auto bytes = write_size(os, name.size()); + os.write(name.data(), name.size()); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + return bytes + name.size(); + } + + static size_type read_name(std::istream& is, std::string& name) { + size_type size = 0; + auto bytes = read_size(is, size); + name.resize(size); + is.read(name.data(), size); + if (!is) { + throw ANNEXCEPTION("Error reading from stream!"); + } + return bytes + size; + } + + static void read_from_istream(std::istream& in, std::ostream& out, size_t data_size) { + // Copy the data in chunks. + constexpr size_t buffer_size = 1 << 13; // 8KB buffer + alignas(CACHELINE_BYTES) char buffer[buffer_size]; + + size_t bytes_remaining = data_size; + while (bytes_remaining > 0) { + size_t to_read = std::min(buffer_size, bytes_remaining); + in.read(buffer, to_read); + if (!in) { + throw ANNEXCEPTION("Error reading from stream!"); + } + out.write(buffer, to_read); + bytes_remaining -= to_read; + } + } +}; + +} // namespace svs::lib diff --git a/include/svs/lib/file.h b/include/svs/lib/file.h index 937bc1502..c4df66099 100644 --- a/include/svs/lib/file.h +++ b/include/svs/lib/file.h @@ -17,6 +17,7 @@ #pragma once // svs +#include "svs/lib/archiver.h" #include "svs/lib/exception.h" #include "svs/lib/uuid.h" @@ -151,50 +152,9 @@ struct UniqueTempDirectory { // Uses a simple custom binary format. // Not meant to be super efficient, just a simple way to serialize a directory // structure to a stream. -struct DirectoryArchiver { - using size_type = uint64_t; - - // TODO: Define CACHELINE_BYTES in a common place - // rather than duplicating it here and in prefetch.h - static constexpr auto CACHELINE_BYTES = 64; +struct DirectoryArchiver : Archiver { static constexpr size_type magic_number = 0x5e2d58d9f3b4a6c1; - static size_type write_size(std::ostream& os, size_type size) { - os.write(reinterpret_cast(&size), sizeof(size)); - if (!os) { - throw ANNEXCEPTION("Error writing to stream!"); - } - return sizeof(size); - } - - static size_type read_size(std::istream& is, size_type& size) { - is.read(reinterpret_cast(&size), sizeof(size)); - if (!is) { - throw ANNEXCEPTION("Error reading from stream!"); - } - return sizeof(size); - } - - static size_type write_name(std::ostream& os, const std::string& name) { - auto bytes = write_size(os, name.size()); - os.write(name.data(), name.size()); - if (!os) { - throw ANNEXCEPTION("Error writing to stream!"); - } - return bytes + name.size(); - } - - static size_type read_name(std::istream& is, std::string& name) { - size_type size = 0; - auto bytes = read_size(is, size); - name.resize(size); - is.read(name.data(), size); - if (!is) { - throw ANNEXCEPTION("Error reading from stream!"); - } - return bytes + size; - } - static size_type write_file( std::ostream& stream, const std::filesystem::path& path, @@ -262,22 +222,9 @@ struct DirectoryArchiver { throw ANNEXCEPTION("Error opening file {} for writing!", path); } - // Copy the data in chunks. - constexpr size_t buffer_size = 1 << 13; // 8KB buffer - alignas(CACHELINE_BYTES) char buffer[buffer_size]; - - size_t bytes_remaining = filesize; - while (bytes_remaining > 0) { - size_t to_read = std::min(buffer_size, bytes_remaining); - stream.read(buffer, to_read); - if (!stream) { - throw ANNEXCEPTION("Error reading from stream!"); - } - out.write(buffer, to_read); - if (!out) { - throw ANNEXCEPTION("Error writing to file {}!", path); - } - bytes_remaining -= to_read; + read_from_istream(stream, out, filesize); + if (!out) { + throw ANNEXCEPTION("Error writing to file {}!", path); } return header_bytes + filesize; diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index 767e02afa..c848e80c0 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -22,6 +22,11 @@ // stl #include +#include "svs/core/io/native.h" +#include "svs/lib/file.h" +#include "svs/lib/readwrite.h" +#include "svs/lib/stream.h" + namespace svs::lib { /// @@ -828,6 +833,42 @@ inline SerializedObject begin_deserialization(const std::filesystem::path& fullp std::move(table), lib::LoadContext{fullpath.parent_path(), version}}; } +inline ContextFreeSerializedObject begin_deserialization(std::istream& stream) { + lib::StreamArchiver::size_type magic = 0; + lib::StreamArchiver::read_size(stream, magic); + if (magic == lib::DirectoryArchiver::magic_number) { + // Backward compatibility mode for older versions: + // Previously, SVS serialized models using an intermediate file, + // so some dummy information was added to the stream. + lib::StreamArchiver::size_type num_files = 0; + lib::StreamArchiver::read_size(stream, num_files); + + std::string file_name; + lib::StreamArchiver::read_name(stream, file_name); + } else if (magic != lib::StreamArchiver::magic_number) { + throw ANNEXCEPTION("Invalid magic number in stream deserialization!"); + } + + if (!stream) { + throw ANNEXCEPTION("Error reading from stream!"); + } + + auto table = lib::StreamArchiver::read_table(stream); + + if (magic == lib::DirectoryArchiver::magic_number) { + // Backward compatibility mode for older versions: + // Previously, SVS serialized models using an intermediate file, + // so some dummy information was added to the stream. + std::string file_name; + lib::StreamArchiver::read_name(stream, file_name); + + lib::StreamArchiver::size_type file_size = 0; + lib::StreamArchiver::read_size(stream, file_size); + lib::read_binary(stream); + } + return ContextFreeSerializedObject{std::move(table)}; +} + } // namespace detail inline SerializedObject begin_deserialization(const std::filesystem::path& path) { @@ -877,6 +918,21 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) { return lib::load_from_disk(Loader(), path, SVS_FWD(args)...); } +///// load_from_stream +template +T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) { + // At this point, we will try the saving/loading framework to load the object. + // Here we go! + return lib::load( + loader, detail::begin_deserialization(stream), stream, SVS_FWD(args)... + ); +} + +template +T load_from_stream(std::istream& stream, Args&&... args) { + return lib::load_from_stream(Loader(), stream, SVS_FWD(args)...); +} + ///// load_from_file template diff --git a/include/svs/lib/saveload/save.h b/include/svs/lib/saveload/save.h index 60f556e77..fe7694c45 100644 --- a/include/svs/lib/saveload/save.h +++ b/include/svs/lib/saveload/save.h @@ -22,6 +22,7 @@ // svs #include "svs/lib/file.h" #include "svs/lib/readwrite.h" +#include "svs/lib/stream.h" #include "svs/lib/version.h" // stl @@ -319,6 +320,17 @@ void save_node_to_file( auto file = svs::lib::open_write(path, std::ios_base::out); file << top_table << "\n"; } + +template +void save_node_to_stream( + Nodelike&& node, std::ostream& os, const lib::Version& version = CURRENT_SAVE_VERSION +) { + auto top_table = toml::table( + {{config_version_key, version.str()}, {config_object_key, SVS_FWD(node)}} + ); + + StreamArchiver::write_table(os, top_table); +} } // namespace detail /// @@ -365,4 +377,13 @@ template void save_to_file(const T& x, const std::filesystem::path& detail::save_node_to_file(lib::save(x), path); } +template void save_to_stream(const T& x, std::ostream& os) { + lib::StreamArchiver::write_size(os, lib::StreamArchiver::magic_number); + + auto save_table = x.save_table(); + detail::save_node_to_stream(detail::exit_hook(save_table), os); + + x.save(os); +} + } // namespace svs::lib diff --git a/include/svs/lib/stream.h b/include/svs/lib/stream.h new file mode 100644 index 000000000..9b9325335 --- /dev/null +++ b/include/svs/lib/stream.h @@ -0,0 +1,73 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// saveload +#include "svs/lib/saveload/core.h" + +// svs +#include "svs/lib/archiver.h" +#include "svs/lib/exception.h" + +// stl +#include +#include + +namespace svs::lib::detail { +template auto get_buffer_size(T& ss) { + if constexpr (requires { ss.rdbuf()->view(); }) { + return ss.rdbuf()->view().size(); + } else { + return ss.str().size(); + } +} +} // namespace svs::lib::detail + +namespace svs::lib { + +struct StreamArchiver : Archiver { + // SVS_STRM + static constexpr size_type magic_number = 0x5356535f5354524d; + + static auto read_table(std::istream& is) { + std::uint64_t tablesize = 0; + read_size(is, tablesize); + + std::stringstream ss; + read_from_istream(is, ss, tablesize); + + return toml::parse(ss); + } + + static void write_table(std::ostream& os, const toml::table& table) { + std::stringstream ss; + ss << table << "\n"; + + // The best way to get the table size is a c++20 feature: + // ss.rdbuf()->view().size(), + // but Apple's Clang 15 doesn't support std::stringbuf::view() + lib::StreamArchiver::size_type tablesize = detail::get_buffer_size(ss); + + lib::StreamArchiver::write_size(os, tablesize); + os << ss.rdbuf(); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + } +}; + +} // namespace svs::lib diff --git a/include/svs/orchestrators/exhaustive.h b/include/svs/orchestrators/exhaustive.h index b33b6cc4a..7fa969ead 100644 --- a/include/svs/orchestrators/exhaustive.h +++ b/include/svs/orchestrators/exhaustive.h @@ -24,6 +24,7 @@ #include "svs/core/distance.h" #include "svs/core/graph.h" #include "svs/lib/preprocessor.h" +#include "svs/lib/stream.h" #include "svs/lib/threads.h" #include "svs/orchestrators/manager.h" @@ -87,9 +88,7 @@ class FlatImpl : public manager::ManagerImpl { void save(std::ostream& stream) const override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_flat_save"}; - save(tempdir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); } @@ -196,11 +195,8 @@ class Flat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_flat_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); return assemble( - lib::load_from_disk(tempdir, SVS_FWD(data_args)...), + lib::load_from_stream(stream, SVS_FWD(data_args)...), distance, threads::as_threadpool(std::move(threadpool_proto)) ); diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp index 57f821781..d09532d8b 100644 --- a/tests/svs/index/flat/flat.cpp +++ b/tests/svs/index/flat/flat.cpp @@ -16,6 +16,11 @@ #include "svs/index/flat/flat.h" #include "svs/core/logging.h" +#include "svs/lib/file.h" +#include "svs/lib/saveload/load.h" + +// tests +#include "tests/utils/test_dataset.h" // catch2 #include "catch2/catch_test_macros.hpp" @@ -66,3 +71,76 @@ CATCH_TEST_CASE("FlatIndex Logging Test", "[logging]") { CATCH_REQUIRE(captured_logs.size() == 1); CATCH_REQUIRE(captured_logs[0] == "Test FlatIndex Logging"); } + +CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { + using Data_t = svs::data::SimpleData; + using Distance_t = svs::distance::DistanceL2; + using Index_t = svs::index::flat::FlatIndex; + + // Load test data + auto data = Data_t::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + // Build index + Distance_t dist; + Index_t index = Index_t(std::move(data), dist, svs::threads::DefaultThreadPool(1)); + + size_t num_neighbors = 10; + auto results = svs::QueryResult(queries.size(), num_neighbors); + index.search(results.view(), queries.cview(), {}); + + CATCH_SECTION("Load Flat being serialized natively to stream") { + std::stringstream ss; + index.save(ss); + + Index_t loaded_index = Index_t( + svs::lib::load_from_stream(ss), dist, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + + CATCH_SECTION("Load Flat being serialized with intermediate files") { + std::stringstream ss; + + svs::lib::UniqueTempDirectory tempdir{"svs_flat_save"}; + index.save(tempdir); + svs::lib::DirectoryArchiver::pack(tempdir, ss); + + Index_t loaded_index = Index_t( + svs::lib::load_from_stream(ss), dist, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } +} From 4c0f19b75c6f9c4d1b599bc0fea854999224ff13 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Fri, 6 Mar 2026 12:18:50 +0100 Subject: [PATCH 2/7] Native serialization for DynamicFlat index (#281) This PR introduce native serialization for DynamicFlat index. Main changes are: 1. `auto_dynamic_assemble` now accepts lazy loader. That is mandatory for buffer-free deserialization. 2. new class `Deserializer' is introduced. It is responsible for conditional reading of overhead data (like names of temporary files) in case of legacy models. 3. `IDTranslator` is refactored to cover save and load to/from stream. --------- Co-authored-by: Dmitry Razdoburdin --- include/svs/core/data/simple.h | 21 +++- include/svs/core/translation.h | 66 ++++++++--- include/svs/index/flat/dynamic_flat.h | 84 ++++++++++++++ include/svs/index/flat/flat.h | 5 +- include/svs/lib/saveload/load.h | 108 ++++++++++++----- include/svs/lib/saveload/save.h | 17 ++- include/svs/orchestrators/dynamic_flat.h | 46 +++----- include/svs/orchestrators/exhaustive.h | 3 +- tests/svs/index/flat/dynamic_flat.cpp | 140 +++++++++++++++++++++++ tests/svs/index/flat/flat.cpp | 10 +- 10 files changed, 411 insertions(+), 89 deletions(-) diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index cf3c5df24..890cda372 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -136,8 +136,12 @@ class GenericSerializer { } template F> - static lib::lazy_result_t - load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) { + static lib::lazy_result_t load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is, + const F& lazy + ) { auto datatype = lib::load_at(table, "eltype"); if (datatype != datatype_v) { throw ANNEXCEPTION( @@ -151,6 +155,10 @@ class GenericSerializer { size_t num_vectors = lib::load_at(table, "num_vectors"); size_t dims = lib::load_at(table, "dims"); + deserializer.read_name(is); + deserializer.read_size(is); + deserializer.read_binary(is); + return io::load_dataset(is, lazy, num_vectors, dims); } }; @@ -474,13 +482,14 @@ class SimpleData { static SimpleData load( const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, std::istream& is, const allocator_type& allocator = {} ) requires(!is_view) { return GenericSerializer::load( - table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { + table, deserializer, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); }) ); @@ -879,11 +888,15 @@ class SimpleData> { static SimpleData load( const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, std::istream& is, const Blocked& allocator = {} ) { return GenericSerializer::load( - table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { + table, + deserializer, + is, + lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); }) ); diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index a3c4bca34..1b1fde9c6 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -324,27 +324,36 @@ class IDTranslator { "external_to_internal_translation"; static constexpr lib::Version save_version = lib::Version(0, 0, 0); - lib::SaveTable save(const lib::SaveContext& ctx) const { - auto filename = ctx.generate_name("id_translation", "binary"); - // Save the translations to a file. - auto stream = lib::open_write(filename); - for (auto i = begin(), iend = end(); i != iend; ++i) { - // N.B.: Apparently `std::pair` of integers is not trivially copyable ... - lib::write_binary(stream, i->first); - lib::write_binary(stream, i->second); - } + lib::SaveTable save_table() const { return lib::SaveTable( serialization_schema, save_version, {{"kind", kind}, {"num_points", lib::save(size())}, {"external_id_type", lib::save(datatype_v)}, - {"internal_id_type", lib::save(datatype_v)}, - {"filename", lib::save(filename.filename())}} + {"internal_id_type", lib::save(datatype_v)}} ); } - static IDTranslator load(const lib::LoadTable& table) { + void save(std::ostream& os) const { + for (auto i = begin(), iend = end(); i != iend; ++i) { + // N.B.: Apparently `std::pair` of integers is not trivially copyable ... + lib::write_binary(os, i->first); + lib::write_binary(os, i->second); + } + } + + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto filename = ctx.generate_name("id_translation", "binary"); + // Save the translations to a file. + auto os = lib::open_write(filename); + save(os); + auto table = save_table(); + table.insert("filename", lib::save(filename.filename())); + return table; + } + + static void validate(const lib::ContextFreeLoadTable& table) { if (kind != lib::load_at(table, "kind")) { throw ANNEXCEPTION("Mismatched kind!"); } @@ -357,21 +366,42 @@ class IDTranslator { if (internal_id_name != lib::load_at(table, "internal_id_type")) { throw ANNEXCEPTION("Mismatched internal id types!"); } + } - // Now that we've more-or-less validated the metadata, time to start loading - // the points. + static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) { auto num_points = lib::load_at(table, "num_points"); + auto translator = IDTranslator{}; - auto resolved = table.resolve_at("filename"); - auto stream = lib::open_read(resolved); for (size_t i = 0; i < num_points; ++i) { - auto external_id = lib::read_binary(stream); - auto internal_id = lib::read_binary(stream); + auto external_id = lib::read_binary(is); + auto internal_id = lib::read_binary(is); translator.insert_translation(external_id, internal_id); } return translator; } + static IDTranslator load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is + ) { + IDTranslator::validate(table); + deserializer.read_name(is); + deserializer.read_size(is); + + return IDTranslator::load(table, is); + } + + static IDTranslator load(const lib::LoadTable& table) { + IDTranslator::validate(table); + + // Now that we've more-or-less validated the metadata, time to start loading + // the points. + auto resolved = table.resolve_at("filename"); + auto is = lib::open_read(resolved); + return IDTranslator::load(table, is); + } + private: template void check( diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 868054ba1..26946a220 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -36,6 +36,7 @@ #include "svs/lib/invoke.h" #include "svs/lib/misc.h" #include "svs/lib/preprocessor.h" +#include "svs/lib/stream.h" #include "svs/lib/threads.h" namespace svs::index::flat { @@ -403,6 +404,26 @@ template class DynamicFlatIndex { // Save the dataset in the separate data directory lib::save_to_disk(data_, data_directory); } + + void save(std::ostream& os) { + compact(); + + lib::begin_serialization(os); + // Save data structures and translation to config directory + lib::SaveTable save_table = lib::SaveTable( + "dynamic_flat_config", + save_version, + { + {"name", name()}, + {"translation", lib::detail::exit_hook(translator_.save_table())}, + } + ); + lib::save_to_stream(save_table, os); + translator_.save(os); + + lib::save_to_stream(data_, os); + } + constexpr std::string_view name() const { return "dynamic flat index"; } ///// Thread Pool Management @@ -767,4 +788,67 @@ auto auto_dynamic_assemble( ); } +auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) { + auto table = lib::detail::begin_deserialization(deserializer, is); + auto translator = IDTranslator::load( + table.template cast().at("translation").template cast(), + deserializer, + is + ); + return translator; +} + +template +auto auto_dynamic_assemble( + const lib::detail::Deserializer& deserializer, + std::istream& is, + LazyDataLoader&& data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + // Set this to `true` to use the identity map for ID translation. + // This allows us to read files generated by the static index construction routines + // to easily benchmark the static versus dynamic implementation. + // + // This is an internal API and should not be considered officially supported nor stable. + bool SVS_UNUSED(debug_load_from_static) = false, + svs::logging::logger_ptr logger = svs::logging::get() +) { + IDTranslator translator; + // In legacy deserialization the order of directories isn't determined. + auto name = deserializer.read_name_in_advance(is); + + // We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy + // model was serialized + bool translator_before_data = + (name == "config/svs_config.toml") || deserializer.is_native(); + if (translator_before_data) { + translator = load_translator(deserializer, is); + } + + // Load the dataset + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto data = svs::detail::dispatch_load(data_loader(), threadpool); + auto datasize = data.size(); + + if (!translator_before_data) { + translator = load_translator(deserializer, is); + } + + // Validate the translator + auto translator_size = translator.size(); + if (translator_size != datasize) { + throw ANNEXCEPTION( + "Translator has {} IDs but should have {}", translator_size, datasize + ); + } + + return DynamicFlatIndex( + std::move(data), + std::move(translator), + std::move(distance), + std::move(threadpool), + std::move(logger) + ); +} + } // namespace svs::index::flat diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index d81e7cf54..925fe68a5 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -523,7 +523,10 @@ class FlatIndex { lib::save_to_disk(data_, data_directory); } - void save(std::ostream& os) const { lib::save_to_stream(data_, os); } + void save(std::ostream& os) const { + lib::begin_serialization(os); + lib::save_to_stream(data_, os); + } }; /// diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index c848e80c0..09e072dee 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -833,39 +833,78 @@ inline SerializedObject begin_deserialization(const std::filesystem::path& fullp std::move(table), lib::LoadContext{fullpath.parent_path(), version}}; } -inline ContextFreeSerializedObject begin_deserialization(std::istream& stream) { - lib::StreamArchiver::size_type magic = 0; - lib::StreamArchiver::read_size(stream, magic); - if (magic == lib::DirectoryArchiver::magic_number) { - // Backward compatibility mode for older versions: - // Previously, SVS serialized models using an intermediate file, - // so some dummy information was added to the stream. - lib::StreamArchiver::size_type num_files = 0; - lib::StreamArchiver::read_size(stream, num_files); - - std::string file_name; - lib::StreamArchiver::read_name(stream, file_name); - } else if (magic != lib::StreamArchiver::magic_number) { - throw ANNEXCEPTION("Invalid magic number in stream deserialization!"); +class Deserializer { + enum SerializationScheme { native, legacy }; + SerializationScheme scheme_; + + mutable bool skip_next_name_ = false; + + explicit Deserializer(const SerializationScheme& scheme) + : scheme_(scheme) {} + + public: + static Deserializer build(std::istream& stream) { + lib::StreamArchiver::size_type magic = 0; + lib::StreamArchiver::read_size(stream, magic); + if (magic == lib::StreamArchiver::magic_number) { + return Deserializer(SerializationScheme::native); + } else if (magic == lib::DirectoryArchiver::magic_number) { + // Backward compatibility mode for older versions: + // Previously, SVS serialized models using an intermediate file, + // so some dummy information was added to the stream. + lib::StreamArchiver::size_type num_files = 0; + lib::StreamArchiver::read_size(stream, num_files); + + return Deserializer(SerializationScheme::legacy); + } else { + throw ANNEXCEPTION("Invalid magic number in stream deserialization!"); + } + } + + bool is_native() const { return scheme_ == SerializationScheme::native; } + + std::string read_name_in_advance(std::istream& stream) const { + std::string name; + if (scheme_ == SerializationScheme::legacy) { + lib::StreamArchiver::read_name(stream, name); + skip_next_name_ = true; + } + return name; + } + + void read_name(std::istream& stream) const { + if (scheme_ == SerializationScheme::legacy) { + if (!skip_next_name_) { + std::string name; + lib::StreamArchiver::read_name(stream, name); + } + skip_next_name_ = false; + } } + void read_size(std::istream& stream) const { + if (scheme_ == SerializationScheme::legacy) { + lib::StreamArchiver::size_type size = 0; + lib::StreamArchiver::read_size(stream, size); + } + } + + template void read_binary(std::istream& stream) const { + if (scheme_ == SerializationScheme::legacy) { + lib::read_binary(stream); + } + } +}; + +inline ContextFreeSerializedObject +begin_deserialization(const Deserializer& deserializer, std::istream& stream) { + deserializer.read_name(stream); if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); } auto table = lib::StreamArchiver::read_table(stream); - if (magic == lib::DirectoryArchiver::magic_number) { - // Backward compatibility mode for older versions: - // Previously, SVS serialized models using an intermediate file, - // so some dummy information was added to the stream. - std::string file_name; - lib::StreamArchiver::read_name(stream, file_name); - - lib::StreamArchiver::size_type file_size = 0; - lib::StreamArchiver::read_size(stream, file_size); - lib::read_binary(stream); - } return ContextFreeSerializedObject{std::move(table)}; } @@ -920,17 +959,28 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) { ///// load_from_stream template -T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) { +T load_from_stream( + const Loader& loader, + const detail::Deserializer& deserializer, + std::istream& stream, + Args&&... args +) { // At this point, we will try the saving/loading framework to load the object. // Here we go! return lib::load( - loader, detail::begin_deserialization(stream), stream, SVS_FWD(args)... + loader, + detail::begin_deserialization(deserializer, stream), + deserializer, + stream, + SVS_FWD(args)... ); } template -T load_from_stream(std::istream& stream, Args&&... args) { - return lib::load_from_stream(Loader(), stream, SVS_FWD(args)...); +T load_from_stream( + const detail::Deserializer& deserializer, std::istream& stream, Args&&... args +) { + return lib::load_from_stream(Loader(), deserializer, stream, SVS_FWD(args)...); } ///// load_from_file diff --git a/include/svs/lib/saveload/save.h b/include/svs/lib/saveload/save.h index fe7694c45..c609151e1 100644 --- a/include/svs/lib/saveload/save.h +++ b/include/svs/lib/saveload/save.h @@ -377,13 +377,20 @@ template void save_to_file(const T& x, const std::filesystem::path& detail::save_node_to_file(lib::save(x), path); } -template void save_to_stream(const T& x, std::ostream& os) { +inline void begin_serialization(std::ostream& os) { lib::StreamArchiver::write_size(os, lib::StreamArchiver::magic_number); +} - auto save_table = x.save_table(); - detail::save_node_to_stream(detail::exit_hook(save_table), os); - - x.save(os); +template void save_to_stream(const T& x, std::ostream& os) { + if constexpr (requires { x.save_table(); }) { + auto save_table = x.save_table(); + detail::save_node_to_stream(detail::exit_hook(save_table), os); + x.save(os); + } else if constexpr (std::is_same_v) { + detail::save_node_to_stream(detail::exit_hook(x), os); + } else { + static_assert(sizeof(T) == 0, "Type not stream-serializable"); + } } } // namespace svs::lib diff --git a/include/svs/orchestrators/dynamic_flat.h b/include/svs/orchestrators/dynamic_flat.h index e06efb451..250edf10c 100644 --- a/include/svs/orchestrators/dynamic_flat.h +++ b/include/svs/orchestrators/dynamic_flat.h @@ -123,13 +123,7 @@ class DynamicFlatImpl // Stream-based save implementation void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current DynamicFlat backend doesn't support saving!"); } @@ -282,28 +276,22 @@ class DynamicFlat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_dynflat_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION( - "Invalid Dynamic Flat index archive: missing config directory!" - ); - } - - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Dynamic Flat index archive: missing data directory!" - ); - } - - return assemble( - config_path, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) + auto deserializer = svs::lib::detail::Deserializer::build(stream); + return DynamicFlat( + AssembleTag(), + manager::as_typelist(), + index::flat::auto_dynamic_assemble( + deserializer, + stream, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ) ); } diff --git a/include/svs/orchestrators/exhaustive.h b/include/svs/orchestrators/exhaustive.h index 7fa969ead..bf46c1347 100644 --- a/include/svs/orchestrators/exhaustive.h +++ b/include/svs/orchestrators/exhaustive.h @@ -195,8 +195,9 @@ class Flat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { + auto deserializer = svs::lib::detail::Deserializer::build(stream); return assemble( - lib::load_from_stream(stream, SVS_FWD(data_args)...), + lib::load_from_stream(deserializer, stream, SVS_FWD(data_args)...), distance, threads::as_threadpool(std::move(threadpool_proto)) ); diff --git a/tests/svs/index/flat/dynamic_flat.cpp b/tests/svs/index/flat/dynamic_flat.cpp index f9f99e702..76c4d4c65 100644 --- a/tests/svs/index/flat/dynamic_flat.cpp +++ b/tests/svs/index/flat/dynamic_flat.cpp @@ -214,3 +214,143 @@ CATCH_TEST_CASE("Testing Flat Index", "[dynamic_flat]") { test_loop(index, reference, queries, div(reference.size(), modify_fraction), 2, 6); } + +CATCH_TEST_CASE("DynamicFlat Index Save and Load", "[dynamic_flat][index][saveload]") { +#if defined(NDEBUG) + const float initial_fraction = 0.25; + const float modify_fraction = 0.05; +#else + const float initial_fraction = 0.05; + const float modify_fraction = 0.005; +#endif + const size_t num_threads = 10; + + // Load the base dataset and queries. + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto num_points = data.size(); + auto queries = test_dataset::queries(); + + auto reference = svs::misc::ReferenceDataset( + std::move(data), + Distance(), + num_threads, + div(num_points, 0.5 * modify_fraction), + NUM_NEIGHBORS, + queries, + 0x12345678 + ); + + auto num_indices_to_add = div(reference.size(), initial_fraction); + + // Construct a blocked dataset consisting of initial fraction of the base dataset. + auto data_mutable = svs::data::BlockedData(num_indices_to_add, N); + std::vector initial_indices{}; + { + auto [vectors, indices] = reference.generate(num_indices_to_add); + // Copy assign ``initial_indices`` + auto num_points_added = indices.size(); + CATCH_REQUIRE(vectors.size() == num_points_added); + CATCH_REQUIRE(num_points_added <= num_indices_to_add); + CATCH_REQUIRE(num_points_added > num_indices_to_add - reference.bucket_size()); + + initial_indices = indices; + if (vectors.size() != num_indices_to_add || indices.size() != num_indices_to_add) { + throw ANNEXCEPTION("Something when horribly wrong!"); + } + + for (size_t i = 0; i < num_indices_to_add; ++i) { + data_mutable.set_datum(i, vectors.get_datum(i)); + } + } + + using Data_t = svs::data::BlockedData; + using Distance_t = svs::distance::DistanceL2; + using Index_t = svs::index::flat::DynamicFlatIndex; + + Distance_t dist; + auto index = Index_t(std::move(data_mutable), initial_indices, dist, num_threads); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), {}); + + reference.configure_extra_checks(true); + CATCH_REQUIRE(reference.extra_checks_enabled()); + + CATCH_SECTION("Load DynamicFlat being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + Index_t loaded_index = svs::index::flat::auto_dynamic_assemble( + deserializer, + stream, + // lazy-loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + dist, + svs::threads::as_threadpool(num_threads) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load DynamicFlat being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + Index_t loaded_index = svs::index::flat::auto_dynamic_assemble( + deserializer, + stream, + // lazy-loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + dist, + svs::threads::as_threadpool(num_threads) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp index d09532d8b..29879b664 100644 --- a/tests/svs/index/flat/flat.cpp +++ b/tests/svs/index/flat/flat.cpp @@ -93,8 +93,11 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { std::stringstream ss; index.save(ss); + auto deserializer = svs::lib::detail::Deserializer::build(ss); Index_t loaded_index = Index_t( - svs::lib::load_from_stream(ss), dist, svs::threads::DefaultThreadPool(1) + svs::lib::load_from_stream(deserializer, ss), + dist, + svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); @@ -122,8 +125,11 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { index.save(tempdir); svs::lib::DirectoryArchiver::pack(tempdir, ss); + auto deserializer = svs::lib::detail::Deserializer::build(ss); Index_t loaded_index = Index_t( - svs::lib::load_from_stream(ss), dist, svs::threads::DefaultThreadPool(1) + svs::lib::load_from_stream(deserializer, ss), + dist, + svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); From 620ac9f93d01e3dfdf6796b262994b97248e87b9 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Fri, 13 Mar 2026 15:44:55 +0100 Subject: [PATCH 3/7] Native serialization for Vamana index (#285) This PR introduce native serialization for `Vamana` index. Main changes are: 1. New overload of `svs::index::vamana::auto_assemble` required for direct deserialization accepts lazy loaders and call them in a flexible order to cover legacy serilized models. 2. `save_table` method is renamed to `metadata` to avoid confusions, as far as it doesn't save anything. 3. Some minor refactoring to streamline the logic. 4. Serialization for `MutableVamana` is also implemented to avoid compilation errors. Deserialization and related tests for `MutableVamana` are expected later. --------- Co-authored-by: Dmitry Razdoburdin Co-authored-by: Rafik Saliev --- include/svs/core/data/simple.h | 12 +-- include/svs/core/graph/graph.h | 79 +++++++++++++-- include/svs/core/translation.h | 15 +-- include/svs/index/flat/dynamic_flat.h | 60 ++++++----- include/svs/index/vamana/dynamic_index.h | 49 ++++++--- include/svs/index/vamana/index.h | 80 ++++++++++++++- include/svs/lib/saveload/load.h | 31 +++++- include/svs/lib/saveload/save.h | 17 ++-- include/svs/orchestrators/vamana.h | 73 +++++++------- include/svs/quantization/scalar/scalar.h | 25 +++++ tests/svs/index/vamana/index.cpp | 121 +++++++++++++++++++++++ 11 files changed, 451 insertions(+), 111 deletions(-) diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 890cda372..8f15ff717 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -75,7 +75,7 @@ class GenericSerializer { } template - static lib::SaveTable save_table(const Data& data) { + static lib::SaveTable metadata(const Data& data) { using T = typename Data::element_type; auto table = lib::SaveTable( serialization_schema, @@ -92,8 +92,8 @@ class GenericSerializer { template static lib::SaveTable - save_table(const Data& data, const FileName_t& filename, const lib::UUID& uuid) { - auto table = save_table(data); + metadata(const Data& data, const FileName_t& filename, const lib::UUID& uuid) { + auto table = metadata(data); table.insert("binary_file", filename); table.insert("uuid", uuid.str()); return table; @@ -105,7 +105,7 @@ class GenericSerializer { auto uuid = lib::UUID{}; auto filename = ctx.generate_name("data"); io::save(data, io::NativeFile(filename), uuid); - return save_table(data, lib::save(filename.filename()), uuid); + return metadata(data, lib::save(filename.filename()), uuid); } template @@ -452,7 +452,7 @@ class SimpleData { void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } - lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); } + lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); } static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); @@ -871,7 +871,7 @@ class SimpleData> { void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } - lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); } + lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); } static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); diff --git a/include/svs/core/graph/graph.h b/include/svs/core/graph/graph.h index 48d3e8048..68357d157 100644 --- a/include/svs/core/graph/graph.h +++ b/include/svs/core/graph/graph.h @@ -276,22 +276,36 @@ template class SimpleGrap ///// Saving static constexpr lib::Version save_version = lib::Version(0, 0, 0); static constexpr std::string_view serialization_schema = "default_graph"; - lib::SaveTable save(const lib::SaveContext& ctx) const { - auto uuid = lib::UUID{}; - auto filename = ctx.generate_name("graph"); - io::save(data_, io::NativeFile(filename), uuid); - return lib::SaveTable( + + lib::SaveTable metadata() const { + auto table = lib::SaveTable( serialization_schema, save_version, {{"name", "graph"}, - {"binary_file", lib::save(filename.filename())}, {"max_degree", lib::save(max_degree())}, {"num_vertices", lib::save(n_nodes())}, - {"uuid", lib::save(uuid.str())}, {"eltype", lib::save(datatype_v)}} ); + return table; + } + + template + lib::SaveTable metadata(const FileName& filename, const lib::UUID& uuid) const { + auto table = metadata(); + table.insert("binary_file", filename); + table.insert("uuid", uuid.str()); + return table; + } + + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto uuid = lib::UUID{}; + auto filename = ctx.generate_name("graph"); + io::save(data_, io::NativeFile(filename), uuid); + return metadata(lib::save(filename.filename()), uuid); } + void save(std::ostream& os) const { io::save(data_, os); } + protected: template F, typename... Args> static lib::lazy_result_t @@ -317,6 +331,39 @@ template class SimpleGrap return lazy(data_type::load(binaryfile.value(), std::forward(args)...)); } + template F, typename... AllocArgs> + static lib::lazy_result_t load( + const lib::ContextFreeLoadTable& table, + const F& lazy, + const lib::detail::Deserializer& deserializer, + std::istream& is, + AllocArgs&&... alloc_args + ) { + // Perform a sanity check on the element type. + // Make sure we're loading the correct kind. + auto eltype = lib::load_at(table, "eltype"); + if (eltype != datatype_v) { + throw ANNEXCEPTION( + "Trying to load a graph with adjacency list types {} to a graph with " + "adjacency list types {}.", + name(eltype), + name>() + ); + } + + size_t num_vertices = lib::load_at(table, "num_vertices"); + size_t max_degree = lib::load_at(table, "max_degree"); + + // Skip legacy stream metadata (no-op for native serialization). + deserializer.read_name(is); + deserializer.read_size(is); + deserializer.read_binary(is); + + auto data = data_type(num_vertices, max_degree + 1, alloc_args...); + io::populate(is, data); + return lazy(std::move(data)); + } + protected: data_type data_; Idx max_degree_; @@ -366,6 +413,16 @@ class SimpleGraph : public SimpleGraphBase(deserializer, is, allocator); + } }; template diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index 1b1fde9c6..ae2dc4880 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -380,16 +380,17 @@ class IDTranslator { return translator; } - static IDTranslator load( - const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, - std::istream& is - ) { - IDTranslator::validate(table); + static IDTranslator + load(const lib::detail::Deserializer& deserializer, std::istream& is) { + auto table = lib::detail::read_metadata(deserializer, is); + auto translation = table.template cast() + .at("translation") + .template cast(); + IDTranslator::validate(translation); deserializer.read_name(is); deserializer.read_size(is); - return IDTranslator::load(table, is); + return IDTranslator::load(translation, is); } static IDTranslator load(const lib::LoadTable& table) { diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 26946a220..849b80b7f 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -788,16 +788,6 @@ auto auto_dynamic_assemble( ); } -auto load_translator(const lib::detail::Deserializer& deserializer, std::istream& is) { - auto table = lib::detail::begin_deserialization(deserializer, is); - auto translator = IDTranslator::load( - table.template cast().at("translation").template cast(), - deserializer, - is - ); - return translator; -} - template auto auto_dynamic_assemble( const lib::detail::Deserializer& deserializer, @@ -813,38 +803,44 @@ auto auto_dynamic_assemble( bool SVS_UNUSED(debug_load_from_static) = false, svs::logging::logger_ptr logger = svs::logging::get() ) { - IDTranslator translator; - // In legacy deserialization the order of directories isn't determined. - auto name = deserializer.read_name_in_advance(is); - - // We have to hardcode the file_name for legacy mode, since it was hardcoded when legacy - // model was serialized - bool translator_before_data = - (name == "config/svs_config.toml") || deserializer.is_native(); - if (translator_before_data) { - translator = load_translator(deserializer, is); - } - - // Load the dataset - auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); - auto data = svs::detail::dispatch_load(data_loader(), threadpool); - auto datasize = data.size(); - - if (!translator_before_data) { - translator = load_translator(deserializer, is); + using Data = decltype(data_loader()); + auto config_loader = [&] { return IDTranslator::load(deserializer, is); }; + + std::optional config; + std::optional data; + + if (deserializer.is_native()) { + // Order is always config->data. + config.emplace(config_loader()); + data.emplace(data_loader()); + } else { + // Directory packing order is filesystem-dependent. + // Read 2 data blocks: config and data in a corresponding order. + for (int data_block_idx = 0; data_block_idx < 2; ++data_block_idx) { + auto name = deserializer.read_name_in_advance(is); + if (name.starts_with("config/")) { + config.emplace(config_loader()); + } else if (name.starts_with("data/")) { + data.emplace(data_loader()); + } else { + throw ANNEXCEPTION("The stream is corrupted!"); + } + } } // Validate the translator - auto translator_size = translator.size(); + auto translator_size = config->size(); + auto datasize = data->size(); if (translator_size != datasize) { throw ANNEXCEPTION( "Translator has {} IDs but should have {}", translator_size, datasize ); } + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); return DynamicFlatIndex( - std::move(data), - std::move(translator), + std::move(*data), + std::move(*config), std::move(distance), std::move(threadpool), std::move(logger) diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 169be1995..3a1778fe1 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -988,6 +988,18 @@ class MutableVamanaIndex { ///// Saving + VamanaIndexParameters parameters() const { + return { + entry_point_.front(), + {alpha_, + graph_.max_degree(), + get_construction_window_size(), + get_max_candidates(), + prune_to_, + get_full_search_history()}, + get_search_parameters()}; + } + static constexpr lib::Version save_version = lib::Version(0, 0, 0); void save( const std::filesystem::path& config_directory, @@ -1003,22 +1015,12 @@ class MutableVamanaIndex { lib::save_to_disk( lib::SaveOverride([&](const lib::SaveContext& ctx) { // Save the construction parameters. - auto parameters = VamanaIndexParameters{ - entry_point_.front(), - {alpha_, - graph_.max_degree(), - get_construction_window_size(), - get_max_candidates(), - prune_to_, - get_full_search_history()}, - get_search_parameters()}; - return lib::SaveTable( "vamana_dynamic_auxiliary_parameters", save_version, { {"name", lib::save(name())}, - {"parameters", lib::save(parameters, ctx)}, + {"parameters", lib::save(parameters(), ctx)}, {"translation", lib::save(translator_, ctx)}, } ); @@ -1032,6 +1034,31 @@ class MutableVamanaIndex { lib::save_to_disk(graph_, graph_directory); } + void save(std::ostream& os) { + // Post-consolidation, all entries should be "valid". + // Therefore, we don't need to save the slot metadata. + consolidate(); + compact(); + + lib::begin_serialization(os); + auto save_table = lib::SaveTable( + "vamana_dynamic_auxiliary_parameters", + save_version, + { + {"name", lib::save(name())}, + {"parameters", lib::save(parameters())}, + {"translation", lib::detail::exit_hook(translator_.save_table())}, + } + ); + lib::save_to_stream(save_table, os); + translator_.save(os); + + // Save the dataset. + lib::save_to_stream(data_, os); + // Save the graph. + lib::save_to_stream(graph_, os); + } + ///// ///// Calibrate ///// diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index b7c136645..02a93c615 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -85,8 +85,7 @@ struct VamanaIndexParameters { static constexpr lib::Version save_version = lib::Version(0, 0, 3); static constexpr std::string_view serialization_schema = "vamana_index_parameters"; - // Save and Reload. - lib::SaveTable save() const { + lib::SaveTable metadata() const { return lib::SaveTable( serialization_schema, save_version, @@ -97,6 +96,9 @@ struct VamanaIndexParameters { ); } + // Save to Table and Reload. + lib::SaveTable save() const { return metadata(); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return schema == serialization_schema && version <= save_version; } @@ -814,6 +816,20 @@ class VamanaIndex { lib::save_to_disk(graph_, graph_directory); } + void save(std::ostream& os) const { + // Construct and save runtime parameters. + auto parameters = VamanaIndexParameters{ + entry_point_.front(), build_parameters_, get_search_parameters()}; + + lib::begin_serialization(os); + // Config + lib::save_to_stream(parameters, os); + // Data + lib::save_to_stream(data_, os); + // // Graph + lib::save_to_stream(graph_, os); + } + ///// Calibration // Return the maximum degree of the graph. @@ -1006,6 +1022,66 @@ auto auto_assemble( return index; } +template < + typename LazyGraphLoader, + typename LazyDataLoader, + typename Distance, + typename ThreadPoolProto> +auto auto_assemble( + const lib::detail::Deserializer& deserializer, + std::istream& is, + LazyGraphLoader graph_loader, + LazyDataLoader data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() +) { + using Data = decltype(data_loader()); + using Graph = decltype(graph_loader()); + auto config_loader = [&] { + return lib::load_from_stream(deserializer, is); + }; + + std::optional config; + std::optional data; + std::optional graph; + + if (deserializer.is_native()) { + // Order is always config->data->graph. + config.emplace(config_loader()); + data.emplace(data_loader()); + graph.emplace(graph_loader()); + } else { + // Directory packing order is filesystem-dependent. + // Read 3 data blocks: config, data and graph in a corresponding order + for (int data_block_idx = 0; data_block_idx < 3; ++data_block_idx) { + auto name = deserializer.read_name_in_advance(is); + if (name.starts_with("config/")) { + config.emplace(config_loader()); + } else if (name.starts_with("data/")) { + data.emplace(data_loader()); + } else if (name.starts_with("graph/")) { + graph.emplace(graph_loader()); + } else { + throw ANNEXCEPTION("The stream is corrupted!"); + } + } + } + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + // Extract the index type of the provided graph. + using I = typename Graph::index_type; + auto index = VamanaIndex{ + std::move(*graph), + std::move(*data), + I{}, + std::move(distance), + std::move(threadpool), + std::move(logger)}; + index.apply(*config); + return index; +} + /// @brief Verify parameters and set defaults if needed template void verify_and_set_default_index_parameters( diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index 09e072dee..c3e5591b8 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -883,6 +883,10 @@ class Deserializer { } void read_size(std::istream& stream) const { + if (skip_next_name_) { + throw ANNEXCEPTION("Error in deserialization: read_size() shouldn't follow " + "read_name_in_advance()!"); + } if (scheme_ == SerializationScheme::legacy) { lib::StreamArchiver::size_type size = 0; lib::StreamArchiver::read_size(stream, size); @@ -897,7 +901,7 @@ class Deserializer { }; inline ContextFreeSerializedObject -begin_deserialization(const Deserializer& deserializer, std::istream& stream) { +read_metadata(const Deserializer& deserializer, std::istream& stream) { deserializer.read_name(stream); if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); @@ -957,6 +961,10 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) { return lib::load_from_disk(Loader(), path, SVS_FWD(args)...); } +template +concept LoadableFromTable = requires(const T& x +) { T::load(std::declval(), std::declval()...); }; + ///// load_from_stream template T load_from_stream( @@ -964,12 +972,25 @@ T load_from_stream( const detail::Deserializer& deserializer, std::istream& stream, Args&&... args -) { - // At this point, we will try the saving/loading framework to load the object. - // Here we go! +) + requires LoadableFromTable +{ + // Object is loadable from it's toml::table + return lib::load(loader, detail::read_metadata(deserializer, stream), SVS_FWD(args)...); +} + +template +T load_from_stream( + const Loader& loader, + const detail::Deserializer& deserializer, + std::istream& stream, + Args&&... args +) + requires(!LoadableFromTable) +{ return lib::load( loader, - detail::begin_deserialization(deserializer, stream), + detail::read_metadata(deserializer, stream), deserializer, stream, SVS_FWD(args)... diff --git a/include/svs/lib/saveload/save.h b/include/svs/lib/saveload/save.h index c609151e1..edd627b75 100644 --- a/include/svs/lib/saveload/save.h +++ b/include/svs/lib/saveload/save.h @@ -326,7 +326,7 @@ void save_node_to_stream( Nodelike&& node, std::ostream& os, const lib::Version& version = CURRENT_SAVE_VERSION ) { auto top_table = toml::table( - {{config_version_key, version.str()}, {config_object_key, SVS_FWD(node)}} + {{config_version_key, version.str()}, {config_object_key, SVS_FWD(exit_hook(node))}} ); StreamArchiver::write_table(os, top_table); @@ -381,13 +381,16 @@ inline void begin_serialization(std::ostream& os) { lib::StreamArchiver::write_size(os, lib::StreamArchiver::magic_number); } +inline void save_to_stream(const lib::SaveTable& x, std::ostream& os) { + detail::save_node_to_stream(x, os); +} + template void save_to_stream(const T& x, std::ostream& os) { - if constexpr (requires { x.save_table(); }) { - auto save_table = x.save_table(); - detail::save_node_to_stream(detail::exit_hook(save_table), os); - x.save(os); - } else if constexpr (std::is_same_v) { - detail::save_node_to_stream(detail::exit_hook(x), os); + if constexpr (requires { x.metadata(); }) { + save_to_stream(x.metadata(), os); + if constexpr (requires { x.save(os); }) { + x.save(os); + } } else { static_assert(sizeof(T) == 0, "Type not stream-serializable"); } diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index 6b698c4f9..65905cb8c 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -177,15 +177,7 @@ class VamanaImpl : public manager::ManagerImpl { void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto graph_dir = tempdir.get() / "graph"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(graph_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, graph_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); } @@ -474,32 +466,45 @@ class Vamana : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); - } - - const auto graph_path = tempdir.get() / "graph"; - if (!fs::is_directory(graph_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); - } - - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_vamana>( + AssembleTag(), + deserializer, + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ); + }); + } else { + return make_vamana>( + AssembleTag(), + deserializer, + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance, + std::move(threadpool) + ); } - - return assemble( - config_path, - svs::GraphLoader{graph_path}, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) - ); } /// diff --git a/include/svs/quantization/scalar/scalar.h b/include/svs/quantization/scalar/scalar.h index a2244d1fd..992d533cf 100644 --- a/include/svs/quantization/scalar/scalar.h +++ b/include/svs/quantization/scalar/scalar.h @@ -497,6 +497,18 @@ class SQDataset { ); } + lib::SaveTable metadata() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"data", lib::detail::exit_hook(data_.metadata())}, + {"scale", lib::save(scale_)}, + {"bias", lib::save(bias_)}} + ); + } + + void save(std::ostream& os) const { data_.save(os); } + /// @brief Load dataset from a file. static SQDataset load(const lib::LoadTable& table, const allocator_type& allocator = {}) { @@ -506,6 +518,19 @@ class SQDataset { lib::load_at(table, "bias")}; } + /// @brief Load dataset from a stream. + static SQDataset load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is, + const allocator_type& allocator = {} + ) { + return SQDataset{ + SVS_LOAD_MEMBER_AT_(table, data, deserializer, is, allocator), + lib::load_at(table, "scale"), + lib::load_at(table, "bias")}; + } + /// @brief Prefetch data in the dataset. void prefetch(size_t i) const { data_.prefetch(i); } }; diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index b94b902b8..67c5aedd4 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -37,6 +37,7 @@ // svsbenchmark #include "svs-benchmark/benchmark.h" // stl +#include #include namespace { @@ -181,6 +182,126 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") { CATCH_REQUIRE(captured_logs[2].find("Batch Size:") != std::string::npos); } +CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { + const size_t N = 128; + using Eltype = float; + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto graph = svs::graphs::SimpleGraph(data.size(), 64); + svs::distance::DistanceL2 distance_function; + uint32_t entry_point = 0; + auto threadpool = svs::threads::DefaultThreadPool(1); + + // Build the VamanaIndex with the test logger + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + svs::index::vamana::VamanaIndex index( + buildParams, + std::move(graph), + std::move(data), + entry_point, + distance_function, + std::move(threadpool) + ); + + const size_t NUM_NEIGHBORS = 10; + auto queries = test_dataset::queries(); + auto search_params = svs::index::vamana::VamanaSearchParameters{}; + search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{NUM_NEIGHBORS}; + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), search_params); + + CATCH_SECTION("Load Vamana Index being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, graph_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + + using Data_t = svs::data::SimpleData; + using GraphType = svs::GraphLoader<>::return_type; + + auto loaded_index = svs::index::vamana::auto_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + distance_function, + svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load Vamana Index being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + + using Data_t = svs::data::SimpleData; + using GraphType = svs::GraphLoader<>::return_type; + + auto loaded_index = svs::index::vamana::auto_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + distance_function, + svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} + CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") { using Catch::Approx; std::filesystem::path data_path = test_dataset::data_svs_file(); From bf7a3f84bda35727e83e958b10bdcc7f20cf4d5e Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Thu, 19 Mar 2026 09:59:24 +0100 Subject: [PATCH 4/7] Native serialization for MutableVamana index (#286) This PR introduce native serialization for `MutableVamana` index. Main changes are: 1. New overload of svs::index::vamana::auto_dynamic_assemble required for direct deserialization accepts lazy loaders and call them in a flexible order to cover legacy serialized models. 2. The test file tests/svs/index/vamana/dynamic_index.cpp is returned to the test build, as far as it is the right place for serialization tests. 3. Some minor refactoring to streamline the logic. Co-authored-by: Dmitry Razdoburdin --- include/svs/core/graph/graph.h | 35 +++-- include/svs/core/translation.h | 15 +-- include/svs/index/flat/dynamic_flat.h | 8 +- include/svs/index/vamana/dynamic_index.h | 88 +++++++++++++ include/svs/orchestrators/dynamic_vamana.h | 68 ++++++---- tests/CMakeLists.txt | 2 +- tests/svs/index/vamana/dynamic_index.cpp | 141 +++++++++++++++++++++ 7 files changed, 313 insertions(+), 44 deletions(-) diff --git a/include/svs/core/graph/graph.h b/include/svs/core/graph/graph.h index 68357d157..036cdbc59 100644 --- a/include/svs/core/graph/graph.h +++ b/include/svs/core/graph/graph.h @@ -354,14 +354,18 @@ template class SimpleGrap size_t num_vertices = lib::load_at(table, "num_vertices"); size_t max_degree = lib::load_at(table, "max_degree"); - // Skip legacy stream metadata (no-op for native serialization). - deserializer.read_name(is); - deserializer.read_size(is); - deserializer.read_binary(is); - - auto data = data_type(num_vertices, max_degree + 1, alloc_args...); - io::populate(is, data); - return lazy(std::move(data)); + // Build a table compatible with GenericSerializer + auto data_table = toml::table{ + {lib::config_schema_key, data::GenericSerializer::serialization_schema}, + {lib::config_version_key, data::GenericSerializer::save_version.str()}, + {"eltype", lib::save(datatype_v)}, + {"num_vectors", lib::save(num_vertices)}, + {"dims", lib::save(max_degree + 1)}, + }; + + return lazy(data_type::load( + lib::ContextFreeLoadTable(data_table), deserializer, is, alloc_args... + )); } protected: @@ -471,6 +475,16 @@ class SimpleBlockedGraph return parent_type::load(table, lazy); } + static constexpr SimpleBlockedGraph load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is + ) { + auto lazy = + lib::Lazy([](data_type data) { return SimpleBlockedGraph(std::move(data)); }); + return parent_type::load(table, lazy, deserializer, is); + } + static constexpr SimpleBlockedGraph load(const std::filesystem::path& path) { if (data::detail::is_likely_reload(path)) { return lib::load_from_disk(path); @@ -478,6 +492,11 @@ class SimpleBlockedGraph return SimpleBlockedGraph(data_type::load(path)); } } + + static constexpr SimpleBlockedGraph + load(const lib::detail::Deserializer& deserializer, std::istream& is) { + return lib::load_from_stream(deserializer, is); + } }; } // namespace svs::graphs diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index ae2dc4880..1b1fde9c6 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -380,17 +380,16 @@ class IDTranslator { return translator; } - static IDTranslator - load(const lib::detail::Deserializer& deserializer, std::istream& is) { - auto table = lib::detail::read_metadata(deserializer, is); - auto translation = table.template cast() - .at("translation") - .template cast(); - IDTranslator::validate(translation); + static IDTranslator load( + const lib::ContextFreeLoadTable& table, + const lib::detail::Deserializer& deserializer, + std::istream& is + ) { + IDTranslator::validate(table); deserializer.read_name(is); deserializer.read_size(is); - return IDTranslator::load(translation, is); + return IDTranslator::load(table, is); } static IDTranslator load(const lib::LoadTable& table) { diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 849b80b7f..1873eb50e 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -804,7 +804,13 @@ auto auto_dynamic_assemble( svs::logging::logger_ptr logger = svs::logging::get() ) { using Data = decltype(data_loader()); - auto config_loader = [&] { return IDTranslator::load(deserializer, is); }; + auto config_loader = [&] { + auto table = lib::detail::read_metadata(deserializer, is); + auto translation = table.template cast() + .at("translation") + .template cast(); + return IDTranslator::load(translation, deserializer, is); + }; std::optional config; std::optional data; diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 3a1778fe1..4ed55fe6a 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -1456,4 +1456,92 @@ auto auto_dynamic_assemble( std::move(logger)}; } +template < + typename LazyGraphLoader, + typename LazyDataLoader, + typename Distance, + typename ThreadPoolProto> +auto auto_dynamic_assemble( + const lib::detail::Deserializer& deserializer, + std::istream& is, + LazyGraphLoader graph_loader, + LazyDataLoader data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + bool SVS_UNUSED(debug_load_from_static) = false, + svs::logging::logger_ptr logger = svs::logging::get() +) { + using Data = decltype(data_loader()); + using Graph = decltype(graph_loader()); + + // The config loader reads the combined TOML (parameters + translation) + // and the translator binary data. + auto config_loader = [&]() -> detail::VamanaStateLoader { + auto table = lib::detail::read_metadata(deserializer, is); + + auto parameters = lib::load( + table.template cast().at("parameters").template cast() + ); + + auto translation = table.template cast() + .at("translation") + .template cast(); + + auto translator = IDTranslator::load(translation, deserializer, is); + + return detail::VamanaStateLoader{std::move(parameters), std::move(translator)}; + }; + + std::optional config; + std::optional data; + std::optional graph; + + if (deserializer.is_native()) { + // Order is always config->data->graph. + config.emplace(config_loader()); + data.emplace(data_loader()); + graph.emplace(graph_loader()); + } else { + // Directory packing order is filesystem-dependent. + // Read 3 data blocks: config, data and graph in a corresponding order. + for (int data_block_idx = 0; data_block_idx < 3; ++data_block_idx) { + auto name = deserializer.read_name_in_advance(is); + if (name.starts_with("config/")) { + config.emplace(config_loader()); + } else if (name.starts_with("data/")) { + data.emplace(data_loader()); + } else if (name.starts_with("graph/")) { + graph.emplace(graph_loader()); + } else { + throw ANNEXCEPTION("The stream is corrupted!"); + } + } + } + + auto datasize = data->size(); + auto graphsize = graph->n_nodes(); + if (datasize != graphsize) { + throw ANNEXCEPTION( + "Reloaded data has {} nodes while the graph has {} nodes!", datasize, graphsize + ); + } + + auto translator_size = config->translator_.size(); + if (translator_size != datasize) { + throw ANNEXCEPTION( + "Translator has {} IDs but should have {}", translator_size, datasize + ); + } + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + return MutableVamanaIndex{ + config->parameters_, + std::move(*data), + std::move(*graph), + std::move(distance), + std::move(config->translator_), + std::move(threadpool), + std::move(logger)}; +} + } // namespace svs::index::vamana diff --git a/include/svs/orchestrators/dynamic_vamana.h b/include/svs/orchestrators/dynamic_vamana.h index da1af7b19..5b93b3a52 100644 --- a/include/svs/orchestrators/dynamic_vamana.h +++ b/include/svs/orchestrators/dynamic_vamana.h @@ -355,33 +355,49 @@ class DynamicVamana : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); - } - - const auto graph_path = tempdir.get() / "graph"; - if (!fs::is_directory(graph_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); - } - - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_dynamic_vamana>( + index::vamana::auto_dynamic_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { + return GraphType::load(deserializer, stream); + }, + // lazy data loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ) + ); + }); + } else { + return make_dynamic_vamana>( + index::vamana::auto_dynamic_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data { + return lib::load_from_stream( + deserializer, stream, SVS_FWD(data_args)... + ); + }, + distance, + std::move(threadpool) + ) + ); } - - return assemble( - config_path, - svs::GraphLoader{graph_path}, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)), - false - ); } /// @copydoc svs::Vamana::batch_iterator diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 95555eb15..265b37ef8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -148,7 +148,7 @@ set(TEST_SOURCES # Global scalar quantization ${TEST_DIR}/svs/quantization/scalar/scalar.cpp - # # ${TEST_DIR}/svs/index/vamana/dynamic_index.cpp + ${TEST_DIR}/svs/index/vamana/dynamic_index.cpp ) ##### diff --git a/tests/svs/index/vamana/dynamic_index.cpp b/tests/svs/index/vamana/dynamic_index.cpp index 177258879..91039fdb1 100644 --- a/tests/svs/index/vamana/dynamic_index.cpp +++ b/tests/svs/index/vamana/dynamic_index.cpp @@ -23,6 +23,8 @@ #include #include #include +#include +#include #include #include @@ -32,11 +34,14 @@ // catch2 #include "catch2/catch_test_macros.hpp" +#include // tests #include "tests/utils/test_dataset.h" #include "tests/utils/utils.h" +// The MutableVamanaIndex "Soft Deletion" test uses outdated API. +#if 0 namespace { template auto copy_dataset(const T& data) { auto copy = svs::data::SimplePolymorphicData{ @@ -246,3 +251,139 @@ CATCH_TEST_CASE("MutableVamanaIndex", "[graph_index]") { << post_add_time << " seconds." << std::endl; } } +#endif + +CATCH_TEST_CASE( + "MutableVamana Index Save and Load", "[graph_index][dynamic_index][saveload]" +) { + const size_t num_threads = 2; + using Distance = svs::distance::DistanceL2; + + auto data = test_dataset::data_blocked_f32(); + std::vector indices(data.size()); + std::iota(indices.begin(), indices.end(), 0); + + svs::index::vamana::VamanaBuildParameters parameters{1.2, 64, 10, 20, 10, true}; + auto index = svs::index::vamana::MutableVamanaIndex( + parameters, std::move(data), indices, Distance(), num_threads + ); + + const size_t num_neighbors = 10; + auto queries = test_dataset::queries(); + auto search_params = svs::index::vamana::VamanaSearchParameters{}; + search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{num_neighbors}; + auto results = svs::QueryResult(queries.size(), num_neighbors); + index.search(results.view(), queries.cview(), search_params); + + CATCH_SECTION("Load MutableVamana Index being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + + using Data_t = svs::data::BlockedData; + using GraphType = svs::graphs::SimpleBlockedGraph; + + auto loaded = svs::index::vamana::auto_dynamic_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + Distance(), + num_threads + ); + + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha()); + CATCH_REQUIRE(loaded.get_graph_max_degree() == index.get_graph_max_degree()); + CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates()); + CATCH_REQUIRE( + loaded.get_construction_window_size() == + index.get_construction_window_size() + ); + CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to()); + CATCH_REQUIRE( + loaded.get_full_search_history() == index.get_full_search_history() + ); + index.on_ids([&](size_t e) { CATCH_REQUIRE(loaded.has_id(e)); }); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load MutableVamana Index being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_dynvamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, graph_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + auto deserializer = svs::lib::detail::Deserializer::build(stream); + + using Data_t = svs::data::BlockedData; + using GraphType = svs::graphs::SimpleBlockedGraph; + + auto loaded = svs::index::vamana::auto_dynamic_assemble( + deserializer, + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + // lazy data loader + [&]() -> Data_t { + return svs::lib::load_from_stream(deserializer, stream); + }, + Distance(), + num_threads + ); + + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha()); + CATCH_REQUIRE(loaded.get_graph_max_degree() == index.get_graph_max_degree()); + CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates()); + CATCH_REQUIRE( + loaded.get_construction_window_size() == + index.get_construction_window_size() + ); + CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to()); + CATCH_REQUIRE( + loaded.get_full_search_history() == index.get_full_search_history() + ); + index.on_ids([&](size_t e) { CATCH_REQUIRE(loaded.has_id(e)); }); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} From a1ae1a4a8538c177ce515b41b656e9bfa834cac3 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Fri, 20 Mar 2026 10:48:43 +0100 Subject: [PATCH 5/7] Separate native and legacy deserialization paths (#292) This PR diverse deserialization paths for newly serialized models with a legacy ones. The main changes are: 1. Legacy models are now deserialized with intermediate files, as it was done before. 2. Native deserialization path is cleaned, as far as we don't need support of legacy models for this path. --------- Co-authored-by: Dmitry Razdoburdin --- include/svs/core/data/simple.h | 21 +---- include/svs/core/graph/graph.h | 32 +++----- include/svs/core/translation.h | 13 +-- include/svs/index/flat/dynamic_flat.h | 45 +++-------- include/svs/index/vamana/dynamic_index.h | 70 +++++----------- include/svs/index/vamana/index.h | 43 ++-------- include/svs/lib/file.h | 17 ++-- include/svs/lib/saveload/load.h | 94 +++------------------- include/svs/orchestrators/dynamic_flat.h | 52 ++++++++---- include/svs/orchestrators/dynamic_vamana.h | 82 ++++++++++++------- include/svs/orchestrators/exhaustive.h | 21 +++-- include/svs/orchestrators/vamana.h | 77 ++++++++++++------ include/svs/quantization/scalar/scalar.h | 3 +- tests/svs/index/flat/dynamic_flat.cpp | 27 ++----- tests/svs/index/flat/flat.cpp | 15 ++-- tests/svs/index/vamana/dynamic_index.cpp | 37 ++------- tests/svs/index/vamana/index.cpp | 37 ++------- 17 files changed, 263 insertions(+), 423 deletions(-) diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 8f15ff717..7e546977d 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -136,12 +136,8 @@ class GenericSerializer { } template F> - static lib::lazy_result_t load( - const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, - std::istream& is, - const F& lazy - ) { + static lib::lazy_result_t + load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) { auto datatype = lib::load_at(table, "eltype"); if (datatype != datatype_v) { throw ANNEXCEPTION( @@ -155,10 +151,6 @@ class GenericSerializer { size_t num_vectors = lib::load_at(table, "num_vectors"); size_t dims = lib::load_at(table, "dims"); - deserializer.read_name(is); - deserializer.read_size(is); - deserializer.read_binary(is); - return io::load_dataset(is, lazy, num_vectors, dims); } }; @@ -482,14 +474,13 @@ class SimpleData { static SimpleData load( const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, std::istream& is, const allocator_type& allocator = {} ) requires(!is_view) { return GenericSerializer::load( - table, deserializer, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { + table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); }) ); @@ -888,15 +879,11 @@ class SimpleData> { static SimpleData load( const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, std::istream& is, const Blocked& allocator = {} ) { return GenericSerializer::load( - table, - deserializer, - is, - lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { + table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); }) ); diff --git a/include/svs/core/graph/graph.h b/include/svs/core/graph/graph.h index 036cdbc59..89e456e07 100644 --- a/include/svs/core/graph/graph.h +++ b/include/svs/core/graph/graph.h @@ -335,7 +335,6 @@ template class SimpleGrap static lib::lazy_result_t load( const lib::ContextFreeLoadTable& table, const F& lazy, - const lib::detail::Deserializer& deserializer, std::istream& is, AllocArgs&&... alloc_args ) { @@ -363,9 +362,9 @@ template class SimpleGrap {"dims", lib::save(max_degree + 1)}, }; - return lazy(data_type::load( - lib::ContextFreeLoadTable(data_table), deserializer, is, alloc_args... - )); + return lazy( + data_type::load(lib::ContextFreeLoadTable(data_table), is, alloc_args...) + ); } protected: @@ -419,12 +418,11 @@ class SimpleGraph : public SimpleGraphBase(deserializer, is, allocator); + static constexpr SimpleGraph load(std::istream& is, const Alloc& allocator = {}) { + return lib::load_from_stream(is, allocator); } }; @@ -475,14 +469,11 @@ class SimpleBlockedGraph return parent_type::load(table, lazy); } - static constexpr SimpleBlockedGraph load( - const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, - std::istream& is - ) { + static constexpr SimpleBlockedGraph + load(const lib::ContextFreeLoadTable& table, std::istream& is) { auto lazy = lib::Lazy([](data_type data) { return SimpleBlockedGraph(std::move(data)); }); - return parent_type::load(table, lazy, deserializer, is); + return parent_type::load(table, lazy, is); } static constexpr SimpleBlockedGraph load(const std::filesystem::path& path) { @@ -493,9 +484,8 @@ class SimpleBlockedGraph } } - static constexpr SimpleBlockedGraph - load(const lib::detail::Deserializer& deserializer, std::istream& is) { - return lib::load_from_stream(deserializer, is); + static constexpr SimpleBlockedGraph load(std::istream& is) { + return lib::load_from_stream(is); } }; diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index 1b1fde9c6..db65bf213 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -369,6 +369,7 @@ class IDTranslator { } static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) { + IDTranslator::validate(table); auto num_points = lib::load_at(table, "num_points"); auto translator = IDTranslator{}; @@ -380,18 +381,6 @@ class IDTranslator { return translator; } - static IDTranslator load( - const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, - std::istream& is - ) { - IDTranslator::validate(table); - deserializer.read_name(is); - deserializer.read_size(is); - - return IDTranslator::load(table, is); - } - static IDTranslator load(const lib::LoadTable& table) { IDTranslator::validate(table); diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 1873eb50e..71d72e10e 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -790,7 +790,6 @@ auto auto_dynamic_assemble( template auto auto_dynamic_assemble( - const lib::detail::Deserializer& deserializer, std::istream& is, LazyDataLoader&& data_loader, Distance distance, @@ -803,40 +802,16 @@ auto auto_dynamic_assemble( bool SVS_UNUSED(debug_load_from_static) = false, svs::logging::logger_ptr logger = svs::logging::get() ) { - using Data = decltype(data_loader()); - auto config_loader = [&] { - auto table = lib::detail::read_metadata(deserializer, is); - auto translation = table.template cast() - .at("translation") - .template cast(); - return IDTranslator::load(translation, deserializer, is); - }; - - std::optional config; - std::optional data; - - if (deserializer.is_native()) { - // Order is always config->data. - config.emplace(config_loader()); - data.emplace(data_loader()); - } else { - // Directory packing order is filesystem-dependent. - // Read 2 data blocks: config and data in a corresponding order. - for (int data_block_idx = 0; data_block_idx < 2; ++data_block_idx) { - auto name = deserializer.read_name_in_advance(is); - if (name.starts_with("config/")) { - config.emplace(config_loader()); - } else if (name.starts_with("data/")) { - data.emplace(data_loader()); - } else { - throw ANNEXCEPTION("The stream is corrupted!"); - } - } - } + auto table = lib::detail::read_metadata(is); + auto translation = + table.template cast().at("translation").template cast(); + IDTranslator translator = IDTranslator::load(translation, is); + + auto data = data_loader(); // Validate the translator - auto translator_size = config->size(); - auto datasize = data->size(); + auto translator_size = translator.size(); + auto datasize = data.size(); if (translator_size != datasize) { throw ANNEXCEPTION( "Translator has {} IDs but should have {}", translator_size, datasize @@ -845,8 +820,8 @@ auto auto_dynamic_assemble( auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); return DynamicFlatIndex( - std::move(*data), - std::move(*config), + std::move(data), + std::move(translator), std::move(distance), std::move(threadpool), std::move(logger) diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 4ed55fe6a..a5b88b4db 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -1462,7 +1462,6 @@ template < typename Distance, typename ThreadPoolProto> auto auto_dynamic_assemble( - const lib::detail::Deserializer& deserializer, std::istream& is, LazyGraphLoader graph_loader, LazyDataLoader data_loader, @@ -1471,62 +1470,31 @@ auto auto_dynamic_assemble( bool SVS_UNUSED(debug_load_from_static) = false, svs::logging::logger_ptr logger = svs::logging::get() ) { - using Data = decltype(data_loader()); - using Graph = decltype(graph_loader()); - - // The config loader reads the combined TOML (parameters + translation) + // Read the combined TOML (parameters + translation) // and the translator binary data. - auto config_loader = [&]() -> detail::VamanaStateLoader { - auto table = lib::detail::read_metadata(deserializer, is); + auto table = lib::detail::read_metadata(is); - auto parameters = lib::load( - table.template cast().at("parameters").template cast() - ); + auto parameters = lib::load( + table.template cast().at("parameters").template cast() + ); - auto translation = table.template cast() - .at("translation") - .template cast(); - - auto translator = IDTranslator::load(translation, deserializer, is); - - return detail::VamanaStateLoader{std::move(parameters), std::move(translator)}; - }; - - std::optional config; - std::optional data; - std::optional graph; - - if (deserializer.is_native()) { - // Order is always config->data->graph. - config.emplace(config_loader()); - data.emplace(data_loader()); - graph.emplace(graph_loader()); - } else { - // Directory packing order is filesystem-dependent. - // Read 3 data blocks: config, data and graph in a corresponding order. - for (int data_block_idx = 0; data_block_idx < 3; ++data_block_idx) { - auto name = deserializer.read_name_in_advance(is); - if (name.starts_with("config/")) { - config.emplace(config_loader()); - } else if (name.starts_with("data/")) { - data.emplace(data_loader()); - } else if (name.starts_with("graph/")) { - graph.emplace(graph_loader()); - } else { - throw ANNEXCEPTION("The stream is corrupted!"); - } - } - } + auto translation = + table.template cast().at("translation").template cast(); + + auto translator = IDTranslator::load(translation, is); + + auto data = data_loader(); + auto graph = graph_loader(); - auto datasize = data->size(); - auto graphsize = graph->n_nodes(); + auto datasize = data.size(); + auto graphsize = graph.n_nodes(); if (datasize != graphsize) { throw ANNEXCEPTION( "Reloaded data has {} nodes while the graph has {} nodes!", datasize, graphsize ); } - auto translator_size = config->translator_.size(); + auto translator_size = translator.size(); if (translator_size != datasize) { throw ANNEXCEPTION( "Translator has {} IDs but should have {}", translator_size, datasize @@ -1535,11 +1503,11 @@ auto auto_dynamic_assemble( auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); return MutableVamanaIndex{ - config->parameters_, - std::move(*data), - std::move(*graph), + parameters, + std::move(data), + std::move(graph), std::move(distance), - std::move(config->translator_), + std::move(translator), std::move(threadpool), std::move(logger)}; } diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index 02a93c615..70a921353 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -1028,7 +1028,6 @@ template < typename Distance, typename ThreadPoolProto> auto auto_assemble( - const lib::detail::Deserializer& deserializer, std::istream& is, LazyGraphLoader graph_loader, LazyDataLoader data_loader, @@ -1036,49 +1035,21 @@ auto auto_assemble( ThreadPoolProto threadpool_proto, svs::logging::logger_ptr logger = svs::logging::get() ) { - using Data = decltype(data_loader()); - using Graph = decltype(graph_loader()); - auto config_loader = [&] { - return lib::load_from_stream(deserializer, is); - }; - - std::optional config; - std::optional data; - std::optional graph; - - if (deserializer.is_native()) { - // Order is always config->data->graph. - config.emplace(config_loader()); - data.emplace(data_loader()); - graph.emplace(graph_loader()); - } else { - // Directory packing order is filesystem-dependent. - // Read 3 data blocks: config, data and graph in a corresponding order - for (int data_block_idx = 0; data_block_idx < 3; ++data_block_idx) { - auto name = deserializer.read_name_in_advance(is); - if (name.starts_with("config/")) { - config.emplace(config_loader()); - } else if (name.starts_with("data/")) { - data.emplace(data_loader()); - } else if (name.starts_with("graph/")) { - graph.emplace(graph_loader()); - } else { - throw ANNEXCEPTION("The stream is corrupted!"); - } - } - } + VamanaIndexParameters config = lib::load_from_stream(is); + auto data = data_loader(); + auto graph = graph_loader(); auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); // Extract the index type of the provided graph. - using I = typename Graph::index_type; + using I = typename decltype(graph)::index_type; auto index = VamanaIndex{ - std::move(*graph), - std::move(*data), + std::move(graph), + std::move(data), I{}, std::move(distance), std::move(threadpool), std::move(logger)}; - index.apply(*config); + index.apply(config); return index; } diff --git a/include/svs/lib/file.h b/include/svs/lib/file.h index c4df66099..7d949ae4d 100644 --- a/include/svs/lib/file.h +++ b/include/svs/lib/file.h @@ -257,18 +257,16 @@ struct DirectoryArchiver : Archiver { return total_bytes; } - static size_t unpack(std::istream& stream, const std::filesystem::path& root) { + static size_t + unpack(std::istream& stream, const std::filesystem::path& root, size_type magic) { namespace fs = std::filesystem; - // Read and verify the magic number. - size_type magic = 0; - auto total_bytes = read_size(stream, magic); if (magic != magic_number) { throw ANNEXCEPTION("Invalid magic number in directory unpacking!"); } size_type num_files = 0; - total_bytes += read_size(stream, num_files); + auto total_bytes = read_size(stream, num_files); if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); } @@ -286,5 +284,14 @@ struct DirectoryArchiver : Archiver { return total_bytes; } + + static size_t unpack(std::istream& stream, const std::filesystem::path& root) { + // Read and verify the magic number. + size_type magic = 0; + auto total_bytes = read_size(stream, magic); + + total_bytes += unpack(stream, root, magic); + return total_bytes; + } }; } // namespace svs::lib diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index c3e5591b8..06f718d22 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -834,75 +834,25 @@ inline SerializedObject begin_deserialization(const std::filesystem::path& fullp } class Deserializer { - enum SerializationScheme { native, legacy }; - SerializationScheme scheme_; + lib::StreamArchiver::size_type magic_; - mutable bool skip_next_name_ = false; - - explicit Deserializer(const SerializationScheme& scheme) - : scheme_(scheme) {} + explicit Deserializer(lib::StreamArchiver::size_type magic) + : magic_(magic) {} public: static Deserializer build(std::istream& stream) { lib::StreamArchiver::size_type magic = 0; lib::StreamArchiver::read_size(stream, magic); - if (magic == lib::StreamArchiver::magic_number) { - return Deserializer(SerializationScheme::native); - } else if (magic == lib::DirectoryArchiver::magic_number) { - // Backward compatibility mode for older versions: - // Previously, SVS serialized models using an intermediate file, - // so some dummy information was added to the stream. - lib::StreamArchiver::size_type num_files = 0; - lib::StreamArchiver::read_size(stream, num_files); - - return Deserializer(SerializationScheme::legacy); - } else { - throw ANNEXCEPTION("Invalid magic number in stream deserialization!"); - } - } - - bool is_native() const { return scheme_ == SerializationScheme::native; } - - std::string read_name_in_advance(std::istream& stream) const { - std::string name; - if (scheme_ == SerializationScheme::legacy) { - lib::StreamArchiver::read_name(stream, name); - skip_next_name_ = true; - } - return name; - } - void read_name(std::istream& stream) const { - if (scheme_ == SerializationScheme::legacy) { - if (!skip_next_name_) { - std::string name; - lib::StreamArchiver::read_name(stream, name); - } - skip_next_name_ = false; - } + return Deserializer(magic); } - void read_size(std::istream& stream) const { - if (skip_next_name_) { - throw ANNEXCEPTION("Error in deserialization: read_size() shouldn't follow " - "read_name_in_advance()!"); - } - if (scheme_ == SerializationScheme::legacy) { - lib::StreamArchiver::size_type size = 0; - lib::StreamArchiver::read_size(stream, size); - } - } + auto magic() const { return magic_; } - template void read_binary(std::istream& stream) const { - if (scheme_ == SerializationScheme::legacy) { - lib::read_binary(stream); - } - } + bool is_native() const { return magic_ == lib::StreamArchiver::magic_number; } }; -inline ContextFreeSerializedObject -read_metadata(const Deserializer& deserializer, std::istream& stream) { - deserializer.read_name(stream); +inline ContextFreeSerializedObject read_metadata(std::istream& stream) { if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); } @@ -967,41 +917,23 @@ concept LoadableFromTable = requires(const T& x ///// load_from_stream template -T load_from_stream( - const Loader& loader, - const detail::Deserializer& deserializer, - std::istream& stream, - Args&&... args -) +T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) requires LoadableFromTable { // Object is loadable from it's toml::table - return lib::load(loader, detail::read_metadata(deserializer, stream), SVS_FWD(args)...); + return lib::load(loader, detail::read_metadata(stream), SVS_FWD(args)...); } template -T load_from_stream( - const Loader& loader, - const detail::Deserializer& deserializer, - std::istream& stream, - Args&&... args -) +T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) requires(!LoadableFromTable) { - return lib::load( - loader, - detail::read_metadata(deserializer, stream), - deserializer, - stream, - SVS_FWD(args)... - ); + return lib::load(loader, detail::read_metadata(stream), stream, SVS_FWD(args)...); } template -T load_from_stream( - const detail::Deserializer& deserializer, std::istream& stream, Args&&... args -) { - return lib::load_from_stream(Loader(), deserializer, stream, SVS_FWD(args)...); +T load_from_stream(std::istream& stream, Args&&... args) { + return lib::load_from_stream(Loader(), stream, SVS_FWD(args)...); } ///// load_from_file diff --git a/include/svs/orchestrators/dynamic_flat.h b/include/svs/orchestrators/dynamic_flat.h index 250edf10c..fa7d35855 100644 --- a/include/svs/orchestrators/dynamic_flat.h +++ b/include/svs/orchestrators/dynamic_flat.h @@ -277,22 +277,46 @@ class DynamicFlat : public manager::IndexManager { DataLoaderArgs&&... data_args ) { auto deserializer = svs::lib::detail::Deserializer::build(stream); - return DynamicFlat( - AssembleTag(), - manager::as_typelist(), - index::flat::auto_dynamic_assemble( - deserializer, - stream, - // lazy-loader - [&]() -> Data { - return lib::load_from_stream( - deserializer, stream, SVS_FWD(data_args)... - ); - }, + if (deserializer.is_native()) { + return DynamicFlat( + AssembleTag(), + manager::as_typelist(), + index::flat::auto_dynamic_assemble( + stream, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream(stream, SVS_FWD(data_args)...); + }, + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ) + ); + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_dynflat_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION( + "Invalid Dynamic Flat index archive: missing config directory!" + ); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION( + "Invalid Dynamic Flat index archive: missing data directory!" + ); + } + + return assemble( + config_path, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), distance, threads::as_threadpool(std::move(threadpool_proto)) - ) - ); + ); + } } ///// Distance diff --git a/include/svs/orchestrators/dynamic_vamana.h b/include/svs/orchestrators/dynamic_vamana.h index 5b93b3a52..045077387 100644 --- a/include/svs/orchestrators/dynamic_vamana.h +++ b/include/svs/orchestrators/dynamic_vamana.h @@ -356,46 +356,74 @@ class DynamicVamana : public manager::IndexManager { DataLoaderArgs&&... data_args ) { auto deserializer = svs::lib::detail::Deserializer::build(stream); - auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); - using GraphType = svs::GraphLoader<>::return_type; - if constexpr (std::is_same_v, DistanceType>) { - auto dispatcher = DistanceDispatcher(distance); - return dispatcher([&](auto distance_function) { + if (deserializer.is_native()) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_dynamic_vamana>( + index::vamana::auto_dynamic_assemble( + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(stream); }, + // lazy data loader + [&]() -> Data { + return lib::load_from_stream( + stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ) + ); + }); + } else { return make_dynamic_vamana>( index::vamana::auto_dynamic_assemble( - deserializer, stream, // lazy graph loader - [&]() -> GraphType { - return GraphType::load(deserializer, stream); - }, + [&]() -> GraphType { return GraphType::load(stream); }, // lazy data loader [&]() -> Data { return lib::load_from_stream( - deserializer, stream, SVS_FWD(data_args)... + stream, SVS_FWD(data_args)... ); }, - distance_function, + distance, std::move(threadpool) ) ); - }); + } } else { - return make_dynamic_vamana>( - index::vamana::auto_dynamic_assemble( - deserializer, - stream, - // lazy graph loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, - // lazy data loader - [&]() -> Data { - return lib::load_from_stream( - deserializer, stream, SVS_FWD(data_args)... - ); - }, - distance, - std::move(threadpool) - ) + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!" + ); + } + + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!" + ); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + } + + return assemble( + config_path, + svs::GraphLoader{graph_path}, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)), + false ); } } diff --git a/include/svs/orchestrators/exhaustive.h b/include/svs/orchestrators/exhaustive.h index bf46c1347..e22dab7a4 100644 --- a/include/svs/orchestrators/exhaustive.h +++ b/include/svs/orchestrators/exhaustive.h @@ -196,11 +196,22 @@ class Flat : public manager::IndexManager { DataLoaderArgs&&... data_args ) { auto deserializer = svs::lib::detail::Deserializer::build(stream); - return assemble( - lib::load_from_stream(deserializer, stream, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) - ); + if (deserializer.is_native()) { + return assemble( + lib::load_from_stream(stream, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_flat_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); + return assemble( + lib::load_from_disk(tempdir, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } } ///// Distance diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index 65905cb8c..ad12fd8c3 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -467,42 +467,69 @@ class Vamana : public manager::IndexManager { DataLoaderArgs&&... data_args ) { auto deserializer = svs::lib::detail::Deserializer::build(stream); - auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); - using GraphType = svs::GraphLoader<>::return_type; - if constexpr (std::is_same_v) { - auto dispatcher = DistanceDispatcher(distance); - return dispatcher([&](auto distance_function) { + if (deserializer.is_native()) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_vamana>( + AssembleTag(), + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ); + }); + } else { return make_vamana>( AssembleTag(), - deserializer, stream, // lazy-loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, + [&]() -> GraphType { return GraphType::load(stream); }, // lazy-loader [&]() -> Data { - return lib::load_from_stream( - deserializer, stream, SVS_FWD(data_args)... - ); + return lib::load_from_stream(stream, SVS_FWD(data_args)...); }, - distance_function, + distance, std::move(threadpool) ); - }); + } } else { - return make_vamana>( - AssembleTag(), - deserializer, - stream, - // lazy-loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, - // lazy-loader - [&]() -> Data { - return lib::load_from_stream( - deserializer, stream, SVS_FWD(data_args)... - ); - }, + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!" + ); + } + + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!" + ); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + } + + return assemble( + config_path, + svs::GraphLoader{graph_path}, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), distance, - std::move(threadpool) + threads::as_threadpool(std::move(threadpool_proto)) ); } } diff --git a/include/svs/quantization/scalar/scalar.h b/include/svs/quantization/scalar/scalar.h index 992d533cf..5998e76a7 100644 --- a/include/svs/quantization/scalar/scalar.h +++ b/include/svs/quantization/scalar/scalar.h @@ -521,12 +521,11 @@ class SQDataset { /// @brief Load dataset from a stream. static SQDataset load( const lib::ContextFreeLoadTable& table, - const lib::detail::Deserializer& deserializer, std::istream& is, const allocator_type& allocator = {} ) { return SQDataset{ - SVS_LOAD_MEMBER_AT_(table, data, deserializer, is, allocator), + SVS_LOAD_MEMBER_AT_(table, data, is, allocator), lib::load_at(table, "scale"), lib::load_at(table, "bias")}; } diff --git a/tests/svs/index/flat/dynamic_flat.cpp b/tests/svs/index/flat/dynamic_flat.cpp index 76c4d4c65..900118e15 100644 --- a/tests/svs/index/flat/dynamic_flat.cpp +++ b/tests/svs/index/flat/dynamic_flat.cpp @@ -26,6 +26,7 @@ #include "svs/lib/threads.h" #include "svs/lib/timing.h" #include "svs/misc/dynamic_helper.h" +#include "svs/orchestrators/dynamic_flat.h" // tests #include "tests/utils/test_dataset.h" @@ -280,17 +281,8 @@ CATCH_TEST_CASE("DynamicFlat Index Save and Load", "[dynamic_flat][index][savelo std::stringstream stream; index.save(stream); { - auto deserializer = svs::lib::detail::Deserializer::build(stream); - Index_t loaded_index = svs::index::flat::auto_dynamic_assemble( - deserializer, - stream, - // lazy-loader - [&]() -> Data_t { - return svs::lib::load_from_stream(deserializer, stream); - }, - dist, - svs::threads::as_threadpool(num_threads) - ); + auto loaded_index = + svs::DynamicFlat::assemble(stream, dist, num_threads); CATCH_REQUIRE(loaded_index.size() == index.size()); CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); @@ -323,17 +315,8 @@ CATCH_TEST_CASE("DynamicFlat Index Save and Load", "[dynamic_flat][index][savelo svs::lib::DirectoryArchiver::pack(tempdir, stream); } { - auto deserializer = svs::lib::detail::Deserializer::build(stream); - Index_t loaded_index = svs::index::flat::auto_dynamic_assemble( - deserializer, - stream, - // lazy-loader - [&]() -> Data_t { - return svs::lib::load_from_stream(deserializer, stream); - }, - dist, - svs::threads::as_threadpool(num_threads) - ); + auto loaded_index = + svs::DynamicFlat::assemble(stream, dist, num_threads); CATCH_REQUIRE(loaded_index.size() == index.size()); CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp index 29879b664..a5e1c801f 100644 --- a/tests/svs/index/flat/flat.cpp +++ b/tests/svs/index/flat/flat.cpp @@ -18,6 +18,7 @@ #include "svs/core/logging.h" #include "svs/lib/file.h" #include "svs/lib/saveload/load.h" +#include "svs/orchestrators/exhaustive.h" // tests #include "tests/utils/test_dataset.h" @@ -93,11 +94,8 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { std::stringstream ss; index.save(ss); - auto deserializer = svs::lib::detail::Deserializer::build(ss); - Index_t loaded_index = Index_t( - svs::lib::load_from_stream(deserializer, ss), - dist, - svs::threads::DefaultThreadPool(1) + auto loaded_index = svs::Flat::assemble( + ss, dist, svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); @@ -125,11 +123,8 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { index.save(tempdir); svs::lib::DirectoryArchiver::pack(tempdir, ss); - auto deserializer = svs::lib::detail::Deserializer::build(ss); - Index_t loaded_index = Index_t( - svs::lib::load_from_stream(deserializer, ss), - dist, - svs::threads::DefaultThreadPool(1) + auto loaded_index = svs::Flat::assemble( + ss, dist, svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); diff --git a/tests/svs/index/vamana/dynamic_index.cpp b/tests/svs/index/vamana/dynamic_index.cpp index 91039fdb1..798a584c2 100644 --- a/tests/svs/index/vamana/dynamic_index.cpp +++ b/tests/svs/index/vamana/dynamic_index.cpp @@ -31,6 +31,7 @@ // svs #include "svs/core/recall.h" #include "svs/lib/timing.h" +#include "svs/orchestrators/dynamic_vamana.h" // catch2 #include "catch2/catch_test_macros.hpp" @@ -279,22 +280,10 @@ CATCH_TEST_CASE( std::stringstream stream; index.save(stream); { - auto deserializer = svs::lib::detail::Deserializer::build(stream); - using Data_t = svs::data::BlockedData; - using GraphType = svs::graphs::SimpleBlockedGraph; - - auto loaded = svs::index::vamana::auto_dynamic_assemble( - deserializer, - stream, - // lazy graph loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, - // lazy data loader - [&]() -> Data_t { - return svs::lib::load_from_stream(deserializer, stream); - }, - Distance(), - num_threads + + auto loaded = svs::DynamicVamana::assemble( + stream, Distance(), num_threads ); CATCH_REQUIRE(loaded.size() == index.size()); @@ -340,22 +329,10 @@ CATCH_TEST_CASE( svs::lib::DirectoryArchiver::pack(tempdir, stream); } { - auto deserializer = svs::lib::detail::Deserializer::build(stream); - using Data_t = svs::data::BlockedData; - using GraphType = svs::graphs::SimpleBlockedGraph; - - auto loaded = svs::index::vamana::auto_dynamic_assemble( - deserializer, - stream, - // lazy graph loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, - // lazy data loader - [&]() -> Data_t { - return svs::lib::load_from_stream(deserializer, stream); - }, - Distance(), - num_threads + + auto loaded = svs::DynamicVamana::assemble( + stream, Distance(), num_threads ); CATCH_REQUIRE(loaded.size() == index.size()); diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index 67c5aedd4..bbc535c62 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -24,6 +24,7 @@ // svs #include "svs/index/vamana/build_params.h" #include "svs/lib/preprocessor.h" +#include "svs/orchestrators/vamana.h" // catch2 #include "catch2/catch_test_macros.hpp" @@ -224,22 +225,10 @@ CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { svs::lib::DirectoryArchiver::pack(tempdir, stream); } { - auto deserializer = svs::lib::detail::Deserializer::build(stream); - using Data_t = svs::data::SimpleData; - using GraphType = svs::GraphLoader<>::return_type; - - auto loaded_index = svs::index::vamana::auto_assemble( - deserializer, - stream, - // lazy graph loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, - // lazy data loader - [&]() -> Data_t { - return svs::lib::load_from_stream(deserializer, stream); - }, - distance_function, - svs::threads::DefaultThreadPool(1) + + auto loaded_index = svs::Vamana::assemble( + stream, distance_function, svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); @@ -265,22 +254,10 @@ CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { index.save(stream); { - auto deserializer = svs::lib::detail::Deserializer::build(stream); - using Data_t = svs::data::SimpleData; - using GraphType = svs::GraphLoader<>::return_type; - - auto loaded_index = svs::index::vamana::auto_assemble( - deserializer, - stream, - // lazy graph loader - [&]() -> GraphType { return GraphType::load(deserializer, stream); }, - // lazy data loader - [&]() -> Data_t { - return svs::lib::load_from_stream(deserializer, stream); - }, - distance_function, - svs::threads::DefaultThreadPool(1) + + auto loaded_index = svs::Vamana::assemble( + stream, distance_function, svs::threads::DefaultThreadPool(1) ); CATCH_REQUIRE(loaded_index.size() == index.size()); From 81c5d8219fbaace76c4510637921eb2ad09494c6 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Sun, 22 Mar 2026 12:14:06 +0100 Subject: [PATCH 6/7] native serialization/deserialization for ivf (#294) This PR introduce native serialization for IVF index. Main changes are: 1. New overload of svs::index::ivf::load_ivf_index accpting istream is introduced. 2. Added related tests. 3. save(std::ostream&) method for DynamicIVF is just a placeholder for now. Real implementation is expected later. Co-authored-by: Dmitry Razdoburdin --- include/svs/index/ivf/clustering.h | 115 ++++++++++++++++++++++------ include/svs/index/ivf/dynamic_ivf.h | 4 + include/svs/index/ivf/index.h | 54 +++++++++++++ include/svs/orchestrators/ivf.h | 71 +++++++++++------ tests/svs/index/ivf/index.cpp | 77 +++++++++++++++++++ 5 files changed, 274 insertions(+), 47 deletions(-) diff --git a/include/svs/index/ivf/clustering.h b/include/svs/index/ivf/clustering.h index cc9bb09b8..65828525b 100644 --- a/include/svs/index/ivf/clustering.h +++ b/include/svs/index/ivf/clustering.h @@ -410,6 +410,66 @@ class DenseClusteredDataset { static constexpr lib::Version save_version{0, 0, 0}; static constexpr std::string_view serialization_schema = "ivf_dense_clustered_dataset"; + lib::SaveTable metadata() const { + auto num_clusters = size(); + auto dims = dimensions(); + + return lib::SaveTable( + serialization_schema, + save_version, + {{"num_clusters", lib::save(num_clusters)}, + {"dimensions", lib::save(dims)}, + {"prefetch_offset", lib::save(prefetch_offset_)}, + {"index_type", lib::save(datatype_v)}} + ); + } + + void save(std::ostream& os) const { + auto num_clusters = size(); + + // Compute cluster sizes and ID offsets + std::vector cluster_sizes, ids_offsets; + calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets); + + lib::write_binary(os, cluster_sizes); + for (size_t i = 0; i < num_clusters; ++i) { + lib::save_to_stream(clusters_[i].data_, os); + if (!clusters_[i].ids_.empty()) { + lib::write_binary(os, clusters_[i].ids_); + } + } + } + + template + static DenseClusteredDataset load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const Allocator& allocator = Allocator{} + ) { + auto num_clusters = lib::load_at(table, "num_clusters"); + [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); + auto prefetch_offset = lib::load_at(table, "prefetch_offset"); + check_saved_index_type(table); + + DenseClusteredDataset result; + result.prefetch_offset_ = prefetch_offset; + result.clusters_.reserve(num_clusters); + + std::vector cluster_sizes(num_clusters); + lib::read_binary(is, cluster_sizes); + + for (size_t i = 0; i < num_clusters; ++i) { + auto cluster_data = lib::load_from_stream(is, allocator); + size_t cluster_size = cluster_sizes[i]; + std::vector cluster_ids(cluster_size); + if (cluster_size > 0) { + lib::read_binary(is, cluster_ids); + } + result.clusters_.emplace_back(std::move(cluster_data), std::move(cluster_ids)); + } + return result; + } + /// @brief Save the DenseClusteredDataset to disk. /// /// Saves all cluster data using the existing save mechanisms for each data type @@ -427,16 +487,8 @@ class DenseClusteredDataset { auto dims = dimensions(); // Compute cluster sizes and ID offsets - std::vector cluster_sizes(num_clusters); - std::vector ids_offsets(num_clusters + 1); - - size_t ids_offset = 0; - for (size_t i = 0; i < num_clusters; ++i) { - cluster_sizes[i] = clusters_[i].size(); - ids_offsets[i] = ids_offset; - ids_offset += cluster_sizes[i] * sizeof(I); - } - ids_offsets[num_clusters] = ids_offset; + std::vector cluster_sizes, ids_offsets; + calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets); // Create a temporary directory for cluster data lib::UniqueTempDirectory tempdir{"svs_ivf_clusters_save"}; @@ -495,7 +547,7 @@ class DenseClusteredDataset { {"ids_file", lib::save(std::string("ids.bin"))}, {"cluster_sizes_file", lib::save(std::string("cluster_sizes.bin"))}, {"ids_offsets_file", lib::save(std::string("ids_offsets.bin"))}, - {"total_ids_bytes", lib::save(ids_offset)}} + {"total_ids_bytes", lib::save(ids_offsets[num_clusters])}} ); } @@ -528,17 +580,7 @@ class DenseClusteredDataset { // since each cluster's data type determines its own dimensions [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); auto prefetch_offset = lib::load_at(table, "prefetch_offset"); - - // Verify index type matches - auto saved_index_type = lib::load_at(table, "index_type"); - if (saved_index_type != datatype_v) { - throw ANNEXCEPTION( - "DenseClusteredDataset was saved using index type {} but we're trying to " - "reload it using {}!", - saved_index_type, - datatype_v - ); - } + check_saved_index_type(table); auto base_dir = table.context().get_directory(); @@ -603,6 +645,35 @@ class DenseClusteredDataset { } private: + void calculate_cluster_sizes_and_offsets( + std::vector* cluster_sizes, std::vector* ids_offsets + ) const { + auto num_clusters = size(); + cluster_sizes->resize(num_clusters); + ids_offsets->resize(num_clusters + 1); + + size_t ids_offset = 0; + for (size_t i = 0; i < num_clusters; ++i) { + (*cluster_sizes)[i] = clusters_[i].size(); + (*ids_offsets)[i] = ids_offset; + ids_offset += clusters_[i].size() * sizeof(I); + } + (*ids_offsets)[num_clusters] = ids_offset; + } + + static void check_saved_index_type(const lib::ContextFreeLoadTable& table) { + // Verify index type matches + auto saved_index_type = lib::load_at(table, "index_type"); + if (saved_index_type != datatype_v) { + throw ANNEXCEPTION( + "DenseClusteredDataset was saved using index type {} but we're trying to " + "reload it using {}!", + saved_index_type, + datatype_v + ); + } + } + std::vector> clusters_; size_t prefetch_offset_ = 8; }; diff --git a/include/svs/index/ivf/dynamic_ivf.h b/include/svs/index/ivf/dynamic_ivf.h index 92fc86ed3..40ad32c3f 100644 --- a/include/svs/index/ivf/dynamic_ivf.h +++ b/include/svs/index/ivf/dynamic_ivf.h @@ -761,6 +761,10 @@ class DynamicIVFIndex { lib::save_to_disk(clusters_, clusters_dir); } + void save(std::ostream& SVS_UNUSED(os)) const { + throw ANNEXCEPTION("Placeholder; not implemented!"); + } + private: ///// Helper Methods ///// diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 263b18686..00f45735a 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -465,6 +465,7 @@ class IVFIndex { /// Each directory may be created as a side-effect of this method call provided that /// the parent directory exists. /// + void save( const std::filesystem::path& config_directory, const std::filesystem::path& data_directory @@ -504,6 +505,27 @@ class IVFIndex { lib::save_to_disk(cluster_, clusters_dir); } + void save(std::ostream& os) const { + lib::begin_serialization(os); + + // Save config + auto data_type_config = DataTypeTraits::get_config(); + data_type_config.centroid_type = datatype_v; + auto save_table = lib::SaveTable( + serialization_schema, + save_version, + {{"name", lib::save(name())}, {"num_clusters", lib::save(num_clusters())}} + ); + save_table.insert("data_type_config", lib::save(data_type_config)); + lib::save_to_stream(save_table, os); + + // Save centroids + lib::save_to_stream(centroids_, os); + + // Save clusters + lib::save_to_stream(cluster_, os); + } + private: ///// Core Components ///// Centroids centroids_; @@ -964,4 +986,36 @@ auto load_ivf_index( return index; } +template < + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> +auto load_ivf_index( + std::istream& is, + Distance distance, + ThreadpoolProto threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() +) { + using centroids_type = data::SimpleData; + using data_type = typename DataType::lib_alloc_data_type; + using cluster_type = DenseClusteredDataset; + + lib::detail::read_metadata(is); + auto centroids = lib::load_from_stream(is); + auto clusters = lib::load_from_stream(is); + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + + return IVFIndex( + std::move(centroids), + std::move(clusters), + std::move(distance), + std::move(threadpool), + intra_query_thread_count, + logger + ); +} + } // namespace svs::index::ivf diff --git a/include/svs/orchestrators/ivf.h b/include/svs/orchestrators/ivf.h index 90c980780..87cb1b229 100644 --- a/include/svs/orchestrators/ivf.h +++ b/include/svs/orchestrators/ivf.h @@ -116,13 +116,7 @@ class IVFImpl : public manager::ManagerImpl { void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current IVF backend doesn't support saving!"); } @@ -380,27 +374,54 @@ class IVF : public manager::IndexManager { ThreadpoolProto threadpool_proto, size_t intra_query_threads = 1 ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + stream, + std::move(distance_function), + std::move(threadpool_proto), + intra_query_threads + ) + ); + }); + } else { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + stream, distance, std::move(threadpool_proto), intra_query_threads + ) + ); + } + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); - } + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); + } - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); - } + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); + } - return assemble( - config_path, - data_path, - distance, - std::move(threadpool_proto), - intra_query_threads - ); + return assemble( + config_path, + data_path, + distance, + std::move(threadpool_proto), + intra_query_threads + ); + } } ///// Building diff --git a/tests/svs/index/ivf/index.cpp b/tests/svs/index/ivf/index.cpp index 7c78a3add..b15f26040 100644 --- a/tests/svs/index/ivf/index.cpp +++ b/tests/svs/index/ivf/index.cpp @@ -16,6 +16,7 @@ // header under test #include "svs/index/ivf/index.h" +#include "svs/orchestrators/ivf.h" // tests #include "tests/utils/test_dataset.h" @@ -33,6 +34,7 @@ // stl #include +#include CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") { namespace ivf = svs::index::ivf; @@ -276,6 +278,81 @@ CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") { svs_test::cleanup_temp_directory(); } + CATCH_SECTION("Load IVF Index serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + using DataType = svs::data::SimpleData; + + auto loaded_ivf = svs::IVF::assemble( + stream, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + CATCH_REQUIRE(loaded_ivf.size() == index.size()); + CATCH_REQUIRE(loaded_ivf.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_ivf.search(loaded_results.view(), batch_queries, search_params); + + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE( + loaded_results.index(q, i) == original_results.index(q, i) + ); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load IVF Index serialized natively to stream") { + std::stringstream stream; + index.save(stream); + + { + using DataType = svs::data::SimpleData; + + auto loaded_ivf = svs::IVF::assemble( + stream, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + CATCH_REQUIRE(loaded_ivf.size() == index.size()); + CATCH_REQUIRE(loaded_ivf.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_ivf.search(loaded_results.view(), batch_queries, search_params); + + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE( + loaded_results.index(q, i) == original_results.index(q, i) + ); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + CATCH_SECTION("Save and load DenseClusteredDataset") { // Prepare temp directory auto tempdir = svs_test::prepare_temp_directory_v2(); From be0b99d9314bb4a7b44a867c815be4364b3c2c27 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Mon, 23 Mar 2026 06:11:00 -0700 Subject: [PATCH 7/7] initial --- include/svs/index/inverted/clustering.h | 80 +++++++++++++---------- include/svs/index/inverted/memory_based.h | 44 +++++++++++++ include/svs/orchestrators/inverted.h | 31 +++++++++ tests/svs/index/inverted/memory_based.cpp | 75 +++++++++++++++++++++ 4 files changed, 195 insertions(+), 35 deletions(-) diff --git a/include/svs/index/inverted/clustering.h b/include/svs/index/inverted/clustering.h index db82e1593..de5585c29 100644 --- a/include/svs/index/inverted/clustering.h +++ b/include/svs/index/inverted/clustering.h @@ -571,6 +571,36 @@ template class Clustering { // Saving and Loading. static constexpr lib::Version save_version{0, 0, 0}; static constexpr std::string_view serialization_schema = "clustering"; + + lib::SaveTable metadata() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"integer_type", lib::save(datatype_v)}, + {"num_clusters", lib::save(size())}} + ); + } + + void save(std::ostream& os) const { + for (const auto& [id, cluster] : *this) { + cluster.serialize(os); + } + } + + static Clustering + load(const lib::ContextFreeLoadTable& table, std::istream& stream) { + auto saved_integer_type = lib::load_at(table, "integer_type"); + if (saved_integer_type != datatype_v) { + throw ANNEXCEPTION("Clustering was saved using {} but we're trying to reload it using {}!", saved_integer_type, datatype_v); + } + auto num_clusters = lib::load_at(table, "num_clusters"); + auto clustering = Clustering(); + for (size_t i = 0; i < num_clusters; ++i) { + clustering.insert(Cluster::deserialize(stream)); + } + return clustering; + } + lib::SaveTable save(const lib::SaveContext& ctx) const { // Serialize all clusters into an auxiliary file. auto fullpath = ctx.generate_name("clustering", "bin"); @@ -582,48 +612,28 @@ template class Clustering { } } - return lib::SaveTable( - serialization_schema, - save_version, - {{"filepath", lib::save(fullpath.filename())}, - SVS_LIST_SAVE(filesize), - {"integer_type", lib::save(datatype_v)}, - {"num_clusters", lib::save(size())}} - ); + auto table = metadata(); + table.insert("filepath", lib::save(fullpath.filename())); + table.insert("filesize", lib::save(filesize)); + return table; + + return table; } static Clustering load(const lib::LoadTable& table) { - // Ensure we have the correct integer type when decoding. - auto saved_integer_type = lib::load_at(table, "integer_type"); - if (saved_integer_type != datatype_v) { - auto type = datatype_v; + auto expected_filesize = lib::load_at(table, "filesize"); + auto file = table.resolve_at("filepath"); + size_t actual_filesize = std::filesystem::file_size(file); + if (actual_filesize != expected_filesize) { throw ANNEXCEPTION( - "Clustering was saved using {} but we're trying to reload it using {}!", - saved_integer_type, - type + "Expected cluster file size to be {}. Instead, it is {}!", + actual_filesize, + expected_filesize ); } - auto num_clusters = lib::load_at(table, "num_clusters"); - auto expected_filesize = lib::load_at(table, "filesize"); - auto clustering = Clustering(); - { - auto file = table.resolve_at("filepath"); - size_t actual_filesize = std::filesystem::file_size(file); - if (actual_filesize != expected_filesize) { - throw ANNEXCEPTION( - "Expected cluster file size to be {}. Instead, it is {}!", - actual_filesize, - expected_filesize - ); - } - - auto io = lib::open_read(file); - for (size_t i = 0; i < num_clusters; ++i) { - clustering.insert(Cluster::deserialize(io)); - } - } - return clustering; + auto io = lib::open_read(file); + return load(table, io); } private: diff --git a/include/svs/index/inverted/memory_based.h b/include/svs/index/inverted/memory_based.h index 3d3fc24c5..8b2e8c8e1 100644 --- a/include/svs/index/inverted/memory_based.h +++ b/include/svs/index/inverted/memory_based.h @@ -497,6 +497,8 @@ template class InvertedIndex { index_.save(index_config, graph, data); } + void save_primary_index(std::ostream& os) const { index_.save(os); } + ///// Accessors /// @brief Getter method for logger svs::logging::logger_ptr get_logger() const { return logger_; } @@ -655,4 +657,46 @@ auto assemble_from_clustering( ); } +template < + typename DataProto, + typename Distance, + StorageStrategy Strategy, + typename ThreadPoolProto> +auto assemble_from_clustering( + std::istream& is, + DataProto data_proto, + Distance distance, + Strategy strategy, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() +) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto original = svs::detail::dispatch_load(std::move(data_proto), threadpool); + auto clustering = lib::load_from_stream>(is); + auto ids = clustering.sorted_centroids(); + + // skip magic + svs::lib::detail::Deserializer::build(is); + auto index = index::vamana::auto_assemble( + is, + lib::Lazy([&]() { return GraphLoader::return_type::load(is); }), + lib::Lazy([&]() { + using T = typename std::decay_t::element_type; + constexpr size_t Ext = std::decay_t::extent; + return lib::load_from_stream>(is); + }), + distance, + 1, + logger + ); + + return InvertedIndex( + std::move(index), + strategy(original, clustering, HugepageAllocator()), + std::move(ids), + std::move(threadpool), + std::move(logger) + ); +} + } // namespace svs::index::inverted diff --git a/include/svs/orchestrators/inverted.h b/include/svs/orchestrators/inverted.h index 6b6e50470..7fb7e60da 100644 --- a/include/svs/orchestrators/inverted.h +++ b/include/svs/orchestrators/inverted.h @@ -34,6 +34,9 @@ class InvertedInterface { const std::filesystem::path& primary_data, const std::filesystem::path& primary_graph ) = 0; + + ///// Saving + virtual void save_primary_index(std::ostream& os) = 0; }; template @@ -72,6 +75,8 @@ class InvertedImpl : public manager::ManagerImpl { ) override { impl().save_primary_index(primary_config, primary_data, primary_graph); } + + void save_primary_index(std::ostream& os) override { impl().save_primary_index(os); } }; ///// @@ -106,6 +111,8 @@ class Inverted : public manager::IndexManager { impl_->save_primary_index(primary_config, primary_data, primary_graph); } + void save_primary_index(std::ostream& os) { impl_->save_primary_index(os); } + ///// Building template < manager::QueryTypeDefinition QueryTypes, @@ -168,6 +175,30 @@ class Inverted : public manager::IndexManager { std::move(threadpool_proto) )}; } + template < + manager::QueryTypeDefinition QueryTypes, + typename DataProto, + typename Distance, + typename ThreadPoolProto, + typename StorageStrategy = index::inverted::SparseStrategy> + static Inverted assemble_from_clustering( + std::istream& is, + DataProto data_proto, + Distance distance, + ThreadPoolProto threadpool_proto, + StorageStrategy strategy = {} + ) { + return Inverted{ + std::in_place, + manager::as_typelist{}, + index::inverted::assemble_from_clustering( + is, + std::move(data_proto), + std::move(distance), + std::move(strategy), + std::move(threadpool_proto) + )}; + } }; } // namespace svs diff --git a/tests/svs/index/inverted/memory_based.cpp b/tests/svs/index/inverted/memory_based.cpp index d418dd83d..28ff6b269 100644 --- a/tests/svs/index/inverted/memory_based.cpp +++ b/tests/svs/index/inverted/memory_based.cpp @@ -19,9 +19,11 @@ #include "spdlog/sinks/callback_sink.h" #include "svs-benchmark/datasets.h" #include "svs/lib/timing.h" +#include "svs/orchestrators/inverted.h" #include "tests/utils/inverted_reference.h" #include "tests/utils/test_dataset.h" #include +#include CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") { // Vector to store captured log messages @@ -73,3 +75,76 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") { CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos); CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos); } + +namespace { +constexpr size_t NUM_NEIGHBORS = 10; + +template void test_stream_save_load(Strategy strategy) { + auto distance = svs::DistanceL2(); + constexpr auto distance_type = svs::distance_type_v; + auto expected_results = test_dataset::inverted::expected_build_results( + distance_type, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_parameters = expected_results.build_parameters_.value(); + + // Capture the clustering during build. + svs::index::inverted::Clustering clustering; + auto clustering_op = [&](const auto& c) { clustering = c; }; + + svs::Inverted index = svs::Inverted::build( + build_parameters, + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance, + 2, + strategy, + svs::index::inverted::PickRandomly{}, + clustering_op + ); + + auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + auto parameters = index.get_search_parameters(); + auto results = index.search(queries, NUM_NEIGHBORS); + + // Serialize to stream. + std::stringstream ss; + svs::lib::save_to_stream(clustering, ss); + index.save_primary_index(ss); + + // Load from stream. + svs::Inverted loaded = svs::Inverted::assemble_from_clustering>( + ss, + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance, + 2, + strategy + ); + loaded.set_search_parameters(parameters); + + // Compare basic properties. + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + + // Compare search results element-wise. + auto loaded_results = loaded.search(queries, NUM_NEIGHBORS); + CATCH_REQUIRE(loaded_results.n_queries() == results.n_queries()); + CATCH_REQUIRE(loaded_results.n_neighbors() == results.n_neighbors()); + for (size_t q = 0; q < results.n_queries(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } +} +} // namespace + +CATCH_TEST_CASE("InvertedIndex Save and Load", "[saveload][inverted][index]") { + CATCH_SECTION("SparseStrategy") { + test_stream_save_load(svs::index::inverted::SparseStrategy()); + } + CATCH_SECTION("DenseStrategy") { + test_stream_save_load(svs::index::inverted::DenseStrategy()); + } +}