Skip to content

Commit d1ed5bb

Browse files
committed
Improved error message for negative k
1 parent c6d38d1 commit d1ed5bb

4 files changed

Lines changed: 46 additions & 10 deletions

File tree

ext/faiss/index_binary.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,22 @@ void init_index_binary(Rice::Module& m) {
6262
})
6363
.define_method(
6464
"search",
65-
[](Rice::Object rb_self, numo::UInt8 objects, size_t k) {
65+
[](Rice::Object rb_self, numo::UInt8 objects, int64_t k) {
6666
auto& self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
6767
size_t n = check_shape(objects, self.d / 8);
68+
if (k < 0) {
69+
throw Rice::Exception(rb_eArgError, "expected k to be non-negative");
70+
}
6871

69-
numo::Int32 distances{{n, k}};
70-
numo::Int64 labels{{n, k}};
72+
numo::Int32 distances{{n, static_cast<size_t>(k)}};
73+
numo::Int64 labels{{n, static_cast<size_t>(k)}};
7174

7275
if (rb_self.is_frozen()) {
7376
// Don't mess with Ruby-owned memory while the GVL is released
7477
const auto* objects_ptr = objects.read_ptr();
7578
std::vector<uint8_t> objects_vec(objects_ptr, objects_ptr + n * (self.d / 8));
76-
std::vector<int32_t> distances_vec(n * k);
77-
std::vector<int64_t> labels_vec(n * k);
79+
std::vector<int32_t> distances_vec(n * static_cast<size_t>(k));
80+
std::vector<int64_t> labels_vec(n * static_cast<size_t>(k));
7881

7982
Rice::detail::no_gvl([&] {
8083
self.search(n, objects_vec.data(), k, distances_vec.data(), labels_vec.data());

ext/faiss/index_rb.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,22 @@ void init_index(Rice::Module& m) {
153153
})
154154
.define_method(
155155
"search",
156-
[](Rice::Object rb_self, numo::SFloat objects, size_t k) {
156+
[](Rice::Object rb_self, numo::SFloat objects, int64_t k) {
157157
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
158158
size_t n = check_shape(objects, self.d);
159+
if (k < 0) {
160+
throw Rice::Exception(rb_eArgError, "expected k to be non-negative");
161+
}
159162

160-
numo::SFloat distances{{n, k}};
161-
numo::Int64 labels{{n, k}};
163+
numo::SFloat distances{{n, static_cast<size_t>(k)}};
164+
numo::Int64 labels{{n, static_cast<size_t>(k)}};
162165

163166
if (rb_self.is_frozen()) {
164167
// Don't mess with Ruby-owned memory while the GVL is released
165168
const auto* objects_ptr = objects.read_ptr();
166169
std::vector<float> objects_vec(objects_ptr, objects_ptr + n * self.d);
167-
std::vector<float> distances_vec(n * k);
168-
std::vector<int64_t> labels_vec(n * k);
170+
std::vector<float> distances_vec(n * static_cast<size_t>(k));
171+
std::vector<int64_t> labels_vec(n * static_cast<size_t>(k));
169172

170173
Rice::detail::no_gvl([&] {
171174
self.search(n, objects_vec.data(), k, distances_vec.data(), labels_vec.data());

test/index_binary_test.rb

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ def test_remove_ids
5656
assert_equal [1, 0, -1], ids[0, true].to_a
5757
end
5858

59+
def test_search_negative_k
60+
objects = [
61+
[1, 1, 2, 1],
62+
[5, 4, 6, 5],
63+
[1, 2, 1, 2]
64+
]
65+
index = Faiss::IndexBinaryFlat.new(32)
66+
index.add(objects)
67+
68+
error = assert_raises(ArgumentError) do
69+
index.search(objects, -1)
70+
end
71+
assert_equal "expected k to be non-negative", error.message
72+
end
73+
5974
def test_add_frozen
6075
index = Faiss::IndexBinaryFlat.new(32)
6176
index.freeze

test/index_test.rb

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,21 @@ def test_remove_ids_empty
336336
assert_equal 0, index.remove_ids([])
337337
end
338338

339+
def test_search_negative_k
340+
objects = [
341+
[1, 1, 2, 1],
342+
[5, 4, 6, 5],
343+
[1, 2, 1, 2]
344+
]
345+
index = Faiss::IndexFlatL2.new(4)
346+
index.add(objects)
347+
348+
error = assert_raises(ArgumentError) do
349+
index.search(objects, -1)
350+
end
351+
assert_equal "expected k to be non-negative", error.message
352+
end
353+
339354
def test_add_frozen
340355
index = Faiss::IndexFlatL2.new(4)
341356
index.freeze

0 commit comments

Comments
 (0)