Skip to content

Commit 74bab4d

Browse files
selector for bivf
1 parent aeb45e5 commit 74bab4d

2 files changed

Lines changed: 90 additions & 50 deletions

File tree

faiss/IndexBinaryIVF.cpp

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ void IndexBinaryIVF::search(
115115
int32_t* distances,
116116
idx_t* labels,
117117
const SearchParameters* params) const {
118-
FAISS_THROW_IF_NOT_MSG(
119-
!params, "search params not supported for this index");
120118
FAISS_THROW_IF_NOT(k > 0);
121119
FAISS_THROW_IF_NOT(nprobe > 0);
122120

@@ -131,8 +129,11 @@ void IndexBinaryIVF::search(
131129
t0 = getmillisecs();
132130
invlists->prefetch_lists(idx.get(), n * nprobe_2);
133131

132+
const IVFSearchParameters* params2 = reinterpret_cast<const IVFSearchParameters*>(params);
133+
const IDSelector* sel = params2 ? params2->sel : nullptr;
134134
search_preassigned(
135-
n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
135+
n, x, k, idx.get(), coarse_dis.get(), distances, labels, false, params2, sel);
136+
136137
indexIVF_stats.search_time += getmillisecs() - t0;
137138
}
138139

@@ -314,8 +315,10 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
314315
size_t code_size;
315316
bool store_pairs;
316317

317-
IVFBinaryScannerL2(size_t code_size, bool store_pairs)
318-
: code_size(code_size), store_pairs(store_pairs) {}
318+
IVFBinaryScannerL2(size_t code_size, bool store_pairs, const IDSelector* sel = nullptr)
319+
: BinaryInvertedListScanner(store_pairs, sel),
320+
code_size(code_size),
321+
store_pairs(store_pairs) {}
319322

320323
void set_query(const uint8_t* query_vector) override {
321324
hc.set(query_vector, code_size);
@@ -330,43 +333,47 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
330333
return hc.hamming(code);
331334
}
332335

333-
size_t scan_codes(
334-
size_t n,
335-
const uint8_t* __restrict codes,
336-
const idx_t* __restrict ids,
337-
int32_t* __restrict simi,
338-
idx_t* __restrict idxi,
339-
size_t k) const override {
340-
using C = CMax<int32_t, idx_t>;
336+
size_t scan_codes(
337+
size_t n,
338+
const uint8_t* __restrict codes,
339+
const idx_t* __restrict ids,
340+
int32_t* __restrict simi,
341+
idx_t* __restrict idxi,
342+
size_t k) const override {
343+
using C = CMax<int32_t, idx_t>;
341344

342-
size_t nup = 0;
343-
for (size_t j = 0; j < n; j++) {
344-
uint32_t dis = hc.hamming(codes);
345-
if (dis < simi[0]) {
346-
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
345+
for (size_t j = 0; j < n; j++) {
346+
uint32_t dis = hc.hamming(codes);
347+
if (dis < simi[0]) {
348+
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
349+
// Add selector check
350+
if (!sel || sel->is_member(id)) {
347351
heap_replace_top<C>(k, simi, idxi, dis, id);
348-
nup++;
349-
}
350-
codes += code_size;
352+
}
351353
}
352-
return nup;
354+
codes += code_size;
353355
}
356+
}
354357

355-
void scan_codes_range(
356-
size_t n,
357-
const uint8_t* __restrict codes,
358-
const idx_t* __restrict ids,
359-
int radius,
360-
RangeQueryResult& result) const override {
361-
for (size_t j = 0; j < n; j++) {
362-
uint32_t dis = hc.hamming(codes);
363-
if (dis < radius) {
364-
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
358+
void scan_codes_range(
359+
size_t n,
360+
const uint8_t* __restrict codes,
361+
const idx_t* __restrict ids,
362+
int radius,
363+
RangeQueryResult& result) const override {
364+
for (size_t j = 0; j < n; j++) {
365+
uint32_t dis = hc.hamming(codes);
366+
if (dis < radius) {
367+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
368+
// Add selector check
369+
if (!sel || sel->is_member(id)) {
365370
result.add(dis, id);
366371
}
367-
codes += code_size;
368372
}
373+
codes += code_size;
369374
}
375+
}
376+
370377
};
371378

372379
void search_knn_hamming_heap(
@@ -384,6 +391,7 @@ void search_knn_hamming_heap(
384391
nprobe = std::min((idx_t)ivf->nlist, nprobe);
385392
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
386393
MetricType metric_type = ivf->metric_type;
394+
const IDSelector* sel = params ? params->sel : nullptr;
387395

388396
// almost verbatim copy from IndexIVF::search_preassigned
389397

@@ -394,7 +402,7 @@ void search_knn_hamming_heap(
394402
#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) num_threads(num_omp_threads)
395403
{
396404
std::unique_ptr<BinaryInvertedListScanner> scanner(
397-
ivf->get_InvertedListScanner(store_pairs));
405+
ivf->get_InvertedListScanner(store_pairs, sel));
398406

399407
#pragma omp for
400408
for (idx_t i = 0; i < n; i++) {
@@ -759,47 +767,64 @@ struct BuildScanner {
759767
using T = BinaryInvertedListScanner*;
760768

761769
template <class HammingComputer>
762-
T f(size_t code_size, bool store_pairs) {
763-
return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs);
770+
T f(size_t code_size, bool store_pairs, const IDSelector* sel) {
771+
return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs, sel);
764772
}
765773
};
766774

767775
} // anonymous namespace
768776

769777
BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
770-
bool store_pairs) const {
771-
BuildScanner bs;
772-
return dispatch_HammingComputer(code_size, bs, code_size, store_pairs);
778+
bool store_pairs,
779+
const IDSelector* sel) const {
780+
// Choose the appropriate HammingComputer type based on code_size
781+
if (code_size == 4) {
782+
return new IVFBinaryScannerL2<HammingComputer4>(code_size, store_pairs, sel);
783+
} else if (code_size == 8) {
784+
return new IVFBinaryScannerL2<HammingComputer8>(code_size, store_pairs, sel);
785+
} else if (code_size == 16) {
786+
return new IVFBinaryScannerL2<HammingComputer16>(code_size, store_pairs, sel);
787+
} else if (code_size == 20) {
788+
return new IVFBinaryScannerL2<HammingComputer20>(code_size, store_pairs, sel);
789+
} else if (code_size == 32) {
790+
return new IVFBinaryScannerL2<HammingComputer32>(code_size, store_pairs, sel);
791+
} else if (code_size == 64) {
792+
return new IVFBinaryScannerL2<HammingComputer64>(code_size, store_pairs, sel);
793+
} else {
794+
return new IVFBinaryScannerL2<HammingComputerDefault>(code_size, store_pairs, sel);
795+
}
773796
}
774797

775798
void IndexBinaryIVF::search_preassigned(
776799
idx_t n,
777800
const uint8_t* x,
778801
idx_t k,
779-
const idx_t* cidx,
780-
const int32_t* cdis,
781-
int32_t* dis,
782-
idx_t* idx,
802+
const idx_t* assign,
803+
const int32_t* centroid_dis,
804+
int32_t* distances,
805+
idx_t* labels,
783806
bool store_pairs,
784-
const IVFSearchParameters* params) const {
807+
const IVFSearchParameters* params,
808+
const IDSelector* sel) const {
809+
785810
if (per_invlist_search) {
786811
Run_search_knn_hamming_per_invlist r;
787812
// clang-format off
788813
dispatch_HammingComputer(
789814
code_size, r, this, n, x, k,
790-
cidx, cdis, dis, idx, store_pairs, params);
815+
assign, centroid_dis, distances, labels, store_pairs, params);
791816
// clang-format on
792817
} else if (use_heap) {
793818
search_knn_hamming_heap(
794-
this, n, x, k, cidx, cdis, dis, idx, store_pairs, params);
819+
this, n, x, k, assign, centroid_dis, distances, labels, store_pairs, params);
795820
} else if (store_pairs) { // !use_heap && store_pairs
796821
Run_search_knn_hamming_count<true> r;
797822
dispatch_HammingComputer(
798-
code_size, r, this, n, x, cidx, k, dis, idx, params);
823+
code_size, r, this, n, x, assign, k, distances, labels, params);
799824
} else { // !use_heap && !store_pairs
800825
Run_search_knn_hamming_count<false> r;
801826
dispatch_HammingComputer(
802-
code_size, r, this, n, x, cidx, k, dis, idx, params);
827+
code_size, r, this, n, x, assign, k, distances, labels, params);
803828
}
804829
}
805830

faiss/IndexBinaryIVF.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,11 @@ struct IndexBinaryIVF : IndexBinary {
121121
int32_t* distances,
122122
idx_t* labels,
123123
bool store_pairs,
124-
const IVFSearchParameters* params = nullptr) const;
124+
const IVFSearchParameters* params = nullptr,
125+
const IDSelector* sel = nullptr) const;
125126

126127
virtual BinaryInvertedListScanner* get_InvertedListScanner(
127-
bool store_pairs = false) const;
128+
bool store_pairs = false,const IDSelector* sel = nullptr) const;
128129

129130
/** assign the vectors, then call search_preassign */
130131
void search(
@@ -218,6 +219,20 @@ struct IndexBinaryIVF : IndexBinary {
218219
};
219220

220221
struct BinaryInvertedListScanner {
222+
223+
bool store_pairs;
224+
const IDSelector* sel;
225+
idx_t list_no;
226+
const uint8_t* query_vector;
227+
228+
BinaryInvertedListScanner(
229+
bool store_pairs = false,
230+
const IDSelector* sel = nullptr)
231+
: store_pairs(store_pairs),
232+
sel(sel),
233+
list_no(-1),
234+
query_vector(nullptr) {}
235+
221236
/// from now on we handle this query.
222237
virtual void set_query(const uint8_t* query_vector) = 0;
223238

0 commit comments

Comments
 (0)