@@ -62,19 +62,22 @@ void init_index_binary(Rice::Module& m) {
6262 })
6363 .define_method (
6464 " search" ,
65- [](Rice::Object rb_self, numo::UInt8 objects, size_t k) {
65+ [](Rice::Object rb_self, numo::UInt8 objects, int64_t k) {
6666 auto & self = *Rice::Data_Object<faiss::IndexBinary>{rb_self};
6767 size_t n = check_shape (objects, self.d / 8 );
68+ if (k < 0 ) {
69+ throw Rice::Exception (rb_eArgError, " expected k to be non-negative" );
70+ }
6871
69- numo::Int32 distances{{n, k }};
70- numo::Int64 labels{{n, k }};
72+ numo::Int32 distances{{n, static_cast < size_t >(k) }};
73+ numo::Int64 labels{{n, static_cast < size_t >(k) }};
7174
7275 if (rb_self.is_frozen ()) {
7376 // Don't mess with Ruby-owned memory while the GVL is released
7477 const auto * objects_ptr = objects.read_ptr ();
7578 std::vector<uint8_t > objects_vec (objects_ptr, objects_ptr + n * (self.d / 8 ));
76- std::vector<int32_t > distances_vec (n * k );
77- std::vector<int64_t > labels_vec (n * k );
79+ std::vector<int32_t > distances_vec (n * static_cast < size_t >(k) );
80+ std::vector<int64_t > labels_vec (n * static_cast < size_t >(k) );
7881
7982 Rice::detail::no_gvl ([&] {
8083 self.search (n, objects_vec.data (), k, distances_vec.data (), labels_vec.data ());
0 commit comments