Skip to content

Commit a1f05ea

Browse files
Add RemappedRandomAccessVectorValues to fix BuildScoreProvider::randomAccessScoreProvider (#555)
1 parent 32ba41b commit a1f05ea

3 files changed

Lines changed: 92 additions & 9 deletions

File tree

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.graph;
18+
19+
import io.github.jbellis.jvector.vector.types.VectorFloat;
20+
21+
import java.util.Arrays;
22+
23+
public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues {
24+
private final RandomAccessVectorValues ravv;
25+
private final int[] graphToRavvOrdMap;
26+
27+
/**
28+
* Remaps a RAVV to a different set of ordinals. This is useful when the ordinals used by the graph
29+
* do not match the ordinals used by the RAVV.
30+
*
31+
* @param ravv the RAVV to remap
32+
* @param graphToRavvOrdMap a mapping from the graph's ordinals to the RAVV's ordinals where
33+
* graphToRavvOrdMap[i] is the RAVV ordinal corresponding to graph ordinal i.
34+
*/
35+
public RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) {
36+
this.ravv = ravv;
37+
this.graphToRavvOrdMap = graphToRavvOrdMap;
38+
}
39+
40+
@Override
41+
public int size() {
42+
return graphToRavvOrdMap.length;
43+
}
44+
45+
@Override
46+
public int dimension() {
47+
return ravv.dimension();
48+
}
49+
50+
@Override
51+
public VectorFloat<?> getVector(int node) {
52+
return ravv.getVector(graphToRavvOrdMap[node]);
53+
}
54+
55+
@Override
56+
public boolean isValueShared() {
57+
return ravv.isValueShared();
58+
}
59+
60+
@Override
61+
public RandomAccessVectorValues copy() {
62+
return new RemappedRandomAccessVectorValues(ravv.copy(), Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length));
63+
}
64+
65+
@Override
66+
public void getVectorInto(int node, VectorFloat<?> result, int offset) {
67+
ravv.getVectorInto(graphToRavvOrdMap[node], result, offset);
68+
}
69+
}

jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package io.github.jbellis.jvector.graph.similarity;
1818

1919
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
20+
import io.github.jbellis.jvector.graph.RemappedRandomAccessVectorValues;
2021
import io.github.jbellis.jvector.quantization.BQVectors;
2122
import io.github.jbellis.jvector.quantization.PQVectors;
2223
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
@@ -25,8 +26,6 @@
2526
import io.github.jbellis.jvector.vector.types.VectorFloat;
2627
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
2728

28-
import java.util.stream.IntStream;
29-
3029
/**
3130
* Encapsulates comparing node distances for GraphIndexBuilder.
3231
*/
@@ -88,15 +87,15 @@ public interface BuildScoreProvider {
8887
*
8988
* Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function.
9089
*/
91-
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
92-
return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction);
90+
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) {
91+
return randomAccessScoreProvider(new RemappedRandomAccessVectorValues(ravv, graphToRavvOrdMap), similarityFunction);
9392
}
9493

9594
/**
9695
* Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction.
9796
* graphToRavvOrdMap maps graph node IDs to ravv ordinals.
9897
*/
99-
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) {
98+
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
10099
// We need two sources of vectors in order to perform diversity check comparisons without
101100
// colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared.
102101
var vectors = ravv.threadLocalSupplier();
@@ -125,22 +124,22 @@ public VectorFloat<?> approximateCentroid() {
125124
@Override
126125
public SearchScoreProvider searchProviderFor(VectorFloat<?> vector) {
127126
var vc = vectorsCopy.get();
128-
return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc);
127+
return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc);
129128
}
130129

131130
@Override
132131
public SearchScoreProvider searchProviderFor(int node1) {
133132
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
134-
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
133+
var v = randomAccessVectorValues.getVector(node1);
135134
return searchProviderFor(v);
136135
}
137136

138137
@Override
139138
public SearchScoreProvider diversityProviderFor(int node1) {
140139
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
141-
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
140+
var v = randomAccessVectorValues.getVector(node1);
142141
var vc = vectorsCopy.get();
143-
return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc);
142+
return DefaultSearchScoreProvider.exact(v, similarityFunction, vc);
144143
}
145144
};
146145
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,19 @@ public void testSaveAndLoad() throws IOException {
156156
}
157157
assertGraphEquals(graph, builder.graph);
158158
}
159+
160+
// Because RandomAccessVectorValues is exposed in such a way that it allows for subsequent additions to the
161+
// vector source, we need to ensure that GraphIndexBuilder can handle this.
162+
@Test
163+
public void testAddNodesToVectorValuesIteratively() throws IOException {
164+
int dimension = randomIntBetween(2, 32);
165+
var mutableVectors = new ArrayList<VectorFloat<?>>();
166+
RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(mutableVectors, dimension);
167+
try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f, true)) {
168+
for (int i = 0; i < 10; i++) {
169+
mutableVectors.add(TestUtil.randomVector(random(), dimension));
170+
builder.addGraphNode(i, ravv.getVector(i));
171+
}
172+
}
173+
}
159174
}

0 commit comments

Comments
 (0)