@@ -115,8 +115,6 @@ void IndexBinaryIVF::search(
115115 int32_t * distances,
116116 idx_t * labels,
117117 const SearchParameters* params) const {
118- FAISS_THROW_IF_NOT_MSG (
119- !params, " search params not supported for this index" );
120118 FAISS_THROW_IF_NOT (k > 0 );
121119 FAISS_THROW_IF_NOT (nprobe > 0 );
122120
@@ -131,8 +129,11 @@ void IndexBinaryIVF::search(
131129 t0 = getmillisecs ();
132130 invlists->prefetch_lists (idx.get (), n * nprobe_2);
133131
132+ const IVFSearchParameters* params2 = reinterpret_cast <const IVFSearchParameters*>(params);
133+ const IDSelector* sel = params2 ? params2->sel : nullptr ;
134134 search_preassigned (
135- n, x, k, idx.get (), coarse_dis.get (), distances, labels, false );
135+ n, x, k, idx.get (), coarse_dis.get (), distances, labels, false , params2, sel);
136+
136137 indexIVF_stats.search_time += getmillisecs () - t0;
137138}
138139
@@ -314,8 +315,10 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
314315 size_t code_size;
315316 bool store_pairs;
316317
317- IVFBinaryScannerL2 (size_t code_size, bool store_pairs)
318- : code_size(code_size), store_pairs(store_pairs) {}
318+ IVFBinaryScannerL2 (size_t code_size, bool store_pairs, const IDSelector* sel = nullptr )
319+ : BinaryInvertedListScanner(store_pairs, sel),
320+ code_size (code_size),
321+ store_pairs(store_pairs) {}
319322
320323 void set_query (const uint8_t * query_vector) override {
321324 hc.set (query_vector, code_size);
@@ -330,43 +333,47 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
330333 return hc.hamming (code);
331334 }
332335
333- size_t scan_codes (
334- size_t n,
335- const uint8_t * __restrict codes,
336- const idx_t * __restrict ids,
337- int32_t * __restrict simi,
338- idx_t * __restrict idxi,
339- size_t k) const override {
340- using C = CMax<int32_t , idx_t >;
336+ size_t scan_codes (
337+ size_t n,
338+ const uint8_t * __restrict codes,
339+ const idx_t * __restrict ids,
340+ int32_t * __restrict simi,
341+ idx_t * __restrict idxi,
342+ size_t k) const override {
343+ using C = CMax<int32_t , idx_t >;
341344
342- size_t nup = 0 ;
343- for (size_t j = 0 ; j < n; j++) {
344- uint32_t dis = hc.hamming (codes);
345- if (dis < simi[0 ]) {
346- idx_t id = store_pairs ? lo_build (list_no, j) : ids[j];
345+ for (size_t j = 0 ; j < n; j++) {
346+ uint32_t dis = hc.hamming (codes);
347+ if (dis < simi[0 ]) {
348+ idx_t id = store_pairs ? lo_build (list_no, j) : ids[j];
349+ // Add selector check
350+ if (!sel || sel->is_member (id)) {
347351 heap_replace_top<C>(k, simi, idxi, dis, id);
348- nup++;
349- }
350- codes += code_size;
352+ }
351353 }
352- return nup ;
354+ codes += code_size ;
353355 }
356+ }
354357
355- void scan_codes_range (
356- size_t n,
357- const uint8_t * __restrict codes,
358- const idx_t * __restrict ids,
359- int radius,
360- RangeQueryResult& result) const override {
361- for (size_t j = 0 ; j < n; j++) {
362- uint32_t dis = hc.hamming (codes);
363- if (dis < radius) {
364- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
358+ void scan_codes_range (
359+ size_t n,
360+ const uint8_t * __restrict codes,
361+ const idx_t * __restrict ids,
362+ int radius,
363+ RangeQueryResult& result) const override {
364+ for (size_t j = 0 ; j < n; j++) {
365+ uint32_t dis = hc.hamming (codes);
366+ if (dis < radius) {
367+ int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
368+ // Add selector check
369+ if (!sel || sel->is_member (id)) {
365370 result.add (dis, id);
366371 }
367- codes += code_size;
368372 }
373+ codes += code_size;
369374 }
375+ }
376+
370377};
371378
372379void search_knn_hamming_heap (
@@ -384,6 +391,7 @@ void search_knn_hamming_heap(
384391 nprobe = std::min ((idx_t )ivf->nlist , nprobe);
385392 idx_t max_codes = params ? params->max_codes : ivf->max_codes ;
386393 MetricType metric_type = ivf->metric_type ;
394+ const IDSelector* sel = params ? params->sel : nullptr ;
387395
388396 // almost verbatim copy from IndexIVF::search_preassigned
389397
@@ -394,7 +402,7 @@ void search_knn_hamming_heap(
394402#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) num_threads(num_omp_threads)
395403 {
396404 std::unique_ptr<BinaryInvertedListScanner> scanner (
397- ivf->get_InvertedListScanner (store_pairs));
405+ ivf->get_InvertedListScanner (store_pairs, sel ));
398406
399407#pragma omp for
400408 for (idx_t i = 0 ; i < n; i++) {
@@ -759,47 +767,64 @@ struct BuildScanner {
759767 using T = BinaryInvertedListScanner*;
760768
761769 template <class HammingComputer >
762- T f (size_t code_size, bool store_pairs) {
763- return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs);
770+ T f (size_t code_size, bool store_pairs, const IDSelector* sel ) {
771+ return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs, sel );
764772 }
765773};
766774
767775} // anonymous namespace
768776
769777BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner (
770- bool store_pairs) const {
771- BuildScanner bs;
772- return dispatch_HammingComputer (code_size, bs, code_size, store_pairs);
778+ bool store_pairs,
779+ const IDSelector* sel) const {
780+ // Choose the appropriate HammingComputer type based on code_size
781+ if (code_size == 4 ) {
782+ return new IVFBinaryScannerL2<HammingComputer4>(code_size, store_pairs, sel);
783+ } else if (code_size == 8 ) {
784+ return new IVFBinaryScannerL2<HammingComputer8>(code_size, store_pairs, sel);
785+ } else if (code_size == 16 ) {
786+ return new IVFBinaryScannerL2<HammingComputer16>(code_size, store_pairs, sel);
787+ } else if (code_size == 20 ) {
788+ return new IVFBinaryScannerL2<HammingComputer20>(code_size, store_pairs, sel);
789+ } else if (code_size == 32 ) {
790+ return new IVFBinaryScannerL2<HammingComputer32>(code_size, store_pairs, sel);
791+ } else if (code_size == 64 ) {
792+ return new IVFBinaryScannerL2<HammingComputer64>(code_size, store_pairs, sel);
793+ } else {
794+ return new IVFBinaryScannerL2<HammingComputerDefault>(code_size, store_pairs, sel);
795+ }
773796}
774797
775798void IndexBinaryIVF::search_preassigned (
776799 idx_t n,
777800 const uint8_t * x,
778801 idx_t k,
779- const idx_t * cidx ,
780- const int32_t * cdis ,
781- int32_t * dis ,
782- idx_t * idx ,
802+ const idx_t * assign ,
803+ const int32_t * centroid_dis ,
804+ int32_t * distances ,
805+ idx_t * labels ,
783806 bool store_pairs,
784- const IVFSearchParameters* params) const {
807+ const IVFSearchParameters* params,
808+ const IDSelector* sel) const {
809+
785810 if (per_invlist_search) {
786811 Run_search_knn_hamming_per_invlist r;
787812 // clang-format off
788813 dispatch_HammingComputer (
789814 code_size, r, this , n, x, k,
790- cidx, cdis, dis, idx , store_pairs, params);
815+ assign, centroid_dis, distances, labels , store_pairs, params);
791816 // clang-format on
792817 } else if (use_heap) {
793818 search_knn_hamming_heap (
794- this , n, x, k, cidx, cdis, dis, idx , store_pairs, params);
819+ this , n, x, k, assign, centroid_dis, distances, labels , store_pairs, params);
795820 } else if (store_pairs) { // !use_heap && store_pairs
796821 Run_search_knn_hamming_count<true > r;
797822 dispatch_HammingComputer (
798- code_size, r, this , n, x, cidx , k, dis, idx , params);
823+ code_size, r, this , n, x, assign , k, distances, labels , params);
799824 } else { // !use_heap && !store_pairs
800825 Run_search_knn_hamming_count<false > r;
801826 dispatch_HammingComputer (
802- code_size, r, this , n, x, cidx , k, dis, idx , params);
827+ code_size, r, this , n, x, assign , k, distances, labels , params);
803828 }
804829}
805830
0 commit comments