|
16 | 16 | #include <cstdio> |
17 | 17 |
|
18 | 18 | #include <algorithm> |
| 19 | +#include <numeric> |
19 | 20 |
|
20 | 21 | #include <faiss/utils/Heap.h> |
21 | 22 | #include <faiss/utils/distances.h> |
@@ -565,6 +566,7 @@ struct QueryTables { |
565 | 566 | } |
566 | 567 | } |
567 | 568 |
|
| 569 | + |
568 | 570 | /***************************************************** |
569 | 571 | * When inverted list is known: prepare computations |
570 | 572 | *****************************************************/ |
@@ -748,6 +750,33 @@ struct QueryTables { |
748 | 750 |
|
749 | 751 | return dis0; |
750 | 752 | } |
| 753 | + |
| 754 | + |
| 755 | + void init_sim_table(const float* qi, const float* table) { |
| 756 | + this->qi = qi; |
| 757 | + |
| 758 | + if (metric_type == METRIC_INNER_PRODUCT) { |
| 759 | + memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float)); |
| 760 | + } else { |
| 761 | + if (!by_residual) { |
| 762 | + memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float)); |
| 763 | + } else { |
| 764 | + memcpy(sim_table_2, table, pq.ksub * pq.M * sizeof(float)); |
| 765 | + } |
| 766 | + } |
| 767 | + } |
| 768 | + |
| 769 | + void copy_sim_table(float* table) const { |
| 770 | + if (metric_type == METRIC_INNER_PRODUCT) { |
| 771 | + memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float)); |
| 772 | + } else { |
| 773 | + if (!by_residual) { |
| 774 | + memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float)); |
| 775 | + } else { |
| 776 | + memcpy(table, sim_table_2, pq.ksub * pq.M * sizeof(float)); |
| 777 | + } |
| 778 | + } |
| 779 | + } |
751 | 780 | }; |
752 | 781 |
|
753 | 782 | // This way of handling the selector is not optimal since all distances |
@@ -1387,4 +1416,76 @@ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const { |
1387 | 1416 | return ngroup; |
1388 | 1417 | } |
1389 | 1418 |
|
| 1419 | +void IndexIVFPQ::compute_distance_to_codes_for_list( |
| 1420 | + const idx_t list_no, |
| 1421 | + const float* x, |
| 1422 | + idx_t n, |
| 1423 | + const uint8_t* codes, |
| 1424 | + float* dists, |
| 1425 | + float* dist_table) const { |
| 1426 | + |
| 1427 | + std::unique_ptr<InvertedListScanner> scanner( |
| 1428 | + get_InvertedListScanner(true, nullptr)); |
| 1429 | + |
| 1430 | + |
| 1431 | + if (dist_table) { |
| 1432 | + if (auto* pqscanner = dynamic_cast<QueryTables*>(scanner.get())) { |
| 1433 | + pqscanner->init_sim_table(x, dist_table); |
| 1434 | + } |
| 1435 | + } else { |
| 1436 | + scanner->set_query(x); |
| 1437 | + } |
| 1438 | + |
| 1439 | + // Initialize distances with default values |
| 1440 | + std::vector<float> dist_out(n, metric_type == METRIC_L2 ? HUGE_VAL : -HUGE_VAL); |
| 1441 | + |
| 1442 | + |
| 1443 | + //find the centroid corresponding to the input list_no |
| 1444 | + //and compute its distance from the query vector |
| 1445 | + std::vector<float> centroid(d); |
| 1446 | + quantizer->reconstruct(list_no, centroid.data()); |
| 1447 | + |
| 1448 | + float coarse_dis = quantizer->metric_type == faiss::METRIC_L2 |
| 1449 | + ? faiss::fvec_L2sqr(x, centroid.data(), d) |
| 1450 | + : faiss::fvec_inner_product(x, centroid.data(), d); |
| 1451 | + |
| 1452 | + |
| 1453 | + scanner->set_list(list_no, coarse_dis); |
| 1454 | + |
| 1455 | + // Initialize ids_in as sequential numbers to allow mapping with output distances. |
| 1456 | + std::vector<idx_t> ids_in(n); |
| 1457 | + std::iota(ids_in.begin(), ids_in.end(), 0); |
| 1458 | + |
| 1459 | + //ids_out contain the order of distances in dist_out after scan_codes returns. |
| 1460 | + std::vector<idx_t> ids_out(n, 0); |
| 1461 | + |
| 1462 | + scanner->scan_codes(n, codes, ids_in.data(), dist_out.data(), ids_out.data(), n); |
| 1463 | + |
| 1464 | + // Reorder the returned distances in dist_out based on ids_out. |
| 1465 | + // This function needs to return the distances in the same order as input codes. |
| 1466 | + for (int j = 0; j < n; j++) { |
| 1467 | + int k = ids_out[j]; |
| 1468 | + dists[k] = dist_out[j]; |
| 1469 | + } |
| 1470 | + |
| 1471 | + return; |
| 1472 | +} |
| 1473 | + |
| 1474 | +//This function computes the distance table for the input vector x and returns it in dtable. |
| 1475 | +void IndexIVFPQ::compute_distance_table( |
| 1476 | + const float* x, |
| 1477 | + float* dist_table) const { |
| 1478 | + |
| 1479 | + std::unique_ptr<InvertedListScanner> scanner( |
| 1480 | + get_InvertedListScanner(true, nullptr)); |
| 1481 | + |
| 1482 | + scanner->set_query(x); |
| 1483 | + |
| 1484 | + if (auto* pqscanner = dynamic_cast<QueryTables*>(scanner.get())) { |
| 1485 | + pqscanner->copy_sim_table(dist_table); |
| 1486 | + } |
| 1487 | + |
| 1488 | + return; |
| 1489 | +} |
| 1490 | + |
1390 | 1491 | } // namespace faiss |
0 commit comments