Skip to content

Commit 0be294a

Browse files
authored
Implement compute_distance_to_codes_for_list and compute_distance_table for IndexIVFPQ (#50)
MB-65606 Implement compute_distance_to_codes_for_list for IndexIVFPQ * Implement compute_distance_to_codes_for_list for IndexIVFPQ. Given a query vector x, this function computes distance to provided codes for the input list_no. * Internally the implementation utilizes inverted list scanner to perform the actual computations. MB-65606 Implement compute_distance_table for IndexIVFPQ * Add compute_distance_table to IndexIVFPQ which computes the distance table for a given query vector and returns to the caller. * This allows caller to reuse the distance table for subsequent calls to the function compute_distance_to_codes_for_list. * compute_distance_to_codes_for_list now accepts a precomputed distance table if specified in the input.
1 parent 352484e commit 0be294a

7 files changed

Lines changed: 164 additions & 6 deletions

File tree

c_api/IndexIVF_c_ex.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,23 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list(
8888
const float* x,
8989
idx_t n,
9090
const uint8_t* codes,
91-
float* dists) {
91+
float* dists,
92+
float* dist_table) {
9293
try {
9394
reinterpret_cast<IndexIVF*>(index)->compute_distance_to_codes_for_list(
94-
list_no, x, n, codes, dists);
95+
list_no, x, n, codes, dists, dist_table);
96+
return 0;
97+
}
98+
CATCH_AND_HANDLE
99+
}
100+
101+
int faiss_IndexIVF_compute_distance_table(
102+
FaissIndexIVF* index,
103+
const float* x,
104+
float* dist_table) {
105+
try {
106+
reinterpret_cast<IndexIVF*>(index)->compute_distance_table(
107+
x, dist_table);
95108
return 0;
96109
}
97110
CATCH_AND_HANDLE

c_api/IndexIVF_c_ex.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ int faiss_IndexIVF_search_preassigned_with_params(
8181
@param n - number of codes
8282
@param codes - input codes
8383
@param dists - output computed distances
84+
@param dist_table - input precomputed distance table for PQ
8485
*/
8586

8687
int faiss_IndexIVF_compute_distance_to_codes_for_list(
@@ -89,7 +90,8 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list(
8990
const float* x,
9091
idx_t n,
9192
const uint8_t* codes,
92-
float* dists);
93+
float* dists,
94+
float* dist_table);
9395

9496
/*
9597
Given multiple vector IDs, retrieve the corresponding list (cluster) IDs
@@ -108,6 +110,20 @@ int faiss_get_lists_for_keys(
108110
size_t n_keys,
109111
idx_t* lists);
110112

113+
/*
114+
Given a query vector x, compute distance table and
115+
return to the caller.
116+
117+
@param x - input query vector
118+
@param dist_table - output precomputed distance table for PQ
119+
120+
*/
121+
122+
int faiss_IndexIVF_compute_distance_table(
123+
FaissIndexIVF* index,
124+
const float* x,
125+
float* dist_table);
126+
111127
#ifdef __cplusplus
112128
}
113129
#endif

faiss/IndexIVF.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,28 @@ struct IndexIVF : Index, IndexIVFInterface {
459459
* @param n - number of codes
460460
* @param codes - input codes
461461
* @param dists - output computed distances
462+
* @param dist_table - input precomputed distance table for PQ
462463
*/
463464

464465
virtual void compute_distance_to_codes_for_list(
465466
const idx_t list_no,
466467
const float* x,
467468
idx_t n,
468469
const uint8_t* codes,
469-
float* dists) const {};
470+
float* dists,
471+
float* dist_table) const {};
472+
473+
/** Given a query vector x, compute distance table and
474+
* return to the caller.
475+
*
476+
* @param x - input query vector
477+
* @param dist_table - output precomputed distance table for PQ
478+
*
479+
*/
480+
481+
virtual void compute_distance_table(
482+
const float* x,
483+
float* dist_table) const {};
470484

471485

472486
IndexIVF();

faiss/IndexIVFPQ.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <cstdio>
1717

1818
#include <algorithm>
19+
#include <numeric>
1920

2021
#include <faiss/utils/Heap.h>
2122
#include <faiss/utils/distances.h>
@@ -565,6 +566,7 @@ struct QueryTables {
565566
}
566567
}
567568

569+
568570
/*****************************************************
569571
* When inverted list is known: prepare computations
570572
*****************************************************/
@@ -748,6 +750,33 @@ struct QueryTables {
748750

749751
return dis0;
750752
}
753+
754+
755+
void init_sim_table(const float* qi, const float* table) {
756+
this->qi = qi;
757+
758+
if (metric_type == METRIC_INNER_PRODUCT) {
759+
memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float));
760+
} else {
761+
if (!by_residual) {
762+
memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float));
763+
} else {
764+
memcpy(sim_table_2, table, pq.ksub * pq.M * sizeof(float));
765+
}
766+
}
767+
}
768+
769+
void copy_sim_table(float* table) const {
770+
if (metric_type == METRIC_INNER_PRODUCT) {
771+
memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float));
772+
} else {
773+
if (!by_residual) {
774+
memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float));
775+
} else {
776+
memcpy(table, sim_table_2, pq.ksub * pq.M * sizeof(float));
777+
}
778+
}
779+
}
751780
};
752781

753782
// This way of handling the selector is not optimal since all distances
@@ -1387,4 +1416,76 @@ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
13871416
return ngroup;
13881417
}
13891418

1419+
void IndexIVFPQ::compute_distance_to_codes_for_list(
1420+
const idx_t list_no,
1421+
const float* x,
1422+
idx_t n,
1423+
const uint8_t* codes,
1424+
float* dists,
1425+
float* dist_table) const {
1426+
1427+
std::unique_ptr<InvertedListScanner> scanner(
1428+
get_InvertedListScanner(true, nullptr));
1429+
1430+
1431+
if (dist_table) {
1432+
if (auto* pqscanner = dynamic_cast<QueryTables*>(scanner.get())) {
1433+
pqscanner->init_sim_table(x, dist_table);
1434+
}
1435+
} else {
1436+
scanner->set_query(x);
1437+
}
1438+
1439+
// Initialize distances with default values
1440+
std::vector<float> dist_out(n, metric_type == METRIC_L2 ? HUGE_VAL : -HUGE_VAL);
1441+
1442+
1443+
//find the centroid corresponding to the input list_no
1444+
//and compute its distance from the query vector
1445+
std::vector<float> centroid(d);
1446+
quantizer->reconstruct(list_no, centroid.data());
1447+
1448+
float coarse_dis = quantizer->metric_type == faiss::METRIC_L2
1449+
? faiss::fvec_L2sqr(x, centroid.data(), d)
1450+
: faiss::fvec_inner_product(x, centroid.data(), d);
1451+
1452+
1453+
scanner->set_list(list_no, coarse_dis);
1454+
1455+
// Initialize ids_in as sequential numbers to allow mapping with output distances.
1456+
std::vector<idx_t> ids_in(n);
1457+
std::iota(ids_in.begin(), ids_in.end(), 0);
1458+
1459+
//ids_out contain the order of distances in dist_out after scan_codes returns.
1460+
std::vector<idx_t> ids_out(n, 0);
1461+
1462+
scanner->scan_codes(n, codes, ids_in.data(), dist_out.data(), ids_out.data(), n);
1463+
1464+
// Reorder the returned distances in dist_out based on ids_out.
1465+
// This function needs to return the distances in the same order as input codes.
1466+
for (int j = 0; j < n; j++) {
1467+
int k = ids_out[j];
1468+
dists[k] = dist_out[j];
1469+
}
1470+
1471+
return;
1472+
}
1473+
1474+
//This function computes the distance table for the input vector x and returns it in dtable.
1475+
void IndexIVFPQ::compute_distance_table(
1476+
const float* x,
1477+
float* dist_table) const {
1478+
1479+
std::unique_ptr<InvertedListScanner> scanner(
1480+
get_InvertedListScanner(true, nullptr));
1481+
1482+
scanner->set_query(x);
1483+
1484+
if (auto* pqscanner = dynamic_cast<QueryTables*>(scanner.get())) {
1485+
pqscanner->copy_sim_table(dist_table);
1486+
}
1487+
1488+
return;
1489+
}
1490+
13901491
} // namespace faiss

faiss/IndexIVFPQ.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ struct IndexIVFPQ : IndexIVF {
139139
/// build precomputed table
140140
void precompute_table();
141141

142+
void compute_distance_to_codes_for_list(
143+
const idx_t list_no,
144+
const float* x,
145+
idx_t n,
146+
const uint8_t* codes,
147+
float* dists,
148+
float* dist_table) const override;
149+
150+
void compute_distance_table(
151+
const float* x,
152+
float* dist_table) const override;
153+
142154
IndexIVFPQ();
143155
};
144156

faiss/IndexScalarQuantizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ void IndexIVFScalarQuantizer::compute_distance_to_codes_for_list(
287287
const float* x,
288288
idx_t n,
289289
const uint8_t* codes,
290-
float* dists) const {
290+
float* dists,
291+
float* dist_table) const {
291292

292293
std::unique_ptr<ScalarQuantizer::SQDistanceComputer> dc(
293294
sq.get_distance_computer(metric_type));

faiss/IndexScalarQuantizer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
110110
const float* x,
111111
idx_t n,
112112
const uint8_t* codes,
113-
float* dists) const override;
113+
float* dists,
114+
float* dist_table) const override;
114115

115116
};
116117

0 commit comments

Comments
 (0)