Skip to content

Commit 1dad258

Browse files
committed
bug fixes
1 parent 2041921 commit 1dad258

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

include/sqlite-vec-cpp/index/hnsw.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,9 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
515515
/// Greedy search (read-only, called under shared lock)
516516
size_t greedy_search_layer_locked(std::span<const float> query, size_t entry_point,
517517
size_t layer, const FilterFn* filter) const {
518+
// Small negative threshold to handle floating-point error in distance calculation
519+
constexpr float kDistanceEpsilon = -1e-5f;
520+
518521
size_t current = entry_point;
519522
float current_dist = distance(query, current);
520523

@@ -533,7 +536,7 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
533536
break;
534537
for (size_t neighbor : current_node->neighbors(layer)) {
535538
float neighbor_dist = distance(query, neighbor);
536-
if (neighbor_dist < 0 || neighbor_dist >= current_dist)
539+
if (neighbor_dist < kDistanceEpsilon || neighbor_dist >= current_dist)
537540
continue;
538541
current = neighbor;
539542
current_dist = neighbor_dist;
@@ -584,8 +587,12 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
584587
return !is_deleted(id) && (!filter || (*filter)(id));
585588
};
586589

587-
if (passes_filter(entry_point) && entry_dist >= 0) {
588-
top_candidates.emplace(entry_dist, entry_point);
590+
// Note: entry_dist can be slightly negative due to floating-point error
591+
// (e.g., cosine distance of identical vectors = 1 - 1.0000001 = -1e-7)
592+
// We allow small negative values to avoid missing exact matches.
593+
constexpr float kDistanceEpsilon = -1e-5f;
594+
if (passes_filter(entry_point) && entry_dist >= kDistanceEpsilon) {
595+
top_candidates.emplace(std::max(0.0f, entry_dist), entry_point);
589596
}
590597

591598
while (!candidates.empty()) {
@@ -607,19 +614,19 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
607614
visited.insert(neighbor);
608615

609616
float neighbor_dist = distance(query, neighbor);
610-
if (neighbor_dist < 0)
617+
if (neighbor_dist < kDistanceEpsilon)
611618
continue;
612619

613620
bool should_explore = top_candidates.empty() || top_candidates.size() < ef ||
614621
neighbor_dist < top_candidates.top().first;
615622

616623
if (should_explore) {
617-
candidates.emplace(neighbor_dist, neighbor);
624+
candidates.emplace(std::max(0.0f, neighbor_dist), neighbor);
618625
}
619626

620627
if (passes_filter(neighbor)) {
621628
if (top_candidates.size() < ef || neighbor_dist < top_candidates.top().first) {
622-
top_candidates.emplace(neighbor_dist, neighbor);
629+
top_candidates.emplace(std::max(0.0f, neighbor_dist), neighbor);
623630
if (top_candidates.size() > ef) {
624631
top_candidates.pop();
625632
}
@@ -633,9 +640,8 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
633640
while (!top_candidates.empty()) {
634641
auto [dist, id] = top_candidates.top();
635642
top_candidates.pop();
636-
if (dist >= 0) {
637-
result.emplace_back(id, dist);
638-
}
643+
// Distances are already clamped to >= 0 when added
644+
result.emplace_back(id, dist);
639645
}
640646

641647
std::sort(result.begin(), result.end(),

include/sqlite-vec-cpp/index/hnsw_threading.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ThreadLocalRNG {
4949
}
5050
};
5151

52-
thread_local std::mt19937 ThreadLocalRNG::rng_{std::random_device{}()};
52+
inline thread_local std::mt19937 ThreadLocalRNG::rng_{std::random_device{}()};
5353

5454
/// Per-node lock array using pointer-based storage to avoid move issues
5555
class NodeLocks {

0 commit comments

Comments
 (0)