Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions bindings/c/include/svs/c_api/svs_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum svs_error_code {
SVS_ERROR_NOT_IMPLEMENTED = 5,
SVS_ERROR_UNSUPPORTED_HW = 6,
SVS_ERROR_RUNTIME = 7,
SVS_ERROR_INVALID_OPERATION = 8,
SVS_ERROR_UNKNOWN = 1000
};

Expand Down Expand Up @@ -501,6 +502,32 @@ SVS_API bool svs_index_dynamic_compact(
svs_index_h index, size_t batchsize /*=0*/, svs_error_h out_err /*=NULL*/
);

/// @brief Get number of threads used for search in the index's thread pool
/// @param index The index handle
/// @param out_num_threads Pointer to store the retrieved number of threads
/// @param out_err An optional error handle to capture errors
/// @return true on success, false on failure
SVS_API bool svs_index_get_num_threads(
svs_index_h index, size_t* out_num_threads, svs_error_h out_err /*=NULL*/
);

/// @brief Set number of threads for search in the index's thread pool
/// @param index The index handle
/// @param num_threads The number of threads to set
/// @param out_err An optional error handle to capture errors
/// @return true on success, false on failure
/// @remarks This function is only supported for indices built with threadpool kinds
/// SVS_THREADPOOL_KIND_NATIVE or SVS_THREADPOOL_KIND_OMP. Attempting to call this
/// function on indices built with SVS_THREADPOOL_KIND_CUSTOM or
/// SVS_THREADPOOL_KIND_SINGLE_THREAD will fail and return false.
/// @error On failure, if out_err is provided, it will contain:
/// - SVS_ERROR_INVALID_OPERATION if the index was built with an unsupported threadpool kind
/// - SVS_ERROR_INVALID_ARGUMENT if num_threads is invalid or zero
/// - SVS_ERROR_RUNTIME for other runtime failures
SVS_API bool svs_index_set_num_threads(
svs_index_h index, size_t num_threads, svs_error_h out_err /*=NULL*/
);

#ifdef __cplusplus
}
#endif
8 changes: 8 additions & 0 deletions bindings/c/src/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ class not_implemented : public std::logic_error {
using std::logic_error::logic_error;
};

class invalid_operation : public std::logic_error {
public:
using std::logic_error::logic_error;
};

class unsupported_hw : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
Expand All @@ -104,6 +109,9 @@ Result wrap_exceptions(Callable&& func, svs_error_h err, Result err_res = {}) no
} catch (const svs::c_runtime::not_implemented& ex) {
SET_ERROR(err, SVS_ERROR_NOT_IMPLEMENTED, ex.what());
return err_res;
} catch (const svs::c_runtime::invalid_operation& ex) {
SET_ERROR(err, SVS_ERROR_INVALID_OPERATION, ex.what());
return err_res;
} catch (const svs::c_runtime::unsupported_hw& ex) {
SET_ERROR(err, SVS_ERROR_UNSUPPORTED_HW, ex.what());
return err_res;
Expand Down
31 changes: 23 additions & 8 deletions bindings/c/src/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "svs/c_api/svs_c.h"

#include "algorithm.hpp"
#include "threadpool.hpp"

#include <svs/concepts/data.h>
#include <svs/core/distance.h>
Expand All @@ -32,8 +33,10 @@
namespace svs::c_runtime {
struct Index {
svs_algorithm_type algorithm;
Index(svs_algorithm_type algorithm)
: algorithm(algorithm) {}
ThreadPoolBuilder pool_builder;
Index(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
: algorithm(algorithm)
, pool_builder(pool_builder) {}
virtual ~Index() = default;
virtual svs::QueryResult<size_t> search(
svs::data::ConstSimpleDataView<float> queries,
Expand All @@ -45,11 +48,13 @@ struct Index {
virtual float get_distance(size_t id, std::span<const float> query) const = 0;
virtual void
reconstruct_at(svs::data::SimpleDataView<float> dst, std::span<const size_t> ids) = 0;
virtual size_t get_num_threads() { return pool_builder.get_threads_num(); };
virtual void set_num_threads(size_t num_threads) = 0;
Comment on lines +51 to +52
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Index::get_num_threads() returns pool_builder.get_threads_num(), which can diverge from the index’s actual threadpool size (e.g., if the underlying svs::Vamana/svs::DynamicVamana threadpool was changed or if OpenMP thread count changes). To make the C API reflect the real runtime state, make get_num_threads() a const pure-virtual in Index and implement it in each derived index by delegating to the wrapped index.get_num_threads().

Copilot uses AI. Check for mistakes.
};

struct DynamicIndex : public Index {
DynamicIndex(svs_algorithm_type algorithm)
: Index(algorithm) {}
DynamicIndex(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
: Index(algorithm, pool_builder) {}
~DynamicIndex() = default;

virtual size_t add_points(
Expand All @@ -63,8 +68,8 @@ struct DynamicIndex : public Index {

struct IndexVamana : public Index {
svs::Vamana index;
IndexVamana(svs::Vamana&& index)
: Index{SVS_ALGORITHM_TYPE_VAMANA}
IndexVamana(svs::Vamana&& index, ThreadPoolBuilder pool_builder)
: Index{SVS_ALGORITHM_TYPE_VAMANA, pool_builder}
, index(std::move(index)) {}
~IndexVamana() = default;
svs::QueryResult<size_t> search(
Expand Down Expand Up @@ -99,12 +104,17 @@ struct IndexVamana : public Index {
override {
index.reconstruct_at(dst, ids);
}

void set_num_threads(size_t num_threads) override {
pool_builder.resize(num_threads);
index.set_threadpool(pool_builder.build());
}
};

struct DynamicIndexVamana : public DynamicIndex {
svs::DynamicVamana index;
DynamicIndexVamana(svs::DynamicVamana&& index)
: DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA)
DynamicIndexVamana(svs::DynamicVamana&& index, ThreadPoolBuilder pool_builder)
: DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA, pool_builder)
, index(std::move(index)) {}
~DynamicIndexVamana() = default;

Expand Down Expand Up @@ -170,5 +180,10 @@ struct DynamicIndexVamana : public DynamicIndex {
index.compact(batchsize);
}
}

void set_num_threads(size_t num_threads) override {
pool_builder.resize(num_threads);
index.set_threadpool(pool_builder.build());
}
};
} // namespace svs::c_runtime
50 changes: 30 additions & 20 deletions bindings/c/src/index_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,16 @@ struct IndexBuilder {
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);

auto index = std::make_shared<IndexVamana>(dispatch_vamana_index_build(
vamana_algorithm->build_parameters(),
data,
storage.get(),
to_distance_type(distance_metric),
pool_builder.build()
));
auto index = std::make_shared<IndexVamana>(
dispatch_vamana_index_build(
vamana_algorithm->build_parameters(),
data,
storage.get(),
to_distance_type(distance_metric),
pool_builder.build()
),
pool_builder
);

return index;
}
Expand All @@ -86,13 +89,16 @@ struct IndexBuilder {
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);

auto index = std::make_shared<IndexVamana>(dispatch_vamana_index_load(
vamana_algorithm->build_parameters(),
directory,
storage.get(),
to_distance_type(distance_metric),
pool_builder.build()
));
auto index = std::make_shared<IndexVamana>(
dispatch_vamana_index_load(
vamana_algorithm->build_parameters(),
directory,
storage.get(),
to_distance_type(distance_metric),
pool_builder.build()
),
pool_builder
);

return index;
}
Expand All @@ -107,16 +113,18 @@ struct IndexBuilder {
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);

auto index =
std::make_shared<DynamicIndexVamana>(dispatch_dynamic_vamana_index_build(
auto index = std::make_shared<DynamicIndexVamana>(
dispatch_dynamic_vamana_index_build(
vamana_algorithm->build_parameters(),
data,
ids,
storage.get(),
to_distance_type(distance_metric),
pool_builder.build(),
blocksize_bytes
));
),
pool_builder
);

return index;
}
Expand All @@ -128,15 +136,17 @@ struct IndexBuilder {
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);

auto index =
std::make_shared<DynamicIndexVamana>(dispatch_dynamic_vamana_index_load(
auto index = std::make_shared<DynamicIndexVamana>(
dispatch_dynamic_vamana_index_load(
vamana_algorithm->build_parameters(),
directory,
storage.get(),
to_distance_type(distance_metric),
pool_builder.build(),
blocksize_bytes
));
),
pool_builder
);

return index;
}
Expand Down
34 changes: 34 additions & 0 deletions bindings/c/src/svs_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,37 @@ svs_index_dynamic_compact(svs_index_h index, size_t batchsize, svs_error_h out_e
false
);
}

extern "C" bool
svs_index_get_num_threads(svs_index_h index, size_t* out_num_threads, svs_error_h out_err) {
using namespace svs::c_runtime;
return wrap_exceptions(
[&]() {
EXPECT_ARG_NOT_NULL(index);
EXPECT_ARG_NOT_NULL(out_num_threads);
auto& index_ptr = index->impl;
INVALID_ARGUMENT_IF(index_ptr == nullptr, "Invalid index handle");
*out_num_threads = index_ptr->get_num_threads();
return true;
},
out_err,
false
);
}

extern "C" bool
svs_index_set_num_threads(svs_index_h index, size_t num_threads, svs_error_h out_err) {
using namespace svs::c_runtime;
return wrap_exceptions(
[&]() {
EXPECT_ARG_NOT_NULL(index);
EXPECT_ARG_GT_THAN(num_threads, 0);
auto& index_ptr = index->impl;
INVALID_ARGUMENT_IF(index_ptr == nullptr, "Invalid index handle");
index_ptr->set_num_threads(num_threads);
return true;
},
out_err,
false
);
}
29 changes: 28 additions & 1 deletion bindings/c/src/threadpool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "svs/c_api/svs_c.h"

#include "error.hpp"
#include "types_support.hpp"

#include <svs/lib/threads.h>
Expand Down Expand Up @@ -74,7 +75,8 @@ class ThreadPoolBuilder {

ThreadPoolBuilder(svs_threadpool_kind kind, size_t num_threads)
: kind(kind)
, num_threads(num_threads) {
, num_threads(kind == SVS_THREADPOOL_KIND_SINGLE_THREAD ? 1 : num_threads)
, user_threadpool(nullptr) {
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
throw std::invalid_argument(
"SVS_THREADPOOL_KIND_CUSTOM cannot be built automatically."
Expand All @@ -91,6 +93,31 @@ class ThreadPoolBuilder {
return std::max(size_t{1}, size_t{std::thread::hardware_concurrency()});
}

svs_threadpool_kind get_kind() const { return kind; }
svs_threadpool_i get_user_threadpool() const { return user_threadpool; }

size_t get_threads_num() const {
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
return user_threadpool->ops.size(user_threadpool->self);
}
return num_threads;
}

void resize(size_t new_num_threads) {
if (new_num_threads == 0) {
throw std::invalid_argument("Number of threads must be greater than zero.");
}
if (kind == SVS_THREADPOOL_KIND_SINGLE_THREAD) {
throw svs::c_runtime::invalid_operation(
"Cannot resize a single-threaded threadpool."
);
}
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
throw svs::c_runtime::invalid_operation("Cannot resize a custom threadpool.");
}
num_threads = new_num_threads;
}

svs::threads::ThreadPoolHandle build() const {
using namespace svs::threads;
switch (kind) {
Expand Down
Loading