Skip to content

Commit 719bea6

Browse files
added more bivf utils
1 parent 40db079 commit 719bea6

4 files changed

Lines changed: 187 additions & 58 deletions

File tree

c_api/IndexBinaryIVF_c.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,63 @@ int faiss_IndexBinaryIVF_set_direct_map(
99
}
1010
CATCH_AND_HANDLE
1111
}
12+
13+
int faiss_get_lists_for_keys_binary(
14+
FaissIndexBinaryIVF* index,
15+
idx_t* keys,
16+
size_t n_keys,
17+
idx_t* lists) {
18+
try {
19+
reinterpret_cast<faiss::IndexBinaryIVF*>(index)->get_lists_for_keys(
20+
keys, n_keys, lists);
21+
}
22+
CATCH_AND_HANDLE
23+
}
24+
25+
int faiss_Search_closest_eligible_centroids_binary(
26+
FaissIndexBinaryIVF* index,
27+
idx_t n,
28+
const uint8_t* query,
29+
idx_t k,
30+
int32_t* centroid_distances,
31+
idx_t* centroid_ids,
32+
const FaissSearchParameters* params) {
33+
try {
34+
faiss::IndexBinaryIVF* index_bivf = reinterpret_cast<IndexBinaryIVF*>(index);
35+
assert(index_bivf);
36+
37+
index_bivf->quantizer->search(
38+
n,
39+
query,
40+
k,
41+
centroid_distances,
42+
centroid_ids,
43+
reinterpret_cast<const faiss::SearchParameters*>(params));
44+
}
45+
CATCH_AND_HANDLE
46+
}
47+
48+
int faiss_IndexBinaryIVF_search_preassigned_with_params(
49+
const FaissIndexBinaryIVF* index,
50+
idx_t n,
51+
const uint8_t* x,
52+
idx_t k,
53+
const idx_t* assign,
54+
const int32_t* centroid_dis,
55+
int32_t* distances,
56+
idx_t* labels,
57+
int store_pairs,
58+
const FaissSearchParametersIVF* params) {
59+
try {
60+
faiss::IndexBinaryIVF* index_bivf = reinterpret_cast<IndexBinaryIVF*>(index);
61+
assert(index_bivf);
62+
63+
index_bivf->search_preassigned(n, x, k, assign, centroid_dis, distances,
64+
labels, store_pairs, reinterpret_cast<const faiss::SearchParameters*>(params));
65+
}
66+
CATCH_AND_HANDLE
67+
}
68+
69+
70+
1271
}

c_api/IndexBinaryIVF_c.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@ int faiss_IndexBinaryIVF_set_direct_map(
1313
FaissIndexBinaryIVF* index,
1414
int direct_map_type);
1515

16+
int faiss_get_lists_for_keys_binary(
17+
FaissIndexBinaryIVF* index,
18+
idx_t* keys,
19+
size_t n_keys,
20+
idx_t* lists) ;
21+
22+
23+
int faiss_Search_closest_eligible_centroids_binary(
24+
FaissIndexBinaryIVF* index,
25+
idx_t n,
26+
const uint8_t* query,
27+
idx_t k,
28+
int32_t* centroid_distances,
29+
idx_t* centroid_ids,
30+
const FaissSearchParameters* params);
31+
32+
int faiss_IndexBinaryIVF_search_preassigned_with_params(
33+
const FaissIndexBinaryIVF* index,
34+
idx_t n,
35+
const uint8_t* x,
36+
idx_t k,
37+
const idx_t* assign,
38+
const int32_t* centroid_dis,
39+
int32_t* distances,
40+
idx_t* labels,
41+
int store_pairs,
42+
const FaissSearchParametersIVF* params);
43+
1644
DEFINE_GETTER_PERMISSIVE(IndexBinaryIVF, FaissIndexBinary*, quantizer);
1745
#ifdef __cplusplus
1846
}

faiss/IndexBinaryIVF.cpp

Lines changed: 95 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,20 @@ void IndexBinaryIVF::search(
129129
t0 = getmillisecs();
130130
invlists->prefetch_lists(idx.get(), n * nprobe_2);
131131

132-
const IVFSearchParameters* params2 = reinterpret_cast<const IVFSearchParameters*>(params);
132+
const IVFSearchParameters* params2 =
133+
reinterpret_cast<const IVFSearchParameters*>(params);
133134
const IDSelector* sel = params2 ? params2->sel : nullptr;
134135
search_preassigned(
135-
n, x, k, idx.get(), coarse_dis.get(), distances, labels, false, params2, sel);
136+
n,
137+
x,
138+
k,
139+
idx.get(),
140+
coarse_dis.get(),
141+
distances,
142+
labels,
143+
false,
144+
params2,
145+
sel);
136146

137147
indexIVF_stats.search_time += getmillisecs() - t0;
138148
}
@@ -221,6 +231,15 @@ void IndexBinaryIVF::reconstruct_from_offset(
221231
memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
222232
}
223233

234+
void IndexBinaryIVF::get_lists_for_keys(
235+
idx_t* keys,
236+
size_t n_keys,
237+
idx_t* lists) {
238+
for (int i = 0; i < n_keys; i++) {
239+
lists[i] = lo_listno(direct_map.get(keys[i]));
240+
}
241+
}
242+
224243
void IndexBinaryIVF::reset() {
225244
direct_map.clear();
226245
invlists->reset();
@@ -315,7 +334,10 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
315334
size_t code_size;
316335
bool store_pairs;
317336

318-
IVFBinaryScannerL2(size_t code_size, bool store_pairs, const IDSelector* sel = nullptr)
337+
IVFBinaryScannerL2(
338+
size_t code_size,
339+
bool store_pairs,
340+
const IDSelector* sel = nullptr)
319341
: BinaryInvertedListScanner(store_pairs, sel),
320342
code_size(code_size),
321343
store_pairs(store_pairs) {}
@@ -333,47 +355,46 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
333355
return hc.hamming(code);
334356
}
335357

336-
size_t scan_codes(
337-
size_t n,
338-
const uint8_t* __restrict codes,
339-
const idx_t* __restrict ids,
340-
int32_t* __restrict simi,
341-
idx_t* __restrict idxi,
342-
size_t k) const override {
343-
using C = CMax<int32_t, idx_t>;
358+
size_t scan_codes(
359+
size_t n,
360+
const uint8_t* __restrict codes,
361+
const idx_t* __restrict ids,
362+
int32_t* __restrict simi,
363+
idx_t* __restrict idxi,
364+
size_t k) const override {
365+
using C = CMax<int32_t, idx_t>;
344366

345-
for (size_t j = 0; j < n; j++) {
346-
uint32_t dis = hc.hamming(codes);
347-
if (dis < simi[0]) {
348-
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
349-
// Add selector check
350-
if (!sel || sel->is_member(id)) {
351-
heap_replace_top<C>(k, simi, idxi, dis, id);
352-
}
367+
for (size_t j = 0; j < n; j++) {
368+
uint32_t dis = hc.hamming(codes);
369+
if (dis < simi[0]) {
370+
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
371+
// Add selector check
372+
if (!sel || sel->is_member(id)) {
373+
heap_replace_top<C>(k, simi, idxi, dis, id);
374+
}
375+
}
376+
codes += code_size;
353377
}
354-
codes += code_size;
355378
}
356-
}
357379

358-
void scan_codes_range(
359-
size_t n,
360-
const uint8_t* __restrict codes,
361-
const idx_t* __restrict ids,
362-
int radius,
363-
RangeQueryResult& result) const override {
364-
for (size_t j = 0; j < n; j++) {
365-
uint32_t dis = hc.hamming(codes);
366-
if (dis < radius) {
367-
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
368-
// Add selector check
369-
if (!sel || sel->is_member(id)) {
370-
result.add(dis, id);
380+
void scan_codes_range(
381+
size_t n,
382+
const uint8_t* __restrict codes,
383+
const idx_t* __restrict ids,
384+
int radius,
385+
RangeQueryResult& result) const override {
386+
for (size_t j = 0; j < n; j++) {
387+
uint32_t dis = hc.hamming(codes);
388+
if (dis < radius) {
389+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
390+
// Add selector check
391+
if (!sel || sel->is_member(id)) {
392+
result.add(dis, id);
393+
}
371394
}
395+
codes += code_size;
372396
}
373-
codes += code_size;
374397
}
375-
}
376-
377398
};
378399

379400
void search_knn_hamming_heap(
@@ -399,7 +420,8 @@ void search_knn_hamming_heap(
399420
using HeapForIP = CMin<int32_t, idx_t>;
400421
using HeapForL2 = CMax<int32_t, idx_t>;
401422

402-
#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) num_threads(num_omp_threads)
423+
#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap) \
424+
num_threads(num_omp_threads)
403425
{
404426
std::unique_ptr<BinaryInvertedListScanner> scanner(
405427
ivf->get_InvertedListScanner(store_pairs, sel));
@@ -493,17 +515,19 @@ void search_knn_hamming_count(
493515

494516
std::vector<HCounterState<HammingComputer>> cs;
495517
for (size_t i = 0; i < nx; ++i) {
496-
cs.push_back(HCounterState<HammingComputer>(
497-
all_counters.data() + i * nBuckets,
498-
all_ids_per_dis.get() + i * nBuckets * k,
499-
x + i * ivf->code_size,
500-
ivf->d,
501-
k));
518+
cs.push_back(
519+
HCounterState<HammingComputer>(
520+
all_counters.data() + i * nBuckets,
521+
all_ids_per_dis.get() + i * nBuckets * k,
522+
x + i * ivf->code_size,
523+
ivf->d,
524+
k));
502525
}
503526

504527
size_t nlistv = 0, ndis = 0;
505528

506-
#pragma omp parallel for reduction(+ : nlistv, ndis) num_threads(num_omp_threads)
529+
#pragma omp parallel for reduction(+ : nlistv, ndis) \
530+
num_threads(num_omp_threads)
507531
for (int64_t i = 0; i < nx; i++) {
508532
const idx_t* keysi = keys + i * nprobe;
509533
HCounterState<HammingComputer>& csi = cs[i];
@@ -768,7 +792,8 @@ struct BuildScanner {
768792

769793
template <class HammingComputer>
770794
T f(size_t code_size, bool store_pairs, const IDSelector* sel) {
771-
return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs, sel);
795+
return new IVFBinaryScannerL2<HammingComputer>(
796+
code_size, store_pairs, sel);
772797
}
773798
};
774799

@@ -779,19 +804,26 @@ BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
779804
const IDSelector* sel) const {
780805
// Choose the appropriate HammingComputer type based on code_size
781806
if (code_size == 4) {
782-
return new IVFBinaryScannerL2<HammingComputer4>(code_size, store_pairs, sel);
807+
return new IVFBinaryScannerL2<HammingComputer4>(
808+
code_size, store_pairs, sel);
783809
} else if (code_size == 8) {
784-
return new IVFBinaryScannerL2<HammingComputer8>(code_size, store_pairs, sel);
810+
return new IVFBinaryScannerL2<HammingComputer8>(
811+
code_size, store_pairs, sel);
785812
} else if (code_size == 16) {
786-
return new IVFBinaryScannerL2<HammingComputer16>(code_size, store_pairs, sel);
813+
return new IVFBinaryScannerL2<HammingComputer16>(
814+
code_size, store_pairs, sel);
787815
} else if (code_size == 20) {
788-
return new IVFBinaryScannerL2<HammingComputer20>(code_size, store_pairs, sel);
816+
return new IVFBinaryScannerL2<HammingComputer20>(
817+
code_size, store_pairs, sel);
789818
} else if (code_size == 32) {
790-
return new IVFBinaryScannerL2<HammingComputer32>(code_size, store_pairs, sel);
819+
return new IVFBinaryScannerL2<HammingComputer32>(
820+
code_size, store_pairs, sel);
791821
} else if (code_size == 64) {
792-
return new IVFBinaryScannerL2<HammingComputer64>(code_size, store_pairs, sel);
822+
return new IVFBinaryScannerL2<HammingComputer64>(
823+
code_size, store_pairs, sel);
793824
} else {
794-
return new IVFBinaryScannerL2<HammingComputerDefault>(code_size, store_pairs, sel);
825+
return new IVFBinaryScannerL2<HammingComputerDefault>(
826+
code_size, store_pairs, sel);
795827
}
796828
}
797829

@@ -806,7 +838,6 @@ void IndexBinaryIVF::search_preassigned(
806838
bool store_pairs,
807839
const IVFSearchParameters* params,
808840
const IDSelector* sel) const {
809-
810841
if (per_invlist_search) {
811842
Run_search_knn_hamming_per_invlist r;
812843
// clang-format off
@@ -816,7 +847,16 @@ void IndexBinaryIVF::search_preassigned(
816847
// clang-format on
817848
} else if (use_heap) {
818849
search_knn_hamming_heap(
819-
this, n, x, k, assign, centroid_dis, distances, labels, store_pairs, params);
850+
this,
851+
n,
852+
x,
853+
k,
854+
assign,
855+
centroid_dis,
856+
distances,
857+
labels,
858+
store_pairs,
859+
params);
820860
} else if (store_pairs) { // !use_heap && store_pairs
821861
Run_search_knn_hamming_count<true> r;
822862
dispatch_HammingComputer(

faiss/IndexBinaryIVF.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ struct IndexBinaryIVF : IndexBinary {
125125
const IDSelector* sel = nullptr) const;
126126

127127
virtual BinaryInvertedListScanner* get_InvertedListScanner(
128-
bool store_pairs = false,const IDSelector* sel = nullptr) const;
128+
bool store_pairs = false,
129+
const IDSelector* sel = nullptr) const;
130+
131+
void get_lists_for_keys(idx_t* keys, size_t n_keys, idx_t* lists);
129132

130133
/** assign the vectors, then call search_preassign */
131134
void search(
@@ -219,7 +222,6 @@ struct IndexBinaryIVF : IndexBinary {
219222
};
220223

221224
struct BinaryInvertedListScanner {
222-
223225
bool store_pairs;
224226
const IDSelector* sel;
225227
idx_t list_no;
@@ -232,7 +234,7 @@ struct BinaryInvertedListScanner {
232234
sel(sel),
233235
list_no(-1),
234236
query_vector(nullptr) {}
235-
237+
236238
/// from now on we handle this query.
237239
virtual void set_query(const uint8_t* query_vector) = 0;
238240

0 commit comments

Comments
 (0)