Skip to content

Commit a82422c

Browse files
committed
Improved error message for zero k
1 parent d1ed5bb commit a82422c

4 files changed

Lines changed: 36 additions & 6 deletions

File tree

ext/faiss/index_binary.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void init_index_binary(Rice::Module& m) {
6565
[](Rice::Object rb_self, numo::UInt8 objects, int64_t k) {
6666
auto& self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
6767
size_t n = check_shape(objects, self.d / 8);
68-
if (k < 0) {
69-
throw Rice::Exception(rb_eArgError, "expected k to be non-negative");
68+
if (k <= 0) {
69+
throw Rice::Exception(rb_eArgError, "expected k to be positive");
7070
}
7171

7272
numo::Int32 distances{{n, static_cast<size_t>(k)}};

ext/faiss/index_rb.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ void init_index(Rice::Module& m) {
156156
[](Rice::Object rb_self, numo::SFloat objects, int64_t k) {
157157
auto& self = *Rice::Data_Object<faiss::Index>{rb_self};
158158
size_t n = check_shape(objects, self.d);
159-
if (k < 0) {
160-
throw Rice::Exception(rb_eArgError, "expected k to be non-negative");
159+
if (k <= 0) {
160+
throw Rice::Exception(rb_eArgError, "expected k to be positive");
161161
}
162162

163163
numo::SFloat distances{{n, static_cast<size_t>(k)}};

test/index_binary_test.rb

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ def test_remove_ids
5656
assert_equal [1, 0, -1], ids[0, true].to_a
5757
end
5858

59+
def test_search_zero_k
60+
objects = [
61+
[1, 1, 2, 1],
62+
[5, 4, 6, 5],
63+
[1, 2, 1, 2]
64+
]
65+
index = Faiss::IndexBinaryFlat.new(32)
66+
index.add(objects)
67+
68+
error = assert_raises(ArgumentError) do
69+
index.search(objects, 0)
70+
end
71+
assert_equal "expected k to be positive", error.message
72+
end
73+
5974
def test_search_negative_k
6075
objects = [
6176
[1, 1, 2, 1],
@@ -68,7 +83,7 @@ def test_search_negative_k
6883
error = assert_raises(ArgumentError) do
6984
index.search(objects, -1)
7085
end
71-
assert_equal "expected k to be non-negative", error.message
86+
assert_equal "expected k to be positive", error.message
7287
end
7388

7489
def test_add_frozen

test/index_test.rb

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,21 @@ def test_remove_ids_empty
336336
assert_equal 0, index.remove_ids([])
337337
end
338338

339+
def test_search_zero_k
340+
objects = [
341+
[1, 1, 2, 1],
342+
[5, 4, 6, 5],
343+
[1, 2, 1, 2]
344+
]
345+
index = Faiss::IndexFlatL2.new(4)
346+
index.add(objects)
347+
348+
error = assert_raises(ArgumentError) do
349+
index.search(objects, 0)
350+
end
351+
assert_equal "expected k to be positive", error.message
352+
end
353+
339354
def test_search_negative_k
340355
objects = [
341356
[1, 1, 2, 1],
@@ -348,7 +363,7 @@ def test_search_negative_k
348363
error = assert_raises(ArgumentError) do
349364
index.search(objects, -1)
350365
end
351-
assert_equal "expected k to be non-negative", error.message
366+
assert_equal "expected k to be positive", error.message
352367
end
353368

354369
def test_add_frozen

0 commit comments

Comments
 (0)