diff --git a/include/svs/core/data/io.h b/include/svs/core/data/io.h index 6bb856fed..572ffa63a 100644 --- a/include/svs/core/data/io.h +++ b/include/svs/core/data/io.h @@ -79,6 +79,22 @@ void populate_impl( } } +template void populate(std::istream& is, Data& data) { + auto accessor = DefaultWriteAccessor(); + + size_t num_vectors = data.size(); + size_t dims = data.dimensions(); + + auto max_lines = Dynamic; + auto nvectors = std::min(num_vectors, max_lines); + + auto reader = lib::VectorReader(dims); + for (size_t i = 0; i < nvectors; ++i) { + reader.read(is); + accessor.set(data, i, reader.data()); + } +} + // Intercept the native file to perform dispatch on the actual file type. template void populate_impl( @@ -120,6 +136,15 @@ void save(const Dataset& data, const File& file, const lib::UUID& uuid = lib::Ze return save(data, accessor, file, uuid); } +template +void save(const Dataset& data, std::ostream& os) { + auto accessor = DefaultReadAccessor(); + auto writer = svs::io::v1::StreamWriter(os); + for (size_t i = 0; i < data.size(); ++i) { + writer << accessor.get(data, i); + } +} + /// /// @brief Save the dataset as a "*vecs" file. /// @@ -169,6 +194,14 @@ lib::lazy_result_t load_dataset(const File& file, const F& la return load_impl(detail::to_native(file), default_accessor, lazy); } +template F> +lib::lazy_result_t +load_dataset(std::istream& is, const F& lazy, size_t num_vectors, size_t dims) { + auto data = lazy(num_vectors, dims); + populate(is, data); + return data; +} + // Return whether or not a file is directly loadable via file-extension. inline bool special_by_file_extension(std::string_view path) { return (path.ends_with("svs") || path.ends_with("vecs") || path.ends_with("bin")); diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 0fcb31bbb..7e546977d 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -75,24 +75,42 @@ class GenericSerializer { } template - static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) { + static lib::SaveTable metadata(const Data& data) { using T = typename Data::element_type; - // UUID used to identify the file. - auto uuid = lib::UUID{}; - auto filename = ctx.generate_name("data"); - io::save(data, io::NativeFile(filename), uuid); - return lib::SaveTable( + auto table = lib::SaveTable( serialization_schema, save_version, { {"name", "uncompressed"}, - {"binary_file", lib::save(filename.filename())}, {"dims", lib::save(data.dimensions())}, {"num_vectors", lib::save(data.size())}, - {"uuid", uuid.str()}, {"eltype", lib::save(datatype_v)}, } ); + return table; + } + + template + static lib::SaveTable + metadata(const Data& data, const FileName_t& filename, const lib::UUID& uuid) { + auto table = metadata(data); + table.insert("binary_file", filename); + table.insert("uuid", uuid.str()); + return table; + } + + template + static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) { + // UUID used to identify the file. + auto uuid = lib::UUID{}; + auto filename = ctx.generate_name("data"); + io::save(data, io::NativeFile(filename), uuid); + return metadata(data, lib::save(filename.filename()), uuid); + } + + template + static void save(const Data& data, std::ostream& os) { + io::save(data, os); } template F> @@ -116,6 +134,25 @@ class GenericSerializer { } return io::load_dataset(binaryfile.value(), lazy); } + + template F> + static lib::lazy_result_t + load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) { + auto datatype = lib::load_at(table, "eltype"); + if (datatype != datatype_v) { + throw ANNEXCEPTION( + "Trying to load an uncompressed dataset with element types {} to a dataset " + "with element types {}.", + name(datatype), + name>() + ); + } + + size_t num_vectors = lib::load_at(table, "num_vectors"); + size_t dims = lib::load_at(table, "dims"); + + return io::load_dataset(is, lazy, num_vectors, dims); + } }; struct Matcher { @@ -405,6 +442,10 @@ class SimpleData { return GenericSerializer::save(*this, ctx); } + void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } + + lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); } @@ -431,6 +472,20 @@ class SimpleData { ); } + static SimpleData load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const allocator_type& allocator = {} + ) + requires(!is_view) + { + return GenericSerializer::load( + table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { + return SimpleData(n_elements, n_dimensions, allocator); + }) + ); + } + /// /// @brief Try to automatically load the dataset. /// @@ -805,6 +860,10 @@ class SimpleData> { return GenericSerializer::save(*this, ctx); } + void save(std::ostream& os) const { return GenericSerializer::save(*this, os); } + + lib::SaveTable metadata() const { return GenericSerializer::metadata(*this); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return GenericSerializer::check_compatibility(schema, version); } @@ -818,6 +877,18 @@ class SimpleData> { ); } + static SimpleData load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const Blocked& allocator = {} + ) { + return GenericSerializer::load( + table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) { + return SimpleData(n_elements, n_dimensions, allocator); + }) + ); + } + static SimpleData load(const std::filesystem::path& path, const Blocked& allocator = {}) { if (detail::is_likely_reload(path)) { diff --git a/include/svs/core/graph/graph.h b/include/svs/core/graph/graph.h index 48d3e8048..89e456e07 100644 --- a/include/svs/core/graph/graph.h +++ b/include/svs/core/graph/graph.h @@ -276,22 +276,36 @@ template class SimpleGrap ///// Saving static constexpr lib::Version save_version = lib::Version(0, 0, 0); static constexpr std::string_view serialization_schema = "default_graph"; - lib::SaveTable save(const lib::SaveContext& ctx) const { - auto uuid = lib::UUID{}; - auto filename = ctx.generate_name("graph"); - io::save(data_, io::NativeFile(filename), uuid); - return lib::SaveTable( + + lib::SaveTable metadata() const { + auto table = lib::SaveTable( serialization_schema, save_version, {{"name", "graph"}, - {"binary_file", lib::save(filename.filename())}, {"max_degree", lib::save(max_degree())}, {"num_vertices", lib::save(n_nodes())}, - {"uuid", lib::save(uuid.str())}, {"eltype", lib::save(datatype_v)}} ); + return table; } + template + lib::SaveTable metadata(const FileName& filename, const lib::UUID& uuid) const { + auto table = metadata(); + table.insert("binary_file", filename); + table.insert("uuid", uuid.str()); + return table; + } + + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto uuid = lib::UUID{}; + auto filename = ctx.generate_name("graph"); + io::save(data_, io::NativeFile(filename), uuid); + return metadata(lib::save(filename.filename()), uuid); + } + + void save(std::ostream& os) const { io::save(data_, os); } + protected: template F, typename... Args> static lib::lazy_result_t @@ -317,6 +331,42 @@ template class SimpleGrap return lazy(data_type::load(binaryfile.value(), std::forward(args)...)); } + template F, typename... AllocArgs> + static lib::lazy_result_t load( + const lib::ContextFreeLoadTable& table, + const F& lazy, + std::istream& is, + AllocArgs&&... alloc_args + ) { + // Perform a sanity check on the element type. + // Make sure we're loading the correct kind. + auto eltype = lib::load_at(table, "eltype"); + if (eltype != datatype_v) { + throw ANNEXCEPTION( + "Trying to load a graph with adjacency list types {} to a graph with " + "adjacency list types {}.", + name(eltype), + name>() + ); + } + + size_t num_vertices = lib::load_at(table, "num_vertices"); + size_t max_degree = lib::load_at(table, "max_degree"); + + // Build a table compatible with GenericSerializer + auto data_table = toml::table{ + {lib::config_schema_key, data::GenericSerializer::serialization_schema}, + {lib::config_version_key, data::GenericSerializer::save_version.str()}, + {"eltype", lib::save(datatype_v)}, + {"num_vectors", lib::save(num_vertices)}, + {"dims", lib::save(max_degree + 1)}, + }; + + return lazy( + data_type::load(lib::ContextFreeLoadTable(data_table), is, alloc_args...) + ); + } + protected: data_type data_; Idx max_degree_; @@ -366,6 +416,15 @@ class SimpleGraph : public SimpleGraphBase(is, allocator); + } }; template @@ -406,6 +469,13 @@ class SimpleBlockedGraph return parent_type::load(table, lazy); } + static constexpr SimpleBlockedGraph + load(const lib::ContextFreeLoadTable& table, std::istream& is) { + auto lazy = + lib::Lazy([](data_type data) { return SimpleBlockedGraph(std::move(data)); }); + return parent_type::load(table, lazy, is); + } + static constexpr SimpleBlockedGraph load(const std::filesystem::path& path) { if (data::detail::is_likely_reload(path)) { return lib::load_from_disk(path); @@ -413,6 +483,10 @@ class SimpleBlockedGraph return SimpleBlockedGraph(data_type::load(path)); } } + + static constexpr SimpleBlockedGraph load(std::istream& is) { + return lib::load_from_stream(is); + } }; } // namespace svs::graphs diff --git a/include/svs/core/io/native.h b/include/svs/core/io/native.h index 0039128d3..0f6476686 100644 --- a/include/svs/core/io/native.h +++ b/include/svs/core/io/native.h @@ -344,28 +344,16 @@ struct Header { static_assert(sizeof(Header) == header_size, "Mismatch in Native io::v1 header sizes!"); static_assert(std::is_trivially_copyable_v
, "Header must be trivially copyable!"); -template class Writer { +// CRTP +template class Writer { public: - Writer( - const std::string& path, - size_t dimension, - lib::UUID uuid = lib::UUID(lib::ZeroInitializer()) - ) - : dimension_{dimension} - , uuid_{uuid} - , stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} { - // Write a temporary header. - stream_.seekp(0, std::ofstream::beg); - lib::write_binary(stream_, Header()); - } - - size_t dimensions() const { return dimension_; } void overwrite_num_vectors(size_t num_vectors) { vectors_written_ = num_vectors; } // TODO: Error checking to make sure the length is correct. template Writer& append(U&& v) { + std::ostream& os = static_cast(this)->stream(); for (const auto& i : v) { - lib::write_binary(stream_, lib::io_convert(i)); + lib::write_binary(os, lib::io_convert(i)); } ++vectors_written_; return *this; @@ -374,13 +362,37 @@ template class Writer { template requires std::is_same_v Writer& append(std::tuple&& v) { - lib::foreach (v, [&](const auto& x) { lib::write_binary(stream_, x); }); + std::ostream& os = static_cast(this)->stream(); + lib::foreach (v, [&](const auto& x) { lib::write_binary(os, x); }); ++vectors_written_; return *this; } template Writer& operator<<(U&& v) { return append(std::forward(v)); } + protected: + size_t vectors_written_ = 0; +}; + +template class FileWriter : public Writer> { + public: + FileWriter( + const std::string& path, + size_t dimension, + lib::UUID uuid = lib::UUID(lib::ZeroInitializer()) + ) + : dimension_{dimension} + , uuid_{uuid} + , stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} { + // Write a temporary header. + stream_.seekp(0, std::ofstream::beg); + lib::write_binary(stream_, Header()); + } + + std::ostream& stream() { return stream_; } + + size_t dimensions() const { return dimension_; } + void flush() { stream_.flush(); } void writeheader(bool resume = true) { @@ -388,7 +400,7 @@ template class Writer { // Write to the header the number of vectors actually written. stream_.seekp(0); assert(stream_.good()); - lib::write_binary(stream_, Header(vectors_written_, dimension_, uuid_)); + lib::write_binary(stream_, Header(this->vectors_written_, dimension_, uuid_)); if (resume) { stream_.seekp(position, std::ofstream::beg); } @@ -402,20 +414,30 @@ template class Writer { // // We delete the copy constructor and copy assignment operators because // `std::ofstream` isn't copyable anyways. - Writer(const Writer&) = delete; - Writer& operator=(const Writer&) = delete; - Writer(Writer&&) = delete; - Writer& operator=(Writer&&) = delete; + FileWriter(const FileWriter&) = delete; + FileWriter& operator=(const FileWriter&) = delete; + FileWriter(FileWriter&&) = delete; + FileWriter& operator=(FileWriter&&) = delete; // Write the header for the file. - ~Writer() noexcept { writeheader(); } + ~FileWriter() noexcept { writeheader(); } private: size_t dimension_; lib::UUID uuid_; std::ofstream stream_; size_t writes_this_vector_ = 0; - size_t vectors_written_ = 0; +}; + +template class StreamWriter : public Writer> { + public: + StreamWriter(std::ostream& os) + : stream_{os} {} + + std::ostream& stream() { return stream_; } + + private: + std::ostream& stream_; }; /// @@ -449,13 +471,13 @@ class NativeFile { } template - Writer writer( + FileWriter writer( lib::Type SVS_UNUSED(type), size_t dimension, lib::UUID uuid = lib::ZeroUUID ) const { - return Writer(path_, dimension, uuid); + return FileWriter(path_, dimension, uuid); } - Writer<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const { + FileWriter<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const { return writer(lib::Type(), dimensions, uuid); } @@ -715,7 +737,7 @@ class NativeFile { public: using compatible_file_types = lib::Types; - template using Writer = v1::Writer; + template using Writer = v1::FileWriter; explicit NativeFile(std::filesystem::path path) : path_{std::move(path)} {} diff --git a/include/svs/core/translation.h b/include/svs/core/translation.h index a3c4bca34..db65bf213 100644 --- a/include/svs/core/translation.h +++ b/include/svs/core/translation.h @@ -324,27 +324,36 @@ class IDTranslator { "external_to_internal_translation"; static constexpr lib::Version save_version = lib::Version(0, 0, 0); - lib::SaveTable save(const lib::SaveContext& ctx) const { - auto filename = ctx.generate_name("id_translation", "binary"); - // Save the translations to a file. - auto stream = lib::open_write(filename); - for (auto i = begin(), iend = end(); i != iend; ++i) { - // N.B.: Apparently `std::pair` of integers is not trivially copyable ... - lib::write_binary(stream, i->first); - lib::write_binary(stream, i->second); - } + lib::SaveTable save_table() const { return lib::SaveTable( serialization_schema, save_version, {{"kind", kind}, {"num_points", lib::save(size())}, {"external_id_type", lib::save(datatype_v)}, - {"internal_id_type", lib::save(datatype_v)}, - {"filename", lib::save(filename.filename())}} + {"internal_id_type", lib::save(datatype_v)}} ); } - static IDTranslator load(const lib::LoadTable& table) { + void save(std::ostream& os) const { + for (auto i = begin(), iend = end(); i != iend; ++i) { + // N.B.: Apparently `std::pair` of integers is not trivially copyable ... + lib::write_binary(os, i->first); + lib::write_binary(os, i->second); + } + } + + lib::SaveTable save(const lib::SaveContext& ctx) const { + auto filename = ctx.generate_name("id_translation", "binary"); + // Save the translations to a file. + auto os = lib::open_write(filename); + save(os); + auto table = save_table(); + table.insert("filename", lib::save(filename.filename())); + return table; + } + + static void validate(const lib::ContextFreeLoadTable& table) { if (kind != lib::load_at(table, "kind")) { throw ANNEXCEPTION("Mismatched kind!"); } @@ -357,21 +366,31 @@ class IDTranslator { if (internal_id_name != lib::load_at(table, "internal_id_type")) { throw ANNEXCEPTION("Mismatched internal id types!"); } + } - // Now that we've more-or-less validated the metadata, time to start loading - // the points. + static IDTranslator load(const lib::ContextFreeLoadTable& table, std::istream& is) { + IDTranslator::validate(table); auto num_points = lib::load_at(table, "num_points"); + auto translator = IDTranslator{}; - auto resolved = table.resolve_at("filename"); - auto stream = lib::open_read(resolved); for (size_t i = 0; i < num_points; ++i) { - auto external_id = lib::read_binary(stream); - auto internal_id = lib::read_binary(stream); + auto external_id = lib::read_binary(is); + auto internal_id = lib::read_binary(is); translator.insert_translation(external_id, internal_id); } return translator; } + static IDTranslator load(const lib::LoadTable& table) { + IDTranslator::validate(table); + + // Now that we've more-or-less validated the metadata, time to start loading + // the points. + auto resolved = table.resolve_at("filename"); + auto is = lib::open_read(resolved); + return IDTranslator::load(table, is); + } + private: template void check( diff --git a/include/svs/index/flat/dynamic_flat.h b/include/svs/index/flat/dynamic_flat.h index 868054ba1..71d72e10e 100644 --- a/include/svs/index/flat/dynamic_flat.h +++ b/include/svs/index/flat/dynamic_flat.h @@ -36,6 +36,7 @@ #include "svs/lib/invoke.h" #include "svs/lib/misc.h" #include "svs/lib/preprocessor.h" +#include "svs/lib/stream.h" #include "svs/lib/threads.h" namespace svs::index::flat { @@ -403,6 +404,26 @@ template class DynamicFlatIndex { // Save the dataset in the separate data directory lib::save_to_disk(data_, data_directory); } + + void save(std::ostream& os) { + compact(); + + lib::begin_serialization(os); + // Save data structures and translation to config directory + lib::SaveTable save_table = lib::SaveTable( + "dynamic_flat_config", + save_version, + { + {"name", name()}, + {"translation", lib::detail::exit_hook(translator_.save_table())}, + } + ); + lib::save_to_stream(save_table, os); + translator_.save(os); + + lib::save_to_stream(data_, os); + } + constexpr std::string_view name() const { return "dynamic flat index"; } ///// Thread Pool Management @@ -767,4 +788,44 @@ auto auto_dynamic_assemble( ); } +template +auto auto_dynamic_assemble( + std::istream& is, + LazyDataLoader&& data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + // Set this to `true` to use the identity map for ID translation. + // This allows us to read files generated by the static index construction routines + // to easily benchmark the static versus dynamic implementation. + // + // This is an internal API and should not be considered officially supported nor stable. + bool SVS_UNUSED(debug_load_from_static) = false, + svs::logging::logger_ptr logger = svs::logging::get() +) { + auto table = lib::detail::read_metadata(is); + auto translation = + table.template cast().at("translation").template cast(); + IDTranslator translator = IDTranslator::load(translation, is); + + auto data = data_loader(); + + // Validate the translator + auto translator_size = translator.size(); + auto datasize = data.size(); + if (translator_size != datasize) { + throw ANNEXCEPTION( + "Translator has {} IDs but should have {}", translator_size, datasize + ); + } + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + return DynamicFlatIndex( + std::move(data), + std::move(translator), + std::move(distance), + std::move(threadpool), + std::move(logger) + ); +} + } // namespace svs::index::flat diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index 187fc7440..925fe68a5 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -522,6 +522,11 @@ class FlatIndex { void save(const std::filesystem::path& data_directory) const { lib::save_to_disk(data_, data_directory); } + + void save(std::ostream& os) const { + lib::begin_serialization(os); + lib::save_to_stream(data_, os); + } }; /// diff --git a/include/svs/index/inverted/clustering.h b/include/svs/index/inverted/clustering.h index db82e1593..de5585c29 100644 --- a/include/svs/index/inverted/clustering.h +++ b/include/svs/index/inverted/clustering.h @@ -571,6 +571,36 @@ template class Clustering { // Saving and Loading. static constexpr lib::Version save_version{0, 0, 0}; static constexpr std::string_view serialization_schema = "clustering"; + + lib::SaveTable metadata() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"integer_type", lib::save(datatype_v)}, + {"num_clusters", lib::save(size())}} + ); + } + + void save(std::ostream& os) const { + for (const auto& [id, cluster] : *this) { + cluster.serialize(os); + } + } + + static Clustering + load(const lib::ContextFreeLoadTable& table, std::istream& stream) { + auto saved_integer_type = lib::load_at(table, "integer_type"); + if (saved_integer_type != datatype_v) { + throw ANNEXCEPTION("Clustering was saved using {} but we're trying to reload it using {}!", saved_integer_type, datatype_v); + } + auto num_clusters = lib::load_at(table, "num_clusters"); + auto clustering = Clustering(); + for (size_t i = 0; i < num_clusters; ++i) { + clustering.insert(Cluster::deserialize(stream)); + } + return clustering; + } + lib::SaveTable save(const lib::SaveContext& ctx) const { // Serialize all clusters into an auxiliary file. auto fullpath = ctx.generate_name("clustering", "bin"); @@ -582,48 +612,28 @@ template class Clustering { } } - return lib::SaveTable( - serialization_schema, - save_version, - {{"filepath", lib::save(fullpath.filename())}, - SVS_LIST_SAVE(filesize), - {"integer_type", lib::save(datatype_v)}, - {"num_clusters", lib::save(size())}} - ); + auto table = metadata(); + table.insert("filepath", lib::save(fullpath.filename())); + table.insert("filesize", lib::save(filesize)); + return table; + + return table; } static Clustering load(const lib::LoadTable& table) { - // Ensure we have the correct integer type when decoding. - auto saved_integer_type = lib::load_at(table, "integer_type"); - if (saved_integer_type != datatype_v) { - auto type = datatype_v; + auto expected_filesize = lib::load_at(table, "filesize"); + auto file = table.resolve_at("filepath"); + size_t actual_filesize = std::filesystem::file_size(file); + if (actual_filesize != expected_filesize) { throw ANNEXCEPTION( - "Clustering was saved using {} but we're trying to reload it using {}!", - saved_integer_type, - type + "Expected cluster file size to be {}. Instead, it is {}!", + actual_filesize, + expected_filesize ); } - auto num_clusters = lib::load_at(table, "num_clusters"); - auto expected_filesize = lib::load_at(table, "filesize"); - auto clustering = Clustering(); - { - auto file = table.resolve_at("filepath"); - size_t actual_filesize = std::filesystem::file_size(file); - if (actual_filesize != expected_filesize) { - throw ANNEXCEPTION( - "Expected cluster file size to be {}. Instead, it is {}!", - actual_filesize, - expected_filesize - ); - } - - auto io = lib::open_read(file); - for (size_t i = 0; i < num_clusters; ++i) { - clustering.insert(Cluster::deserialize(io)); - } - } - return clustering; + auto io = lib::open_read(file); + return load(table, io); } private: diff --git a/include/svs/index/inverted/memory_based.h b/include/svs/index/inverted/memory_based.h index 3d3fc24c5..8b2e8c8e1 100644 --- a/include/svs/index/inverted/memory_based.h +++ b/include/svs/index/inverted/memory_based.h @@ -497,6 +497,8 @@ template class InvertedIndex { index_.save(index_config, graph, data); } + void save_primary_index(std::ostream& os) const { index_.save(os); } + ///// Accessors /// @brief Getter method for logger svs::logging::logger_ptr get_logger() const { return logger_; } @@ -655,4 +657,46 @@ auto assemble_from_clustering( ); } +template < + typename DataProto, + typename Distance, + StorageStrategy Strategy, + typename ThreadPoolProto> +auto assemble_from_clustering( + std::istream& is, + DataProto data_proto, + Distance distance, + Strategy strategy, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() +) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto original = svs::detail::dispatch_load(std::move(data_proto), threadpool); + auto clustering = lib::load_from_stream>(is); + auto ids = clustering.sorted_centroids(); + + // skip magic + svs::lib::detail::Deserializer::build(is); + auto index = index::vamana::auto_assemble( + is, + lib::Lazy([&]() { return GraphLoader::return_type::load(is); }), + lib::Lazy([&]() { + using T = typename std::decay_t::element_type; + constexpr size_t Ext = std::decay_t::extent; + return lib::load_from_stream>(is); + }), + distance, + 1, + logger + ); + + return InvertedIndex( + std::move(index), + strategy(original, clustering, HugepageAllocator()), + std::move(ids), + std::move(threadpool), + std::move(logger) + ); +} + } // namespace svs::index::inverted diff --git a/include/svs/index/ivf/clustering.h b/include/svs/index/ivf/clustering.h index cc9bb09b8..65828525b 100644 --- a/include/svs/index/ivf/clustering.h +++ b/include/svs/index/ivf/clustering.h @@ -410,6 +410,66 @@ class DenseClusteredDataset { static constexpr lib::Version save_version{0, 0, 0}; static constexpr std::string_view serialization_schema = "ivf_dense_clustered_dataset"; + lib::SaveTable metadata() const { + auto num_clusters = size(); + auto dims = dimensions(); + + return lib::SaveTable( + serialization_schema, + save_version, + {{"num_clusters", lib::save(num_clusters)}, + {"dimensions", lib::save(dims)}, + {"prefetch_offset", lib::save(prefetch_offset_)}, + {"index_type", lib::save(datatype_v)}} + ); + } + + void save(std::ostream& os) const { + auto num_clusters = size(); + + // Compute cluster sizes and ID offsets + std::vector cluster_sizes, ids_offsets; + calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets); + + lib::write_binary(os, cluster_sizes); + for (size_t i = 0; i < num_clusters; ++i) { + lib::save_to_stream(clusters_[i].data_, os); + if (!clusters_[i].ids_.empty()) { + lib::write_binary(os, clusters_[i].ids_); + } + } + } + + template + static DenseClusteredDataset load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const Allocator& allocator = Allocator{} + ) { + auto num_clusters = lib::load_at(table, "num_clusters"); + [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); + auto prefetch_offset = lib::load_at(table, "prefetch_offset"); + check_saved_index_type(table); + + DenseClusteredDataset result; + result.prefetch_offset_ = prefetch_offset; + result.clusters_.reserve(num_clusters); + + std::vector cluster_sizes(num_clusters); + lib::read_binary(is, cluster_sizes); + + for (size_t i = 0; i < num_clusters; ++i) { + auto cluster_data = lib::load_from_stream(is, allocator); + size_t cluster_size = cluster_sizes[i]; + std::vector cluster_ids(cluster_size); + if (cluster_size > 0) { + lib::read_binary(is, cluster_ids); + } + result.clusters_.emplace_back(std::move(cluster_data), std::move(cluster_ids)); + } + return result; + } + /// @brief Save the DenseClusteredDataset to disk. /// /// Saves all cluster data using the existing save mechanisms for each data type @@ -427,16 +487,8 @@ class DenseClusteredDataset { auto dims = dimensions(); // Compute cluster sizes and ID offsets - std::vector cluster_sizes(num_clusters); - std::vector ids_offsets(num_clusters + 1); - - size_t ids_offset = 0; - for (size_t i = 0; i < num_clusters; ++i) { - cluster_sizes[i] = clusters_[i].size(); - ids_offsets[i] = ids_offset; - ids_offset += cluster_sizes[i] * sizeof(I); - } - ids_offsets[num_clusters] = ids_offset; + std::vector cluster_sizes, ids_offsets; + calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets); // Create a temporary directory for cluster data lib::UniqueTempDirectory tempdir{"svs_ivf_clusters_save"}; @@ -495,7 +547,7 @@ class DenseClusteredDataset { {"ids_file", lib::save(std::string("ids.bin"))}, {"cluster_sizes_file", lib::save(std::string("cluster_sizes.bin"))}, {"ids_offsets_file", lib::save(std::string("ids_offsets.bin"))}, - {"total_ids_bytes", lib::save(ids_offset)}} + {"total_ids_bytes", lib::save(ids_offsets[num_clusters])}} ); } @@ -528,17 +580,7 @@ class DenseClusteredDataset { // since each cluster's data type determines its own dimensions [[maybe_unused]] auto dims = lib::load_at(table, "dimensions"); auto prefetch_offset = lib::load_at(table, "prefetch_offset"); - - // Verify index type matches - auto saved_index_type = lib::load_at(table, "index_type"); - if (saved_index_type != datatype_v) { - throw ANNEXCEPTION( - "DenseClusteredDataset was saved using index type {} but we're trying to " - "reload it using {}!", - saved_index_type, - datatype_v - ); - } + check_saved_index_type(table); auto base_dir = table.context().get_directory(); @@ -603,6 +645,35 @@ class DenseClusteredDataset { } private: + void calculate_cluster_sizes_and_offsets( + std::vector* cluster_sizes, std::vector* ids_offsets + ) const { + auto num_clusters = size(); + cluster_sizes->resize(num_clusters); + ids_offsets->resize(num_clusters + 1); + + size_t ids_offset = 0; + for (size_t i = 0; i < num_clusters; ++i) { + (*cluster_sizes)[i] = clusters_[i].size(); + (*ids_offsets)[i] = ids_offset; + ids_offset += clusters_[i].size() * sizeof(I); + } + (*ids_offsets)[num_clusters] = ids_offset; + } + + static void check_saved_index_type(const lib::ContextFreeLoadTable& table) { + // Verify index type matches + auto saved_index_type = lib::load_at(table, "index_type"); + if (saved_index_type != datatype_v) { + throw ANNEXCEPTION( + "DenseClusteredDataset was saved using index type {} but we're trying to " + "reload it using {}!", + saved_index_type, + datatype_v + ); + } + } + std::vector> clusters_; size_t prefetch_offset_ = 8; }; diff --git a/include/svs/index/ivf/dynamic_ivf.h b/include/svs/index/ivf/dynamic_ivf.h index 92fc86ed3..40ad32c3f 100644 --- a/include/svs/index/ivf/dynamic_ivf.h +++ b/include/svs/index/ivf/dynamic_ivf.h @@ -761,6 +761,10 @@ class DynamicIVFIndex { lib::save_to_disk(clusters_, clusters_dir); } + void save(std::ostream& SVS_UNUSED(os)) const { + throw ANNEXCEPTION("Placeholder; not implemented!"); + } + private: ///// Helper Methods ///// diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 263b18686..00f45735a 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -465,6 +465,7 @@ class IVFIndex { /// Each directory may be created as a side-effect of this method call provided that /// the parent directory exists. /// + void save( const std::filesystem::path& config_directory, const std::filesystem::path& data_directory @@ -504,6 +505,27 @@ class IVFIndex { lib::save_to_disk(cluster_, clusters_dir); } + void save(std::ostream& os) const { + lib::begin_serialization(os); + + // Save config + auto data_type_config = DataTypeTraits::get_config(); + data_type_config.centroid_type = datatype_v; + auto save_table = lib::SaveTable( + serialization_schema, + save_version, + {{"name", lib::save(name())}, {"num_clusters", lib::save(num_clusters())}} + ); + save_table.insert("data_type_config", lib::save(data_type_config)); + lib::save_to_stream(save_table, os); + + // Save centroids + lib::save_to_stream(centroids_, os); + + // Save clusters + lib::save_to_stream(cluster_, os); + } + private: ///// Core Components ///// Centroids centroids_; @@ -964,4 +986,36 @@ auto load_ivf_index( return index; } +template < + typename CentroidType, + typename DataType, + typename Distance, + typename ThreadpoolProto> +auto load_ivf_index( + std::istream& is, + Distance distance, + ThreadpoolProto threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() +) { + using centroids_type = data::SimpleData; + using data_type = typename DataType::lib_alloc_data_type; + using cluster_type = DenseClusteredDataset; + + lib::detail::read_metadata(is); + auto centroids = lib::load_from_stream(is); + auto clusters = lib::load_from_stream(is); + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + + return IVFIndex( + std::move(centroids), + std::move(clusters), + std::move(distance), + std::move(threadpool), + intra_query_thread_count, + logger + ); +} + } // namespace svs::index::ivf diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 169be1995..a5b88b4db 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -988,6 +988,18 @@ class MutableVamanaIndex { ///// Saving + VamanaIndexParameters parameters() const { + return { + entry_point_.front(), + {alpha_, + graph_.max_degree(), + get_construction_window_size(), + get_max_candidates(), + prune_to_, + get_full_search_history()}, + get_search_parameters()}; + } + static constexpr lib::Version save_version = lib::Version(0, 0, 0); void save( const std::filesystem::path& config_directory, @@ -1003,22 +1015,12 @@ class MutableVamanaIndex { lib::save_to_disk( lib::SaveOverride([&](const lib::SaveContext& ctx) { // Save the construction parameters. - auto parameters = VamanaIndexParameters{ - entry_point_.front(), - {alpha_, - graph_.max_degree(), - get_construction_window_size(), - get_max_candidates(), - prune_to_, - get_full_search_history()}, - get_search_parameters()}; - return lib::SaveTable( "vamana_dynamic_auxiliary_parameters", save_version, { {"name", lib::save(name())}, - {"parameters", lib::save(parameters, ctx)}, + {"parameters", lib::save(parameters(), ctx)}, {"translation", lib::save(translator_, ctx)}, } ); @@ -1032,6 +1034,31 @@ class MutableVamanaIndex { lib::save_to_disk(graph_, graph_directory); } + void save(std::ostream& os) { + // Post-consolidation, all entries should be "valid". + // Therefore, we don't need to save the slot metadata. + consolidate(); + compact(); + + lib::begin_serialization(os); + auto save_table = lib::SaveTable( + "vamana_dynamic_auxiliary_parameters", + save_version, + { + {"name", lib::save(name())}, + {"parameters", lib::save(parameters())}, + {"translation", lib::detail::exit_hook(translator_.save_table())}, + } + ); + lib::save_to_stream(save_table, os); + translator_.save(os); + + // Save the dataset. + lib::save_to_stream(data_, os); + // Save the graph. + lib::save_to_stream(graph_, os); + } + ///// ///// Calibrate ///// @@ -1429,4 +1456,60 @@ auto auto_dynamic_assemble( std::move(logger)}; } +template < + typename LazyGraphLoader, + typename LazyDataLoader, + typename Distance, + typename ThreadPoolProto> +auto auto_dynamic_assemble( + std::istream& is, + LazyGraphLoader graph_loader, + LazyDataLoader data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + bool SVS_UNUSED(debug_load_from_static) = false, + svs::logging::logger_ptr logger = svs::logging::get() +) { + // Read the combined TOML (parameters + translation) + // and the translator binary data. + auto table = lib::detail::read_metadata(is); + + auto parameters = lib::load( + table.template cast().at("parameters").template cast() + ); + + auto translation = + table.template cast().at("translation").template cast(); + + auto translator = IDTranslator::load(translation, is); + + auto data = data_loader(); + auto graph = graph_loader(); + + auto datasize = data.size(); + auto graphsize = graph.n_nodes(); + if (datasize != graphsize) { + throw ANNEXCEPTION( + "Reloaded data has {} nodes while the graph has {} nodes!", datasize, graphsize + ); + } + + auto translator_size = translator.size(); + if (translator_size != datasize) { + throw ANNEXCEPTION( + "Translator has {} IDs but should have {}", translator_size, datasize + ); + } + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + return MutableVamanaIndex{ + parameters, + std::move(data), + std::move(graph), + std::move(distance), + std::move(translator), + std::move(threadpool), + std::move(logger)}; +} + } // namespace svs::index::vamana diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index b7c136645..70a921353 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -85,8 +85,7 @@ struct VamanaIndexParameters { static constexpr lib::Version save_version = lib::Version(0, 0, 3); static constexpr std::string_view serialization_schema = "vamana_index_parameters"; - // Save and Reload. - lib::SaveTable save() const { + lib::SaveTable metadata() const { return lib::SaveTable( serialization_schema, save_version, @@ -97,6 +96,9 @@ struct VamanaIndexParameters { ); } + // Save to Table and Reload. + lib::SaveTable save() const { return metadata(); } + static bool check_load_compatibility(std::string_view schema, lib::Version version) { return schema == serialization_schema && version <= save_version; } @@ -814,6 +816,20 @@ class VamanaIndex { lib::save_to_disk(graph_, graph_directory); } + void save(std::ostream& os) const { + // Construct and save runtime parameters. + auto parameters = VamanaIndexParameters{ + entry_point_.front(), build_parameters_, get_search_parameters()}; + + lib::begin_serialization(os); + // Config + lib::save_to_stream(parameters, os); + // Data + lib::save_to_stream(data_, os); + // // Graph + lib::save_to_stream(graph_, os); + } + ///// Calibration // Return the maximum degree of the graph. @@ -1006,6 +1022,37 @@ auto auto_assemble( return index; } +template < + typename LazyGraphLoader, + typename LazyDataLoader, + typename Distance, + typename ThreadPoolProto> +auto auto_assemble( + std::istream& is, + LazyGraphLoader graph_loader, + LazyDataLoader data_loader, + Distance distance, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() +) { + VamanaIndexParameters config = lib::load_from_stream(is); + auto data = data_loader(); + auto graph = graph_loader(); + + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + // Extract the index type of the provided graph. + using I = typename decltype(graph)::index_type; + auto index = VamanaIndex{ + std::move(graph), + std::move(data), + I{}, + std::move(distance), + std::move(threadpool), + std::move(logger)}; + index.apply(config); + return index; +} + /// @brief Verify parameters and set defaults if needed template void verify_and_set_default_index_parameters( diff --git a/include/svs/lib/archiver.h b/include/svs/lib/archiver.h new file mode 100644 index 000000000..ce3ca438f --- /dev/null +++ b/include/svs/lib/archiver.h @@ -0,0 +1,93 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// svs +#include "svs/lib/exception.h" + +// stl +#include +#include +#include +#include +#include + +namespace svs::lib { + +// CRTP +template struct Archiver { + using size_type = uint64_t; + + // TODO: Define CACHELINE_BYTES in a common place + // rather than duplicating it here and in prefetch.h + static constexpr auto CACHELINE_BYTES = 64; + + static size_type write_size(std::ostream& os, size_type size) { + os.write(reinterpret_cast(&size), sizeof(size)); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + return sizeof(size); + } + + static size_type read_size(std::istream& is, size_type& size) { + is.read(reinterpret_cast(&size), sizeof(size)); + if (!is) { + throw ANNEXCEPTION("Error reading from stream!"); + } + return sizeof(size); + } + + static size_type write_name(std::ostream& os, const std::string& name) { + auto bytes = write_size(os, name.size()); + os.write(name.data(), name.size()); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + return bytes + name.size(); + } + + static size_type read_name(std::istream& is, std::string& name) { + size_type size = 0; + auto bytes = read_size(is, size); + name.resize(size); + is.read(name.data(), size); + if (!is) { + throw ANNEXCEPTION("Error reading from stream!"); + } + return bytes + size; + } + + static void read_from_istream(std::istream& in, std::ostream& out, size_t data_size) { + // Copy the data in chunks. + constexpr size_t buffer_size = 1 << 13; // 8KB buffer + alignas(CACHELINE_BYTES) char buffer[buffer_size]; + + size_t bytes_remaining = data_size; + while (bytes_remaining > 0) { + size_t to_read = std::min(buffer_size, bytes_remaining); + in.read(buffer, to_read); + if (!in) { + throw ANNEXCEPTION("Error reading from stream!"); + } + out.write(buffer, to_read); + bytes_remaining -= to_read; + } + } +}; + +} // namespace svs::lib diff --git a/include/svs/lib/file.h b/include/svs/lib/file.h index 937bc1502..7d949ae4d 100644 --- a/include/svs/lib/file.h +++ b/include/svs/lib/file.h @@ -17,6 +17,7 @@ #pragma once // svs +#include "svs/lib/archiver.h" #include "svs/lib/exception.h" #include "svs/lib/uuid.h" @@ -151,50 +152,9 @@ struct UniqueTempDirectory { // Uses a simple custom binary format. // Not meant to be super efficient, just a simple way to serialize a directory // structure to a stream. -struct DirectoryArchiver { - using size_type = uint64_t; - - // TODO: Define CACHELINE_BYTES in a common place - // rather than duplicating it here and in prefetch.h - static constexpr auto CACHELINE_BYTES = 64; +struct DirectoryArchiver : Archiver { static constexpr size_type magic_number = 0x5e2d58d9f3b4a6c1; - static size_type write_size(std::ostream& os, size_type size) { - os.write(reinterpret_cast(&size), sizeof(size)); - if (!os) { - throw ANNEXCEPTION("Error writing to stream!"); - } - return sizeof(size); - } - - static size_type read_size(std::istream& is, size_type& size) { - is.read(reinterpret_cast(&size), sizeof(size)); - if (!is) { - throw ANNEXCEPTION("Error reading from stream!"); - } - return sizeof(size); - } - - static size_type write_name(std::ostream& os, const std::string& name) { - auto bytes = write_size(os, name.size()); - os.write(name.data(), name.size()); - if (!os) { - throw ANNEXCEPTION("Error writing to stream!"); - } - return bytes + name.size(); - } - - static size_type read_name(std::istream& is, std::string& name) { - size_type size = 0; - auto bytes = read_size(is, size); - name.resize(size); - is.read(name.data(), size); - if (!is) { - throw ANNEXCEPTION("Error reading from stream!"); - } - return bytes + size; - } - static size_type write_file( std::ostream& stream, const std::filesystem::path& path, @@ -262,22 +222,9 @@ struct DirectoryArchiver { throw ANNEXCEPTION("Error opening file {} for writing!", path); } - // Copy the data in chunks. - constexpr size_t buffer_size = 1 << 13; // 8KB buffer - alignas(CACHELINE_BYTES) char buffer[buffer_size]; - - size_t bytes_remaining = filesize; - while (bytes_remaining > 0) { - size_t to_read = std::min(buffer_size, bytes_remaining); - stream.read(buffer, to_read); - if (!stream) { - throw ANNEXCEPTION("Error reading from stream!"); - } - out.write(buffer, to_read); - if (!out) { - throw ANNEXCEPTION("Error writing to file {}!", path); - } - bytes_remaining -= to_read; + read_from_istream(stream, out, filesize); + if (!out) { + throw ANNEXCEPTION("Error writing to file {}!", path); } return header_bytes + filesize; @@ -310,18 +257,16 @@ struct DirectoryArchiver { return total_bytes; } - static size_t unpack(std::istream& stream, const std::filesystem::path& root) { + static size_t + unpack(std::istream& stream, const std::filesystem::path& root, size_type magic) { namespace fs = std::filesystem; - // Read and verify the magic number. - size_type magic = 0; - auto total_bytes = read_size(stream, magic); if (magic != magic_number) { throw ANNEXCEPTION("Invalid magic number in directory unpacking!"); } size_type num_files = 0; - total_bytes += read_size(stream, num_files); + auto total_bytes = read_size(stream, num_files); if (!stream) { throw ANNEXCEPTION("Error reading from stream!"); } @@ -339,5 +284,14 @@ struct DirectoryArchiver { return total_bytes; } + + static size_t unpack(std::istream& stream, const std::filesystem::path& root) { + // Read and verify the magic number. + size_type magic = 0; + auto total_bytes = read_size(stream, magic); + + total_bytes += unpack(stream, root, magic); + return total_bytes; + } }; } // namespace svs::lib diff --git a/include/svs/lib/saveload/load.h b/include/svs/lib/saveload/load.h index 767e02afa..06f718d22 100644 --- a/include/svs/lib/saveload/load.h +++ b/include/svs/lib/saveload/load.h @@ -22,6 +22,11 @@ // stl #include +#include "svs/core/io/native.h" +#include "svs/lib/file.h" +#include "svs/lib/readwrite.h" +#include "svs/lib/stream.h" + namespace svs::lib { /// @@ -828,6 +833,35 @@ inline SerializedObject begin_deserialization(const std::filesystem::path& fullp std::move(table), lib::LoadContext{fullpath.parent_path(), version}}; } +class Deserializer { + lib::StreamArchiver::size_type magic_; + + explicit Deserializer(lib::StreamArchiver::size_type magic) + : magic_(magic) {} + + public: + static Deserializer build(std::istream& stream) { + lib::StreamArchiver::size_type magic = 0; + lib::StreamArchiver::read_size(stream, magic); + + return Deserializer(magic); + } + + auto magic() const { return magic_; } + + bool is_native() const { return magic_ == lib::StreamArchiver::magic_number; } +}; + +inline ContextFreeSerializedObject read_metadata(std::istream& stream) { + if (!stream) { + throw ANNEXCEPTION("Error reading from stream!"); + } + + auto table = lib::StreamArchiver::read_table(stream); + + return ContextFreeSerializedObject{std::move(table)}; +} + } // namespace detail inline SerializedObject begin_deserialization(const std::filesystem::path& path) { @@ -877,6 +911,31 @@ T load_from_disk(const std::filesystem::path& path, Args&&... args) { return lib::load_from_disk(Loader(), path, SVS_FWD(args)...); } +template +concept LoadableFromTable = requires(const T& x +) { T::load(std::declval(), std::declval()...); }; + +///// load_from_stream +template +T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) + requires LoadableFromTable +{ + // Object is loadable from it's toml::table + return lib::load(loader, detail::read_metadata(stream), SVS_FWD(args)...); +} + +template +T load_from_stream(const Loader& loader, std::istream& stream, Args&&... args) + requires(!LoadableFromTable) +{ + return lib::load(loader, detail::read_metadata(stream), stream, SVS_FWD(args)...); +} + +template +T load_from_stream(std::istream& stream, Args&&... args) { + return lib::load_from_stream(Loader(), stream, SVS_FWD(args)...); +} + ///// load_from_file template diff --git a/include/svs/lib/saveload/save.h b/include/svs/lib/saveload/save.h index 60f556e77..edd627b75 100644 --- a/include/svs/lib/saveload/save.h +++ b/include/svs/lib/saveload/save.h @@ -22,6 +22,7 @@ // svs #include "svs/lib/file.h" #include "svs/lib/readwrite.h" +#include "svs/lib/stream.h" #include "svs/lib/version.h" // stl @@ -319,6 +320,17 @@ void save_node_to_file( auto file = svs::lib::open_write(path, std::ios_base::out); file << top_table << "\n"; } + +template +void save_node_to_stream( + Nodelike&& node, std::ostream& os, const lib::Version& version = CURRENT_SAVE_VERSION +) { + auto top_table = toml::table( + {{config_version_key, version.str()}, {config_object_key, SVS_FWD(exit_hook(node))}} + ); + + StreamArchiver::write_table(os, top_table); +} } // namespace detail /// @@ -365,4 +377,23 @@ template void save_to_file(const T& x, const std::filesystem::path& detail::save_node_to_file(lib::save(x), path); } +inline void begin_serialization(std::ostream& os) { + lib::StreamArchiver::write_size(os, lib::StreamArchiver::magic_number); +} + +inline void save_to_stream(const lib::SaveTable& x, std::ostream& os) { + detail::save_node_to_stream(x, os); +} + +template void save_to_stream(const T& x, std::ostream& os) { + if constexpr (requires { x.metadata(); }) { + save_to_stream(x.metadata(), os); + if constexpr (requires { x.save(os); }) { + x.save(os); + } + } else { + static_assert(sizeof(T) == 0, "Type not stream-serializable"); + } +} + } // namespace svs::lib diff --git a/include/svs/lib/stream.h b/include/svs/lib/stream.h new file mode 100644 index 000000000..9b9325335 --- /dev/null +++ b/include/svs/lib/stream.h @@ -0,0 +1,73 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// saveload +#include "svs/lib/saveload/core.h" + +// svs +#include "svs/lib/archiver.h" +#include "svs/lib/exception.h" + +// stl +#include +#include + +namespace svs::lib::detail { +template auto get_buffer_size(T& ss) { + if constexpr (requires { ss.rdbuf()->view(); }) { + return ss.rdbuf()->view().size(); + } else { + return ss.str().size(); + } +} +} // namespace svs::lib::detail + +namespace svs::lib { + +struct StreamArchiver : Archiver { + // SVS_STRM + static constexpr size_type magic_number = 0x5356535f5354524d; + + static auto read_table(std::istream& is) { + std::uint64_t tablesize = 0; + read_size(is, tablesize); + + std::stringstream ss; + read_from_istream(is, ss, tablesize); + + return toml::parse(ss); + } + + static void write_table(std::ostream& os, const toml::table& table) { + std::stringstream ss; + ss << table << "\n"; + + // The best way to get the table size is a c++20 feature: + // ss.rdbuf()->view().size(), + // but Apple's Clang 15 doesn't support std::stringbuf::view() + lib::StreamArchiver::size_type tablesize = detail::get_buffer_size(ss); + + lib::StreamArchiver::write_size(os, tablesize); + os << ss.rdbuf(); + if (!os) { + throw ANNEXCEPTION("Error writing to stream!"); + } + } +}; + +} // namespace svs::lib diff --git a/include/svs/orchestrators/dynamic_flat.h b/include/svs/orchestrators/dynamic_flat.h index e06efb451..fa7d35855 100644 --- a/include/svs/orchestrators/dynamic_flat.h +++ b/include/svs/orchestrators/dynamic_flat.h @@ -123,13 +123,7 @@ class DynamicFlatImpl // Stream-based save implementation void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current DynamicFlat backend doesn't support saving!"); } @@ -282,29 +276,47 @@ class DynamicFlat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_dynflat_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION( - "Invalid Dynamic Flat index archive: missing config directory!" + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + return DynamicFlat( + AssembleTag(), + manager::as_typelist(), + index::flat::auto_dynamic_assemble( + stream, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream(stream, SVS_FWD(data_args)...); + }, + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ) ); - } + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_dynflat_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); + + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION( + "Invalid Dynamic Flat index archive: missing config directory!" + ); + } + + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION( + "Invalid Dynamic Flat index archive: missing data directory!" + ); + } - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Dynamic Flat index archive: missing data directory!" + return assemble( + config_path, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) ); } - - return assemble( - config_path, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) - ); } ///// Distance diff --git a/include/svs/orchestrators/dynamic_vamana.h b/include/svs/orchestrators/dynamic_vamana.h index da1af7b19..045077387 100644 --- a/include/svs/orchestrators/dynamic_vamana.h +++ b/include/svs/orchestrators/dynamic_vamana.h @@ -355,33 +355,77 @@ class DynamicVamana : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_dynamic_vamana>( + index::vamana::auto_dynamic_assemble( + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(stream); }, + // lazy data loader + [&]() -> Data { + return lib::load_from_stream( + stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ) + ); + }); + } else { + return make_dynamic_vamana>( + index::vamana::auto_dynamic_assemble( + stream, + // lazy graph loader + [&]() -> GraphType { return GraphType::load(stream); }, + // lazy data loader + [&]() -> Data { + return lib::load_from_stream( + stream, SVS_FWD(data_args)... + ); + }, + distance, + std::move(threadpool) + ) + ); + } + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); - } + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!" + ); + } - const auto graph_path = tempdir.get() / "graph"; - if (!fs::is_directory(graph_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); - } + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!" + ); + } - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); - } + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + } - return assemble( - config_path, - svs::GraphLoader{graph_path}, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)), - false - ); + return assemble( + config_path, + svs::GraphLoader{graph_path}, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)), + false + ); + } } /// @copydoc svs::Vamana::batch_iterator diff --git a/include/svs/orchestrators/exhaustive.h b/include/svs/orchestrators/exhaustive.h index b33b6cc4a..e22dab7a4 100644 --- a/include/svs/orchestrators/exhaustive.h +++ b/include/svs/orchestrators/exhaustive.h @@ -24,6 +24,7 @@ #include "svs/core/distance.h" #include "svs/core/graph.h" #include "svs/lib/preprocessor.h" +#include "svs/lib/stream.h" #include "svs/lib/threads.h" #include "svs/orchestrators/manager.h" @@ -87,9 +88,7 @@ class FlatImpl : public manager::ManagerImpl { void save(std::ostream& stream) const override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_flat_save"}; - save(tempdir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); } @@ -196,14 +195,23 @@ class Flat : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_flat_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); - return assemble( - lib::load_from_disk(tempdir, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) - ); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + return assemble( + lib::load_from_stream(stream, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_flat_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); + return assemble( + lib::load_from_disk(tempdir, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } } ///// Distance diff --git a/include/svs/orchestrators/inverted.h b/include/svs/orchestrators/inverted.h index 6b6e50470..7fb7e60da 100644 --- a/include/svs/orchestrators/inverted.h +++ b/include/svs/orchestrators/inverted.h @@ -34,6 +34,9 @@ class InvertedInterface { const std::filesystem::path& primary_data, const std::filesystem::path& primary_graph ) = 0; + + ///// Saving + virtual void save_primary_index(std::ostream& os) = 0; }; template @@ -72,6 +75,8 @@ class InvertedImpl : public manager::ManagerImpl { ) override { impl().save_primary_index(primary_config, primary_data, primary_graph); } + + void save_primary_index(std::ostream& os) override { impl().save_primary_index(os); } }; ///// @@ -106,6 +111,8 @@ class Inverted : public manager::IndexManager { impl_->save_primary_index(primary_config, primary_data, primary_graph); } + void save_primary_index(std::ostream& os) { impl_->save_primary_index(os); } + ///// Building template < manager::QueryTypeDefinition QueryTypes, @@ -168,6 +175,30 @@ class Inverted : public manager::IndexManager { std::move(threadpool_proto) )}; } + template < + manager::QueryTypeDefinition QueryTypes, + typename DataProto, + typename Distance, + typename ThreadPoolProto, + typename StorageStrategy = index::inverted::SparseStrategy> + static Inverted assemble_from_clustering( + std::istream& is, + DataProto data_proto, + Distance distance, + ThreadPoolProto threadpool_proto, + StorageStrategy strategy = {} + ) { + return Inverted{ + std::in_place, + manager::as_typelist{}, + index::inverted::assemble_from_clustering( + is, + std::move(data_proto), + std::move(distance), + std::move(strategy), + std::move(threadpool_proto) + )}; + } }; } // namespace svs diff --git a/include/svs/orchestrators/ivf.h b/include/svs/orchestrators/ivf.h index 90c980780..87cb1b229 100644 --- a/include/svs/orchestrators/ivf.h +++ b/include/svs/orchestrators/ivf.h @@ -116,13 +116,7 @@ class IVFImpl : public manager::ManagerImpl { void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current IVF backend doesn't support saving!"); } @@ -380,27 +374,54 @@ class IVF : public manager::IndexManager { ThreadpoolProto threadpool_proto, size_t intra_query_threads = 1 ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + stream, + std::move(distance_function), + std::move(threadpool_proto), + intra_query_threads + ) + ); + }); + } else { + return IVF( + std::in_place, + manager::as_typelist{}, + index::ivf::load_ivf_index( + stream, distance, std::move(threadpool_proto), intra_query_threads + ) + ); + } + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_ivf_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); - } + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!"); + } - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); - } + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!"); + } - return assemble( - config_path, - data_path, - distance, - std::move(threadpool_proto), - intra_query_threads - ); + return assemble( + config_path, + data_path, + distance, + std::move(threadpool_proto), + intra_query_threads + ); + } } ///// Building diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index 6b698c4f9..ad12fd8c3 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -177,15 +177,7 @@ class VamanaImpl : public manager::ManagerImpl { void save(std::ostream& stream) override { if constexpr (Impl::supports_saving) { - lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; - const auto config_dir = tempdir.get() / "config"; - const auto graph_dir = tempdir.get() / "graph"; - const auto data_dir = tempdir.get() / "data"; - std::filesystem::create_directories(config_dir); - std::filesystem::create_directories(graph_dir); - std::filesystem::create_directories(data_dir); - save(config_dir, graph_dir, data_dir); - lib::DirectoryArchiver::pack(tempdir, stream); + impl().save(stream); } else { throw ANNEXCEPTION("The current Vamana backend doesn't support saving!"); } @@ -474,32 +466,72 @@ class Vamana : public manager::IndexManager { ThreadPoolProto threadpool_proto, DataLoaderArgs&&... data_args ) { - namespace fs = std::filesystem; - lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; - lib::DirectoryArchiver::unpack(stream, tempdir); + auto deserializer = svs::lib::detail::Deserializer::build(stream); + if (deserializer.is_native()) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + using GraphType = svs::GraphLoader<>::return_type; + if constexpr (std::is_same_v) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + return make_vamana>( + AssembleTag(), + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream( + stream, SVS_FWD(data_args)... + ); + }, + distance_function, + std::move(threadpool) + ); + }); + } else { + return make_vamana>( + AssembleTag(), + stream, + // lazy-loader + [&]() -> GraphType { return GraphType::load(stream); }, + // lazy-loader + [&]() -> Data { + return lib::load_from_stream(stream, SVS_FWD(data_args)...); + }, + distance, + std::move(threadpool) + ); + } + } else { + namespace fs = std::filesystem; + lib::UniqueTempDirectory tempdir{"svs_vamana_load"}; + lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); - const auto config_path = tempdir.get() / "config"; - if (!fs::is_directory(config_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!"); - } + const auto config_path = tempdir.get() / "config"; + if (!fs::is_directory(config_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing config directory!" + ); + } - const auto graph_path = tempdir.get() / "graph"; - if (!fs::is_directory(graph_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!"); - } + const auto graph_path = tempdir.get() / "graph"; + if (!fs::is_directory(graph_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing graph directory!" + ); + } - const auto data_path = tempdir.get() / "data"; - if (!fs::is_directory(data_path)) { - throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); - } + const auto data_path = tempdir.get() / "data"; + if (!fs::is_directory(data_path)) { + throw ANNEXCEPTION("Invalid Vamana index archive: missing data directory!"); + } - return assemble( - config_path, - svs::GraphLoader{graph_path}, - lib::load_from_disk(data_path, SVS_FWD(data_args)...), - distance, - threads::as_threadpool(std::move(threadpool_proto)) - ); + return assemble( + config_path, + svs::GraphLoader{graph_path}, + lib::load_from_disk(data_path, SVS_FWD(data_args)...), + distance, + threads::as_threadpool(std::move(threadpool_proto)) + ); + } } /// diff --git a/include/svs/quantization/scalar/scalar.h b/include/svs/quantization/scalar/scalar.h index a2244d1fd..5998e76a7 100644 --- a/include/svs/quantization/scalar/scalar.h +++ b/include/svs/quantization/scalar/scalar.h @@ -497,6 +497,18 @@ class SQDataset { ); } + lib::SaveTable metadata() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"data", lib::detail::exit_hook(data_.metadata())}, + {"scale", lib::save(scale_)}, + {"bias", lib::save(bias_)}} + ); + } + + void save(std::ostream& os) const { data_.save(os); } + /// @brief Load dataset from a file. static SQDataset load(const lib::LoadTable& table, const allocator_type& allocator = {}) { @@ -506,6 +518,18 @@ class SQDataset { lib::load_at(table, "bias")}; } + /// @brief Load dataset from a stream. + static SQDataset load( + const lib::ContextFreeLoadTable& table, + std::istream& is, + const allocator_type& allocator = {} + ) { + return SQDataset{ + SVS_LOAD_MEMBER_AT_(table, data, is, allocator), + lib::load_at(table, "scale"), + lib::load_at(table, "bias")}; + } + /// @brief Prefetch data in the dataset. void prefetch(size_t i) const { data_.prefetch(i); } }; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 95555eb15..265b37ef8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -148,7 +148,7 @@ set(TEST_SOURCES # Global scalar quantization ${TEST_DIR}/svs/quantization/scalar/scalar.cpp - # # ${TEST_DIR}/svs/index/vamana/dynamic_index.cpp + ${TEST_DIR}/svs/index/vamana/dynamic_index.cpp ) ##### diff --git a/tests/svs/index/flat/dynamic_flat.cpp b/tests/svs/index/flat/dynamic_flat.cpp index f9f99e702..900118e15 100644 --- a/tests/svs/index/flat/dynamic_flat.cpp +++ b/tests/svs/index/flat/dynamic_flat.cpp @@ -26,6 +26,7 @@ #include "svs/lib/threads.h" #include "svs/lib/timing.h" #include "svs/misc/dynamic_helper.h" +#include "svs/orchestrators/dynamic_flat.h" // tests #include "tests/utils/test_dataset.h" @@ -214,3 +215,125 @@ CATCH_TEST_CASE("Testing Flat Index", "[dynamic_flat]") { test_loop(index, reference, queries, div(reference.size(), modify_fraction), 2, 6); } + +CATCH_TEST_CASE("DynamicFlat Index Save and Load", "[dynamic_flat][index][saveload]") { +#if defined(NDEBUG) + const float initial_fraction = 0.25; + const float modify_fraction = 0.05; +#else + const float initial_fraction = 0.05; + const float modify_fraction = 0.005; +#endif + const size_t num_threads = 10; + + // Load the base dataset and queries. + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto num_points = data.size(); + auto queries = test_dataset::queries(); + + auto reference = svs::misc::ReferenceDataset( + std::move(data), + Distance(), + num_threads, + div(num_points, 0.5 * modify_fraction), + NUM_NEIGHBORS, + queries, + 0x12345678 + ); + + auto num_indices_to_add = div(reference.size(), initial_fraction); + + // Construct a blocked dataset consisting of initial fraction of the base dataset. + auto data_mutable = svs::data::BlockedData(num_indices_to_add, N); + std::vector initial_indices{}; + { + auto [vectors, indices] = reference.generate(num_indices_to_add); + // Copy assign ``initial_indices`` + auto num_points_added = indices.size(); + CATCH_REQUIRE(vectors.size() == num_points_added); + CATCH_REQUIRE(num_points_added <= num_indices_to_add); + CATCH_REQUIRE(num_points_added > num_indices_to_add - reference.bucket_size()); + + initial_indices = indices; + if (vectors.size() != num_indices_to_add || indices.size() != num_indices_to_add) { + throw ANNEXCEPTION("Something when horribly wrong!"); + } + + for (size_t i = 0; i < num_indices_to_add; ++i) { + data_mutable.set_datum(i, vectors.get_datum(i)); + } + } + + using Data_t = svs::data::BlockedData; + using Distance_t = svs::distance::DistanceL2; + using Index_t = svs::index::flat::DynamicFlatIndex; + + Distance_t dist; + auto index = Index_t(std::move(data_mutable), initial_indices, dist, num_threads); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), {}); + + reference.configure_extra_checks(true); + CATCH_REQUIRE(reference.extra_checks_enabled()); + + CATCH_SECTION("Load DynamicFlat being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + { + auto loaded_index = + svs::DynamicFlat::assemble(stream, dist, num_threads); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load DynamicFlat being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_dynflat_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + auto loaded_index = + svs::DynamicFlat::assemble(stream, dist, num_threads); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp index 57f821781..a5e1c801f 100644 --- a/tests/svs/index/flat/flat.cpp +++ b/tests/svs/index/flat/flat.cpp @@ -16,6 +16,12 @@ #include "svs/index/flat/flat.h" #include "svs/core/logging.h" +#include "svs/lib/file.h" +#include "svs/lib/saveload/load.h" +#include "svs/orchestrators/exhaustive.h" + +// tests +#include "tests/utils/test_dataset.h" // catch2 #include "catch2/catch_test_macros.hpp" @@ -66,3 +72,76 @@ CATCH_TEST_CASE("FlatIndex Logging Test", "[logging]") { CATCH_REQUIRE(captured_logs.size() == 1); CATCH_REQUIRE(captured_logs[0] == "Test FlatIndex Logging"); } + +CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { + using Data_t = svs::data::SimpleData; + using Distance_t = svs::distance::DistanceL2; + using Index_t = svs::index::flat::FlatIndex; + + // Load test data + auto data = Data_t::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + // Build index + Distance_t dist; + Index_t index = Index_t(std::move(data), dist, svs::threads::DefaultThreadPool(1)); + + size_t num_neighbors = 10; + auto results = svs::QueryResult(queries.size(), num_neighbors); + index.search(results.view(), queries.cview(), {}); + + CATCH_SECTION("Load Flat being serialized natively to stream") { + std::stringstream ss; + index.save(ss); + + auto loaded_index = svs::Flat::assemble( + ss, dist, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + + CATCH_SECTION("Load Flat being serialized with intermediate files") { + std::stringstream ss; + + svs::lib::UniqueTempDirectory tempdir{"svs_flat_save"}; + index.save(tempdir); + svs::lib::DirectoryArchiver::pack(tempdir, ss); + + auto loaded_index = svs::Flat::assemble( + ss, dist, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } +} diff --git a/tests/svs/index/inverted/memory_based.cpp b/tests/svs/index/inverted/memory_based.cpp index d418dd83d..28ff6b269 100644 --- a/tests/svs/index/inverted/memory_based.cpp +++ b/tests/svs/index/inverted/memory_based.cpp @@ -19,9 +19,11 @@ #include "spdlog/sinks/callback_sink.h" #include "svs-benchmark/datasets.h" #include "svs/lib/timing.h" +#include "svs/orchestrators/inverted.h" #include "tests/utils/inverted_reference.h" #include "tests/utils/test_dataset.h" #include +#include CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") { // Vector to store captured log messages @@ -73,3 +75,76 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") { CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos); CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos); } + +namespace { +constexpr size_t NUM_NEIGHBORS = 10; + +template void test_stream_save_load(Strategy strategy) { + auto distance = svs::DistanceL2(); + constexpr auto distance_type = svs::distance_type_v; + auto expected_results = test_dataset::inverted::expected_build_results( + distance_type, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_parameters = expected_results.build_parameters_.value(); + + // Capture the clustering during build. + svs::index::inverted::Clustering clustering; + auto clustering_op = [&](const auto& c) { clustering = c; }; + + svs::Inverted index = svs::Inverted::build( + build_parameters, + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance, + 2, + strategy, + svs::index::inverted::PickRandomly{}, + clustering_op + ); + + auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + auto parameters = index.get_search_parameters(); + auto results = index.search(queries, NUM_NEIGHBORS); + + // Serialize to stream. + std::stringstream ss; + svs::lib::save_to_stream(clustering, ss); + index.save_primary_index(ss); + + // Load from stream. + svs::Inverted loaded = svs::Inverted::assemble_from_clustering>( + ss, + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance, + 2, + strategy + ); + loaded.set_search_parameters(parameters); + + // Compare basic properties. + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + + // Compare search results element-wise. + auto loaded_results = loaded.search(queries, NUM_NEIGHBORS); + CATCH_REQUIRE(loaded_results.n_queries() == results.n_queries()); + CATCH_REQUIRE(loaded_results.n_neighbors() == results.n_neighbors()); + for (size_t q = 0; q < results.n_queries(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } +} +} // namespace + +CATCH_TEST_CASE("InvertedIndex Save and Load", "[saveload][inverted][index]") { + CATCH_SECTION("SparseStrategy") { + test_stream_save_load(svs::index::inverted::SparseStrategy()); + } + CATCH_SECTION("DenseStrategy") { + test_stream_save_load(svs::index::inverted::DenseStrategy()); + } +} diff --git a/tests/svs/index/ivf/index.cpp b/tests/svs/index/ivf/index.cpp index 7c78a3add..b15f26040 100644 --- a/tests/svs/index/ivf/index.cpp +++ b/tests/svs/index/ivf/index.cpp @@ -16,6 +16,7 @@ // header under test #include "svs/index/ivf/index.h" +#include "svs/orchestrators/ivf.h" // tests #include "tests/utils/test_dataset.h" @@ -33,6 +34,7 @@ // stl #include +#include CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") { namespace ivf = svs::index::ivf; @@ -276,6 +278,81 @@ CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") { svs_test::cleanup_temp_directory(); } + CATCH_SECTION("Load IVF Index serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_ivf_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + using DataType = svs::data::SimpleData; + + auto loaded_ivf = svs::IVF::assemble( + stream, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + CATCH_REQUIRE(loaded_ivf.size() == index.size()); + CATCH_REQUIRE(loaded_ivf.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_ivf.search(loaded_results.view(), batch_queries, search_params); + + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE( + loaded_results.index(q, i) == original_results.index(q, i) + ); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load IVF Index serialized natively to stream") { + std::stringstream stream; + index.save(stream); + + { + using DataType = svs::data::SimpleData; + + auto loaded_ivf = svs::IVF::assemble( + stream, + distance, + svs::threads::as_threadpool(num_threads), + num_inner_threads + ); + + CATCH_REQUIRE(loaded_ivf.size() == index.size()); + CATCH_REQUIRE(loaded_ivf.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_ivf.search(loaded_results.view(), batch_queries, search_params); + + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE( + loaded_results.index(q, i) == original_results.index(q, i) + ); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(original_results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + CATCH_SECTION("Save and load DenseClusteredDataset") { // Prepare temp directory auto tempdir = svs_test::prepare_temp_directory_v2(); diff --git a/tests/svs/index/vamana/dynamic_index.cpp b/tests/svs/index/vamana/dynamic_index.cpp index 177258879..798a584c2 100644 --- a/tests/svs/index/vamana/dynamic_index.cpp +++ b/tests/svs/index/vamana/dynamic_index.cpp @@ -23,20 +23,26 @@ #include #include #include +#include +#include #include #include // svs #include "svs/core/recall.h" #include "svs/lib/timing.h" +#include "svs/orchestrators/dynamic_vamana.h" // catch2 #include "catch2/catch_test_macros.hpp" +#include // tests #include "tests/utils/test_dataset.h" #include "tests/utils/utils.h" +// The MutableVamanaIndex "Soft Deletion" test uses outdated API. +#if 0 namespace { template auto copy_dataset(const T& data) { auto copy = svs::data::SimplePolymorphicData{ @@ -246,3 +252,115 @@ CATCH_TEST_CASE("MutableVamanaIndex", "[graph_index]") { << post_add_time << " seconds." << std::endl; } } +#endif + +CATCH_TEST_CASE( + "MutableVamana Index Save and Load", "[graph_index][dynamic_index][saveload]" +) { + const size_t num_threads = 2; + using Distance = svs::distance::DistanceL2; + + auto data = test_dataset::data_blocked_f32(); + std::vector indices(data.size()); + std::iota(indices.begin(), indices.end(), 0); + + svs::index::vamana::VamanaBuildParameters parameters{1.2, 64, 10, 20, 10, true}; + auto index = svs::index::vamana::MutableVamanaIndex( + parameters, std::move(data), indices, Distance(), num_threads + ); + + const size_t num_neighbors = 10; + auto queries = test_dataset::queries(); + auto search_params = svs::index::vamana::VamanaSearchParameters{}; + search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{num_neighbors}; + auto results = svs::QueryResult(queries.size(), num_neighbors); + index.search(results.view(), queries.cview(), search_params); + + CATCH_SECTION("Load MutableVamana Index being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + { + using Data_t = svs::data::BlockedData; + + auto loaded = svs::DynamicVamana::assemble( + stream, Distance(), num_threads + ); + + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha()); + CATCH_REQUIRE(loaded.get_graph_max_degree() == index.get_graph_max_degree()); + CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates()); + CATCH_REQUIRE( + loaded.get_construction_window_size() == + index.get_construction_window_size() + ); + CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to()); + CATCH_REQUIRE( + loaded.get_full_search_history() == index.get_full_search_history() + ); + index.on_ids([&](size_t e) { CATCH_REQUIRE(loaded.has_id(e)); }); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load MutableVamana Index being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_dynvamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, graph_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + using Data_t = svs::data::BlockedData; + + auto loaded = svs::DynamicVamana::assemble( + stream, Distance(), num_threads + ); + + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha()); + CATCH_REQUIRE(loaded.get_graph_max_degree() == index.get_graph_max_degree()); + CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates()); + CATCH_REQUIRE( + loaded.get_construction_window_size() == + index.get_construction_window_size() + ); + CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to()); + CATCH_REQUIRE( + loaded.get_full_search_history() == index.get_full_search_history() + ); + index.on_ids([&](size_t e) { CATCH_REQUIRE(loaded.has_id(e)); }); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index b94b902b8..bbc535c62 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -24,6 +24,7 @@ // svs #include "svs/index/vamana/build_params.h" #include "svs/lib/preprocessor.h" +#include "svs/orchestrators/vamana.h" // catch2 #include "catch2/catch_test_macros.hpp" @@ -37,6 +38,7 @@ // svsbenchmark #include "svs-benchmark/benchmark.h" // stl +#include #include namespace { @@ -181,6 +183,102 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") { CATCH_REQUIRE(captured_logs[2].find("Batch Size:") != std::string::npos); } +CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { + const size_t N = 128; + using Eltype = float; + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto graph = svs::graphs::SimpleGraph(data.size(), 64); + svs::distance::DistanceL2 distance_function; + uint32_t entry_point = 0; + auto threadpool = svs::threads::DefaultThreadPool(1); + + // Build the VamanaIndex with the test logger + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + svs::index::vamana::VamanaIndex index( + buildParams, + std::move(graph), + std::move(data), + entry_point, + distance_function, + std::move(threadpool) + ); + + const size_t NUM_NEIGHBORS = 10; + auto queries = test_dataset::queries(); + auto search_params = svs::index::vamana::VamanaSearchParameters{}; + search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{NUM_NEIGHBORS}; + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), search_params); + + CATCH_SECTION("Load Vamana Index being serialized with intermediate files") { + std::stringstream stream; + { + svs::lib::UniqueTempDirectory tempdir{"svs_vamana_save"}; + const auto config_dir = tempdir.get() / "config"; + const auto graph_dir = tempdir.get() / "graph"; + const auto data_dir = tempdir.get() / "data"; + std::filesystem::create_directories(config_dir); + std::filesystem::create_directories(graph_dir); + std::filesystem::create_directories(data_dir); + index.save(config_dir, graph_dir, data_dir); + svs::lib::DirectoryArchiver::pack(tempdir, stream); + } + { + using Data_t = svs::data::SimpleData; + + auto loaded_index = svs::Vamana::assemble( + stream, distance_function, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } + + CATCH_SECTION("Load Vamana Index being serialized natively to stream") { + std::stringstream stream; + index.save(stream); + + { + using Data_t = svs::data::SimpleData; + + auto loaded_index = svs::Vamana::assemble( + stream, distance_function, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + } + } +} + CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") { using Catch::Approx; std::filesystem::path data_path = test_dataset::data_svs_file();