Skip to content

Commit 6b4e7a0

Browse files
committed
adding clamp logic
1 parent 7d784b4 commit 6b4e7a0

2 files changed

Lines changed: 30 additions & 9 deletions

File tree

include/sqlite-vec-cpp/index/hnsw.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
5656
size_t ef_construction = 200; ///< Exploration factor during construction (100-500)
5757
float ml_factor = 1.0f / std::log(2.0f); ///< Layer selection multiplier (1/ln(2))
5858
MetricT metric{}; ///< Distance metric (operates on float spans)
59+
bool clamp_negative_distances = true; ///< Clamp negative distances to 0 (safe for L2/cosine)
5960

6061
/// Create config optimized for high recall on large corpora
6162
/// @param corpus_size Expected number of vectors
@@ -812,7 +813,8 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
812813
size_t greedy_search_layer_locked(std::span<const float> query, size_t entry_point,
813814
size_t layer, const FilterFn* filter) const {
814815
// Small negative threshold to handle floating-point error in distance calculation
815-
constexpr float kDistanceEpsilon = -1e-5f;
816+
const float kDistanceEpsilon =
817+
config_.clamp_negative_distances ? -1e-5f : std::numeric_limits<float>::lowest();
816818

817819
size_t current = entry_point;
818820
float current_dist = distance(query, current);
@@ -903,9 +905,12 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
903905
// Note: entry_dist can be slightly negative due to floating-point error
904906
// (e.g., cosine distance of identical vectors = 1 - 1.0000001 = -1e-7)
905907
// We allow small negative values to avoid missing exact matches.
906-
constexpr float kDistanceEpsilon = -1e-5f;
908+
const float kDistanceEpsilon =
909+
config_.clamp_negative_distances ? -1e-5f : std::numeric_limits<float>::lowest();
907910
if (passes_filter(entry_point) && entry_dist >= kDistanceEpsilon) {
908-
top_candidates.emplace(std::max(0.0f, entry_dist), entry_point);
911+
float entry_score =
912+
config_.clamp_negative_distances ? std::max(0.0f, entry_dist) : entry_dist;
913+
top_candidates.emplace(entry_score, entry_point);
909914
}
910915

911916
while (!candidates.empty()) {
@@ -958,12 +963,18 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
958963
neighbor_dist < top_candidates.top().first;
959964

960965
if (should_explore) {
961-
candidates.emplace(std::max(0.0f, neighbor_dist), neighbor);
966+
float candidate_score = config_.clamp_negative_distances
967+
? std::max(0.0f, neighbor_dist)
968+
: neighbor_dist;
969+
candidates.emplace(candidate_score, neighbor);
962970
}
963971

964972
if (passes_filter(neighbor)) {
965973
if (top_candidates.size() < ef || neighbor_dist < top_candidates.top().first) {
966-
top_candidates.emplace(std::max(0.0f, neighbor_dist), neighbor);
974+
float top_score = config_.clamp_negative_distances
975+
? std::max(0.0f, neighbor_dist)
976+
: neighbor_dist;
977+
top_candidates.emplace(top_score, neighbor);
967978
if (top_candidates.size() > ef) {
968979
top_candidates.pop();
969980
}

include/sqlite-vec-cpp/index/hnsw_persistence.hpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ std::vector<uint8_t> serialize_hnsw_config(const typename HNSWIndex<T, Metric>::
2424
blob.reserve(64);
2525

2626
// Version marker (for future compatibility)
27-
constexpr uint32_t version = 1;
27+
constexpr uint32_t version = 2;
2828
auto write_u32 = [&](uint32_t val) {
2929
for (int i = 0; i < 4; ++i) {
3030
blob.push_back((val >> (i * 8)) & 0xFF);
@@ -47,15 +47,17 @@ std::vector<uint8_t> serialize_hnsw_config(const typename HNSWIndex<T, Metric>::
4747
write_u64(config.M_max_0);
4848
write_u64(config.ef_construction);
4949
write_f32(config.ml_factor);
50+
write_u32(config.clamp_negative_distances ? 1U : 0U);
5051

5152
return blob;
5253
}
5354

5455
/// Deserialize HNSW index configuration from blob
5556
template <typename T, typename Metric>
5657
typename HNSWIndex<T, Metric>::Config deserialize_hnsw_config(const void* blob, size_t size) {
57-
// Config blob: version(4) + M(8) + M_max(8) + M_max_0(8) + ef_construction(8) + ml_factor(4) =
58-
// 40 bytes
58+
// Config blob v1: version(4) + M(8) + M_max(8) + M_max_0(8) + ef_construction(8) + ml_factor(4)
59+
// = 40 bytes
60+
// Config blob v2: v1 + clamp_negative_distances(4) = 44 bytes
5961
if (size < 40) {
6062
throw std::runtime_error("Invalid HNSW config blob: too small");
6163
}
@@ -85,7 +87,7 @@ typename HNSWIndex<T, Metric>::Config deserialize_hnsw_config(const void* blob,
8587
};
8688

8789
uint32_t version = read_u32();
88-
if (version != 1) {
90+
if (version != 1 && version != 2) {
8991
throw std::runtime_error("Unsupported HNSW config version");
9092
}
9193

@@ -95,6 +97,14 @@ typename HNSWIndex<T, Metric>::Config deserialize_hnsw_config(const void* blob,
9597
config.M_max_0 = read_u64();
9698
config.ef_construction = read_u64();
9799
config.ml_factor = read_f32();
100+
if (version >= 2) {
101+
if (size < 44) {
102+
throw std::runtime_error("Invalid HNSW config blob: missing clamp flag");
103+
}
104+
config.clamp_negative_distances = (read_u32() != 0);
105+
} else {
106+
config.clamp_negative_distances = true;
107+
}
98108

99109
return config;
100110
}

0 commit comments

Comments
 (0)