-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgenerate_ground_truth.py
More file actions
152 lines (140 loc) · 4.91 KB
/
generate_ground_truth.py
File metadata and controls
152 lines (140 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Generate ground truth ivecs file."""
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import svs
from . import consts
from . import utils
logger = logging.getLogger(__file__)
def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
"""Read command line arguments."""
parser = argparse.ArgumentParser(description=__file__.__doc__)
utils.add_common_arguments(parser)
parser.add_argument("--vecs_file", help="Vectors *vecs file", type=Path)
parser.add_argument("--query_file", help="Query vectors file", type=Path)
parser.add_argument("--out_file", help="Output file", type=Path)
parser.add_argument(
"--query_out_file",
help="Output file for query vectors generated when num_queries given",
type=Path,
)
parser.add_argument(
"--distance",
help="Distance",
choices=tuple(consts.STR_TO_DISTANCE.keys()),
default="mip",
)
parser.add_argument(
"-k", help="Number of neighbors", type=int, default=100
)
parser.add_argument("--num_vectors", help="Number of vectors", type=int)
parser.add_argument(
"--num_query_vectors",
help="Number of query vectors."
" If given, query vectors will be shuffled."
" If more than in the query file, the query vectors will be shuffled"
" and repeated as needed.",
type=int,
)
parser.add_argument(
"--shuffle", help="Shuffle order of vectors", action="store_true"
)
return parser.parse_args(argv)
def main(argv: str | None = None) -> None:
args = _read_args(argv)
log_file = utils.configure_logger(
logger, args.log_dir if args.log_dir is not None else args.out_dir
)
print("Logging to", log_file, sep="\n")
logger.info({"argv": argv if argv else sys.argv})
generate_ground_truth(
vecs_path=args.vecs_file,
query_file=args.query_file,
distance=consts.STR_TO_DISTANCE[args.distance],
num_vectors=args.num_vectors,
k=args.k,
num_threads=args.max_threads,
out_file=args.out_file,
query_out_path=args.query_out_file,
shuffle=args.shuffle,
seed=args.seed,
num_query_vectors=args.num_query_vectors,
)
def generate_ground_truth(
*,
vecs_path: Path,
query_file: Path,
distance: svs.DistanceType,
num_vectors: int | None = None,
k: int = 100,
num_threads: int = 1,
out_file: Path | None = None,
query_out_path: Path | None = None,
shuffle: bool = False,
seed: int = 42,
num_query_vectors: int | None = None,
) -> None:
if out_file is not None and out_file.suffix != ".ivecs":
raise SystemExit("Error: --out_file must end in .ivecs")
if (
query_out_path is not None
and query_out_path.suffix != query_file.suffix
):
raise SystemExit(
"Error: --query_out_path must have the same suffix as --query_file"
)
queries = svs.read_vecs(str(query_file))
vectors = svs.read_vecs(str(vecs_path))
# If num_vectors is None or larger than the number of vectors,
# slicing will return the whole array.
vectors = vectors[:num_vectors]
if shuffle:
np.random.default_rng(seed).shuffle(vectors)
index = svs.Flat(vectors, distance=distance, num_threads=num_threads)
idxs, _ = index.search(queries, k)
if num_query_vectors is not None:
queries_all = np.empty_like(
queries, shape=(num_query_vectors, queries.shape[1])
)
ground_truth_all = np.empty_like(
idxs, shape=(num_query_vectors, idxs.shape[1])
)
rng = np.random.default_rng(seed)
cursor = 0
while cursor < num_query_vectors:
permutation = rng.permutation(len(queries))
batch_size = min(num_query_vectors - cursor, len(queries))
queries_all[cursor : cursor + batch_size] = queries[
permutation[:batch_size]
]
ground_truth_all[cursor : cursor + batch_size] = idxs[
permutation[:batch_size]
]
cursor += batch_size
if query_out_path is None:
query_out_path = (
query_file.parent
/ f"{query_file.stem}-{num_query_vectors}_{seed}"
f"{query_file.suffix}"
)
svs.write_vecs(queries_all, str(query_out_path))
queries_path = query_out_path
else:
queries_path = query_file
ground_truth_all = idxs
if out_file is None:
out_file = utils.ground_truth_path(
vecs_path,
queries_path,
distance,
num_vectors,
seed if shuffle else None,
)
svs.write_vecs(ground_truth_all.astype(np.uint32), str(out_file))
logger.info({"ground_truth_saved": out_file})
if __name__ == "__main__":
main()