Skip to content

Commit 40407c6

Browse files
meiravgrigithub-actions[bot]
authored andcommitted
MOD-14470 Add VecSimParams_GetQueryBlobSize API for safe query vector allocatio (#915)
* api for VecSimParams_GetQueryBlobSize * fix dtor * move index-> null to the test start * format (cherry picked from commit 8d37cf1)
1 parent 8e80775 commit 40407c6

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

src/VecSim/vec_sim.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,18 @@ extern "C" void VecSim_Normalize(void *blob, size_t dim, VecSimType type) {
207207
}
208208
}
209209

210+
extern "C" size_t VecSimParams_GetQueryBlobSize(VecSimType type, size_t dim, VecSimMetric metric) {
211+
// Assert all supported types are covered
212+
assert(type == VecSimType_FLOAT32 || type == VecSimType_FLOAT64 ||
213+
type == VecSimType_BFLOAT16 || type == VecSimType_FLOAT16 || type == VecSimType_INT8 ||
214+
type == VecSimType_UINT8);
215+
size_t blobSize = VecSimType_sizeof(type) * dim;
216+
if (metric == VecSimMetric_Cosine && (type == VecSimType_INT8 || type == VecSimType_UINT8)) {
217+
blobSize += sizeof(float); // For the norm
218+
}
219+
return blobSize;
220+
}
221+
210222
extern "C" size_t VecSimIndex_IndexSize(VecSimIndex *index) { return index->indexSize(); }
211223

212224
extern "C" VecSimResolveCode VecSimIndex_ResolveParams(VecSimIndex *index, VecSimRawParam *rparams,

src/VecSim/vec_sim.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ double VecSimIndex_GetDistanceFrom_Unsafe(VecSimIndex *index, size_t label, cons
9898
*/
9999
void VecSim_Normalize(void *blob, size_t dim, VecSimType type);
100100

101+
/**
102+
* @brief Returns the required blob size for a query vector that will be normalized.
103+
*
104+
* For INT8/UINT8 vectors with Cosine metric, VecSim_Normalize appends the norm (a float)
105+
* at the end of the blob, so the required size is larger than just dim * sizeof(type).
106+
*
107+
* @param type vector element type.
108+
* @param dim vector dimension.
109+
* @param metric distance metric.
110+
* @return required blob size in bytes.
111+
*/
112+
size_t VecSimParams_GetQueryBlobSize(VecSimType type, size_t dim, VecSimMetric metric);
113+
101114
/**
102115
* @brief Return the number of vectors in the index.
103116
* @param index the index whose size is requested.

tests/unit/test_common.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,11 @@ class CommonTypeMetricTests : public testing::TestWithParam<std::tuple<VecSimTyp
836836
template <typename algo_params>
837837
void test_initial_size_estimation();
838838

839-
virtual void TearDown() { VecSimIndex_Free(index); }
839+
virtual void TearDown() {
840+
if (index) {
841+
VecSimIndex_Free(index);
842+
}
843+
}
840844

841845
VecSimIndex *index;
842846
};
@@ -880,6 +884,26 @@ TEST_P(CommonTypeMetricTests, TestInitialSizeEstimationHNSW) {
880884
this->test_initial_size_estimation<HNSWParams>();
881885
}
882886

887+
TEST_P(CommonTypeMetricTests, TestGetQueryBlobSize) {
888+
// We don't need to create an index for this test, set to nullptr to avoid cleanup issues
889+
this->index = nullptr;
890+
891+
size_t dim = 4;
892+
VecSimType type = std::get<0>(GetParam());
893+
VecSimMetric metric = std::get<1>(GetParam());
894+
895+
// Call the API function
896+
size_t actual = VecSimParams_GetQueryBlobSize(type, dim, metric);
897+
898+
// Calculate expected blob size
899+
size_t expected = dim * VecSimType_sizeof(type);
900+
if (metric == VecSimMetric_Cosine && (type == VecSimType_INT8 || type == VecSimType_UINT8)) {
901+
expected += sizeof(float); // For the norm
902+
}
903+
904+
ASSERT_EQ(actual, expected);
905+
}
906+
883907
class CommonTypeMetricTieredTests : public CommonTypeMetricTests {
884908
protected:
885909
virtual void TearDown() override {}

0 commit comments

Comments
 (0)