@@ -129,10 +129,20 @@ void IndexBinaryIVF::search(
129129 t0 = getmillisecs ();
130130 invlists->prefetch_lists (idx.get (), n * nprobe_2);
131131
132- const IVFSearchParameters* params2 = reinterpret_cast <const IVFSearchParameters*>(params);
132+ const IVFSearchParameters* params2 =
133+ reinterpret_cast <const IVFSearchParameters*>(params);
133134 const IDSelector* sel = params2 ? params2->sel : nullptr ;
134135 search_preassigned (
135- n, x, k, idx.get (), coarse_dis.get (), distances, labels, false , params2, sel);
136+ n,
137+ x,
138+ k,
139+ idx.get (),
140+ coarse_dis.get (),
141+ distances,
142+ labels,
143+ false ,
144+ params2,
145+ sel);
136146
137147 indexIVF_stats.search_time += getmillisecs () - t0;
138148}
@@ -221,6 +231,15 @@ void IndexBinaryIVF::reconstruct_from_offset(
221231 memcpy (recons, invlists->get_single_code (list_no, offset), code_size);
222232}
223233
234+ void IndexBinaryIVF::get_lists_for_keys (
235+ idx_t * keys,
236+ size_t n_keys,
237+ idx_t * lists) {
238+ for (int i = 0 ; i < n_keys; i++) {
239+ lists[i] = lo_listno (direct_map.get (keys[i]));
240+ }
241+ }
242+
224243void IndexBinaryIVF::reset () {
225244 direct_map.clear ();
226245 invlists->reset ();
@@ -315,7 +334,10 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
315334 size_t code_size;
316335 bool store_pairs;
317336
318- IVFBinaryScannerL2 (size_t code_size, bool store_pairs, const IDSelector* sel = nullptr )
337+ IVFBinaryScannerL2 (
338+ size_t code_size,
339+ bool store_pairs,
340+ const IDSelector* sel = nullptr )
319341 : BinaryInvertedListScanner(store_pairs, sel),
320342 code_size (code_size),
321343 store_pairs(store_pairs) {}
@@ -333,47 +355,46 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
333355 return hc.hamming (code);
334356 }
335357
336- size_t scan_codes (
337- size_t n,
338- const uint8_t * __restrict codes,
339- const idx_t * __restrict ids,
340- int32_t * __restrict simi,
341- idx_t * __restrict idxi,
342- size_t k) const override {
343- using C = CMax<int32_t , idx_t >;
358+ size_t scan_codes (
359+ size_t n,
360+ const uint8_t * __restrict codes,
361+ const idx_t * __restrict ids,
362+ int32_t * __restrict simi,
363+ idx_t * __restrict idxi,
364+ size_t k) const override {
365+ using C = CMax<int32_t , idx_t >;
344366
345- for (size_t j = 0 ; j < n; j++) {
346- uint32_t dis = hc.hamming (codes);
347- if (dis < simi[0 ]) {
348- idx_t id = store_pairs ? lo_build (list_no, j) : ids[j];
349- // Add selector check
350- if (!sel || sel->is_member (id)) {
351- heap_replace_top<C>(k, simi, idxi, dis, id);
352- }
367+ for (size_t j = 0 ; j < n; j++) {
368+ uint32_t dis = hc.hamming (codes);
369+ if (dis < simi[0 ]) {
370+ idx_t id = store_pairs ? lo_build (list_no, j) : ids[j];
371+ // Add selector check
372+ if (!sel || sel->is_member (id)) {
373+ heap_replace_top<C>(k, simi, idxi, dis, id);
374+ }
375+ }
376+ codes += code_size;
353377 }
354- codes += code_size;
355378 }
356- }
357379
358- void scan_codes_range (
359- size_t n,
360- const uint8_t * __restrict codes,
361- const idx_t * __restrict ids,
362- int radius,
363- RangeQueryResult& result) const override {
364- for (size_t j = 0 ; j < n; j++) {
365- uint32_t dis = hc.hamming (codes);
366- if (dis < radius) {
367- int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
368- // Add selector check
369- if (!sel || sel->is_member (id)) {
370- result.add (dis, id);
380+ void scan_codes_range (
381+ size_t n,
382+ const uint8_t * __restrict codes,
383+ const idx_t * __restrict ids,
384+ int radius,
385+ RangeQueryResult& result) const override {
386+ for (size_t j = 0 ; j < n; j++) {
387+ uint32_t dis = hc.hamming (codes);
388+ if (dis < radius) {
389+ int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
390+ // Add selector check
391+ if (!sel || sel->is_member (id)) {
392+ result.add (dis, id);
393+ }
371394 }
395+ codes += code_size;
372396 }
373- codes += code_size;
374397 }
375- }
376-
377398};
378399
379400void search_knn_hamming_heap (
@@ -399,7 +420,8 @@ void search_knn_hamming_heap(
399420 using HeapForIP = CMin<int32_t , idx_t >;
400421 using HeapForL2 = CMax<int32_t , idx_t >;
401422
402- #pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) num_threads(num_omp_threads)
423+ #pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) \
424+ num_threads (num_omp_threads)
403425 {
404426 std::unique_ptr<BinaryInvertedListScanner> scanner (
405427 ivf->get_InvertedListScanner (store_pairs, sel));
@@ -493,17 +515,19 @@ void search_knn_hamming_count(
493515
494516 std::vector<HCounterState<HammingComputer>> cs;
495517 for (size_t i = 0 ; i < nx; ++i) {
496- cs.push_back (HCounterState<HammingComputer>(
497- all_counters.data () + i * nBuckets,
498- all_ids_per_dis.get () + i * nBuckets * k,
499- x + i * ivf->code_size ,
500- ivf->d ,
501- k));
518+ cs.push_back (
519+ HCounterState<HammingComputer>(
520+ all_counters.data () + i * nBuckets,
521+ all_ids_per_dis.get () + i * nBuckets * k,
522+ x + i * ivf->code_size ,
523+ ivf->d ,
524+ k));
502525 }
503526
504527 size_t nlistv = 0 , ndis = 0 ;
505528
506- #pragma omp parallel for reduction(+ : nlistv, ndis) num_threads(num_omp_threads)
529+ #pragma omp parallel for reduction(+ : nlistv, ndis) \
530+ num_threads (num_omp_threads)
507531 for (int64_t i = 0 ; i < nx; i++) {
508532 const idx_t * keysi = keys + i * nprobe;
509533 HCounterState<HammingComputer>& csi = cs[i];
@@ -768,7 +792,8 @@ struct BuildScanner {
768792
769793 template <class HammingComputer >
770794 T f (size_t code_size, bool store_pairs, const IDSelector* sel) {
771- return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs, sel);
795+ return new IVFBinaryScannerL2<HammingComputer>(
796+ code_size, store_pairs, sel);
772797 }
773798};
774799
@@ -779,19 +804,26 @@ BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
779804 const IDSelector* sel) const {
780805 // Choose the appropriate HammingComputer type based on code_size
781806 if (code_size == 4 ) {
782- return new IVFBinaryScannerL2<HammingComputer4>(code_size, store_pairs, sel);
807+ return new IVFBinaryScannerL2<HammingComputer4>(
808+ code_size, store_pairs, sel);
783809 } else if (code_size == 8 ) {
784- return new IVFBinaryScannerL2<HammingComputer8>(code_size, store_pairs, sel);
810+ return new IVFBinaryScannerL2<HammingComputer8>(
811+ code_size, store_pairs, sel);
785812 } else if (code_size == 16 ) {
786- return new IVFBinaryScannerL2<HammingComputer16>(code_size, store_pairs, sel);
813+ return new IVFBinaryScannerL2<HammingComputer16>(
814+ code_size, store_pairs, sel);
787815 } else if (code_size == 20 ) {
788- return new IVFBinaryScannerL2<HammingComputer20>(code_size, store_pairs, sel);
816+ return new IVFBinaryScannerL2<HammingComputer20>(
817+ code_size, store_pairs, sel);
789818 } else if (code_size == 32 ) {
790- return new IVFBinaryScannerL2<HammingComputer32>(code_size, store_pairs, sel);
819+ return new IVFBinaryScannerL2<HammingComputer32>(
820+ code_size, store_pairs, sel);
791821 } else if (code_size == 64 ) {
792- return new IVFBinaryScannerL2<HammingComputer64>(code_size, store_pairs, sel);
822+ return new IVFBinaryScannerL2<HammingComputer64>(
823+ code_size, store_pairs, sel);
793824 } else {
794- return new IVFBinaryScannerL2<HammingComputerDefault>(code_size, store_pairs, sel);
825+ return new IVFBinaryScannerL2<HammingComputerDefault>(
826+ code_size, store_pairs, sel);
795827 }
796828}
797829
@@ -806,7 +838,6 @@ void IndexBinaryIVF::search_preassigned(
806838 bool store_pairs,
807839 const IVFSearchParameters* params,
808840 const IDSelector* sel) const {
809-
810841 if (per_invlist_search) {
811842 Run_search_knn_hamming_per_invlist r;
812843 // clang-format off
@@ -816,7 +847,16 @@ void IndexBinaryIVF::search_preassigned(
816847 // clang-format on
817848 } else if (use_heap) {
818849 search_knn_hamming_heap (
819- this , n, x, k, assign, centroid_dis, distances, labels, store_pairs, params);
850+ this ,
851+ n,
852+ x,
853+ k,
854+ assign,
855+ centroid_dis,
856+ distances,
857+ labels,
858+ store_pairs,
859+ params);
820860 } else if (store_pairs) { // !use_heap && store_pairs
821861 Run_search_knn_hamming_count<true > r;
822862 dispatch_HammingComputer (
0 commit comments