Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 93 additions & 22 deletions include/svs/index/ivf/clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,66 @@ class DenseClusteredDataset {
static constexpr lib::Version save_version{0, 0, 0};
static constexpr std::string_view serialization_schema = "ivf_dense_clustered_dataset";

lib::SaveTable metadata() const {
auto num_clusters = size();
auto dims = dimensions();

return lib::SaveTable(
serialization_schema,
save_version,
{{"num_clusters", lib::save(num_clusters)},
{"dimensions", lib::save(dims)},
{"prefetch_offset", lib::save(prefetch_offset_)},
{"index_type", lib::save(datatype_v<I>)}}
);
}

void save(std::ostream& os) const {
auto num_clusters = size();

// Compute cluster sizes and ID offsets
std::vector<size_t> cluster_sizes, ids_offsets;
calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets);

lib::write_binary(os, cluster_sizes);
for (size_t i = 0; i < num_clusters; ++i) {
lib::save_to_stream(clusters_[i].data_, os);
if (!clusters_[i].ids_.empty()) {
lib::write_binary(os, clusters_[i].ids_);
}
}
}

template <typename Allocator = typename Data::allocator_type>
static DenseClusteredDataset load(
const lib::ContextFreeLoadTable& table,
std::istream& is,
const Allocator& allocator = Allocator{}
) {
auto num_clusters = lib::load_at<size_t>(table, "num_clusters");
[[maybe_unused]] auto dims = lib::load_at<size_t>(table, "dimensions");
auto prefetch_offset = lib::load_at<size_t>(table, "prefetch_offset");
check_saved_index_type(table);

DenseClusteredDataset result;
result.prefetch_offset_ = prefetch_offset;
result.clusters_.reserve(num_clusters);

std::vector<size_t> cluster_sizes(num_clusters);
lib::read_binary(is, cluster_sizes);

for (size_t i = 0; i < num_clusters; ++i) {
auto cluster_data = lib::load_from_stream<Data>(is, allocator);
size_t cluster_size = cluster_sizes[i];
std::vector<I> cluster_ids(cluster_size);
if (cluster_size > 0) {
lib::read_binary(is, cluster_ids);
}
result.clusters_.emplace_back(std::move(cluster_data), std::move(cluster_ids));
}
return result;
}

/// @brief Save the DenseClusteredDataset to disk.
///
/// Saves all cluster data using the existing save mechanisms for each data type
Expand All @@ -427,16 +487,8 @@ class DenseClusteredDataset {
auto dims = dimensions();

// Compute cluster sizes and ID offsets
std::vector<size_t> cluster_sizes(num_clusters);
std::vector<size_t> ids_offsets(num_clusters + 1);

size_t ids_offset = 0;
for (size_t i = 0; i < num_clusters; ++i) {
cluster_sizes[i] = clusters_[i].size();
ids_offsets[i] = ids_offset;
ids_offset += cluster_sizes[i] * sizeof(I);
}
ids_offsets[num_clusters] = ids_offset;
std::vector<size_t> cluster_sizes, ids_offsets;
calculate_cluster_sizes_and_offsets(&cluster_sizes, &ids_offsets);

// Create a temporary directory for cluster data
lib::UniqueTempDirectory tempdir{"svs_ivf_clusters_save"};
Expand Down Expand Up @@ -495,7 +547,7 @@ class DenseClusteredDataset {
{"ids_file", lib::save(std::string("ids.bin"))},
{"cluster_sizes_file", lib::save(std::string("cluster_sizes.bin"))},
{"ids_offsets_file", lib::save(std::string("ids_offsets.bin"))},
{"total_ids_bytes", lib::save(ids_offset)}}
{"total_ids_bytes", lib::save(ids_offsets[num_clusters])}}
);
}

Expand Down Expand Up @@ -528,17 +580,7 @@ class DenseClusteredDataset {
// since each cluster's data type determines its own dimensions
[[maybe_unused]] auto dims = lib::load_at<size_t>(table, "dimensions");
auto prefetch_offset = lib::load_at<size_t>(table, "prefetch_offset");

// Verify index type matches
auto saved_index_type = lib::load_at<DataType>(table, "index_type");
if (saved_index_type != datatype_v<I>) {
throw ANNEXCEPTION(
"DenseClusteredDataset was saved using index type {} but we're trying to "
"reload it using {}!",
saved_index_type,
datatype_v<I>
);
}
check_saved_index_type(table);

auto base_dir = table.context().get_directory();

Expand Down Expand Up @@ -603,6 +645,35 @@ class DenseClusteredDataset {
}

private:
void calculate_cluster_sizes_and_offsets(
std::vector<size_t>* cluster_sizes, std::vector<size_t>* ids_offsets
) const {
auto num_clusters = size();
cluster_sizes->resize(num_clusters);
ids_offsets->resize(num_clusters + 1);

size_t ids_offset = 0;
for (size_t i = 0; i < num_clusters; ++i) {
(*cluster_sizes)[i] = clusters_[i].size();
(*ids_offsets)[i] = ids_offset;
ids_offset += clusters_[i].size() * sizeof(I);
}
(*ids_offsets)[num_clusters] = ids_offset;
}

static void check_saved_index_type(const lib::ContextFreeLoadTable& table) {
// Verify index type matches
auto saved_index_type = lib::load_at<DataType>(table, "index_type");
if (saved_index_type != datatype_v<I>) {
throw ANNEXCEPTION(
"DenseClusteredDataset was saved using index type {} but we're trying to "
"reload it using {}!",
saved_index_type,
datatype_v<I>
);
}
}

std::vector<DenseCluster<Data, I>> clusters_;
size_t prefetch_offset_ = 8;
};
Expand Down
4 changes: 4 additions & 0 deletions include/svs/index/ivf/dynamic_ivf.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,10 @@ class DynamicIVFIndex {
lib::save_to_disk(clusters_, clusters_dir);
}

void save(std::ostream& SVS_UNUSED(os)) const {
throw ANNEXCEPTION("Placeholder; not implemented!");
}

private:
///// Helper Methods /////

Expand Down
54 changes: 54 additions & 0 deletions include/svs/index/ivf/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ class IVFIndex {
/// Each directory may be created as a side-effect of this method call provided that
/// the parent directory exists.
///

void save(
const std::filesystem::path& config_directory,
const std::filesystem::path& data_directory
Expand Down Expand Up @@ -504,6 +505,27 @@ class IVFIndex {
lib::save_to_disk(cluster_, clusters_dir);
}

void save(std::ostream& os) const {
lib::begin_serialization(os);

// Save config
auto data_type_config = DataTypeTraits<Data>::get_config();
data_type_config.centroid_type = datatype_v<typename Centroids::element_type>;
auto save_table = lib::SaveTable(
serialization_schema,
save_version,
{{"name", lib::save(name())}, {"num_clusters", lib::save(num_clusters())}}
);
save_table.insert("data_type_config", lib::save(data_type_config));
lib::save_to_stream(save_table, os);

// Save centroids
lib::save_to_stream(centroids_, os);

// Save clusters
lib::save_to_stream(cluster_, os);
}

private:
///// Core Components /////
Centroids centroids_;
Expand Down Expand Up @@ -964,4 +986,36 @@ auto load_ivf_index(
return index;
}

template <
typename CentroidType,
typename DataType,
typename Distance,
typename ThreadpoolProto>
auto load_ivf_index(
std::istream& is,
Distance distance,
ThreadpoolProto threadpool_proto,
const size_t intra_query_thread_count = 1,
svs::logging::logger_ptr logger = svs::logging::get()
) {
using centroids_type = data::SimpleData<CentroidType>;
using data_type = typename DataType::lib_alloc_data_type;
using cluster_type = DenseClusteredDataset<centroids_type, uint32_t, data_type>;

lib::detail::read_metadata(is);
auto centroids = lib::load_from_stream<centroids_type>(is);
auto clusters = lib::load_from_stream<cluster_type>(is);

auto threadpool = threads::as_threadpool(std::move(threadpool_proto));

return IVFIndex(
std::move(centroids),
std::move(clusters),
std::move(distance),
std::move(threadpool),
intra_query_thread_count,
logger
);
}

} // namespace svs::index::ivf
71 changes: 46 additions & 25 deletions include/svs/orchestrators/ivf.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,7 @@ class IVFImpl : public manager::ManagerImpl<QueryTypes, Impl, IFace> {

void save(std::ostream& stream) override {
if constexpr (Impl::supports_saving) {
lib::UniqueTempDirectory tempdir{"svs_ivf_save"};
const auto config_dir = tempdir.get() / "config";
const auto data_dir = tempdir.get() / "data";
std::filesystem::create_directories(config_dir);
std::filesystem::create_directories(data_dir);
save(config_dir, data_dir);
lib::DirectoryArchiver::pack(tempdir, stream);
impl().save(stream);
} else {
throw ANNEXCEPTION("The current IVF backend doesn't support saving!");
}
Expand Down Expand Up @@ -380,27 +374,54 @@ class IVF : public manager::IndexManager<IVFInterface> {
ThreadpoolProto threadpool_proto,
size_t intra_query_threads = 1
) {
namespace fs = std::filesystem;
lib::UniqueTempDirectory tempdir{"svs_ivf_load"};
lib::DirectoryArchiver::unpack(stream, tempdir);
auto deserializer = svs::lib::detail::Deserializer::build(stream);
if (deserializer.is_native()) {
if constexpr (std::is_same_v<std::decay_t<Distance>, DistanceType>) {
auto dispatcher = DistanceDispatcher(distance);
return dispatcher([&](auto distance_function) {
return IVF(
std::in_place,
manager::as_typelist<QueryTypes>{},
index::ivf::load_ivf_index<CentroidType, DataType>(
stream,
std::move(distance_function),
std::move(threadpool_proto),
intra_query_threads
)
);
});
} else {
return IVF(
std::in_place,
manager::as_typelist<QueryTypes>{},
index::ivf::load_ivf_index<CentroidType, DataType>(
stream, distance, std::move(threadpool_proto), intra_query_threads
)
);
}
} else {
namespace fs = std::filesystem;
lib::UniqueTempDirectory tempdir{"svs_ivf_load"};
lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic());

const auto config_path = tempdir.get() / "config";
if (!fs::is_directory(config_path)) {
throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!");
}
const auto config_path = tempdir.get() / "config";
if (!fs::is_directory(config_path)) {
throw ANNEXCEPTION("Invalid IVF index archive: missing config directory!");
}

const auto data_path = tempdir.get() / "data";
if (!fs::is_directory(data_path)) {
throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!");
}
const auto data_path = tempdir.get() / "data";
if (!fs::is_directory(data_path)) {
throw ANNEXCEPTION("Invalid IVF index archive: missing data directory!");
}

return assemble<QueryTypes, CentroidType, DataType>(
config_path,
data_path,
distance,
std::move(threadpool_proto),
intra_query_threads
);
return assemble<QueryTypes, CentroidType, DataType>(
config_path,
data_path,
distance,
std::move(threadpool_proto),
intra_query_threads
);
}
}

///// Building
Expand Down
Loading
Loading