Skip to content

Commit 32ba41b

Browse files
Add getDimension to ImmutableGraphIndex and all implementations
Add getDimension to ImmutableGraphIndex and all implementations
1 parent d95bc61 commit 32ba41b

6 files changed

Lines changed: 33 additions & 6 deletions

File tree

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
325325
this.simdExecutor = simdExecutor;
326326
this.parallelExecutor = parallelExecutor;
327327

328-
this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha));
328+
this.graph = new OnHeapGraphIndex(maxDegrees, dimension, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha));
329329

330330
this.searchers = ExplicitThreadLocal.withInitial(() -> {
331331
var gs = new GraphSearcher(graph);
@@ -1001,7 +1001,7 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in,
10011001

10021002
var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha);
10031003

1004-
try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, overflowRatio, diversityProvider);) {
1004+
try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, newVectors.dimension(), overflowRatio, diversityProvider);) {
10051005

10061006
GraphIndexBuilder builder = new GraphIndexBuilder(
10071007
buildScoreProvider,

jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ default int size() {
8080

8181
List<Integer> maxDegrees();
8282

83+
/**
84+
* @return the dimension of the vectors in the graph
85+
*/
86+
int getDimension();
87+
8388
/**
8489
* @return the first ordinal greater than all node ids in the graph. Equal to size() in simple cases;
8590
* May be different from size() if nodes are being added concurrently, or if nodes have been

jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,17 @@ public class OnHeapGraphIndex implements MutableGraphIndex {
7373

7474
// Maximum number of neighbors (edges) per node per layer
7575
final List<Integer> maxDegrees;
76+
private final int dimension;
7677
// The ratio by which we can overflow the neighborhood of a node during construction. Since it is a multiplicative
7778
// ratio, i.e., the maximum allowable degree if maxDegree * overflowRatio, it should be higher than 1.
7879
private final double overflowRatio;
7980

8081
private volatile boolean allMutationsCompleted = false;
8182

82-
OnHeapGraphIndex(List<Integer> maxDegrees, double overflowRatio, DiversityProvider diversityProvider) {
83+
OnHeapGraphIndex(List<Integer> maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider) {
8384
this.overflowRatio = overflowRatio;
8485
this.maxDegrees = new IntArrayList();
86+
this.dimension = dimension;
8587
setDegrees(maxDegrees);
8688
entryPoint = new AtomicReference<>();
8789
this.completions = new CompletionTracker(1024);
@@ -369,6 +371,11 @@ public void setDegrees(List<Integer> layerDegrees) {
369371
maxDegrees.addAll(layerDegrees);
370372
}
371373

374+
@Override
375+
public int getDimension() {
376+
return dimension;
377+
}
378+
372379
@Override
373380
public void setAllMutationsCompleted() {
374381
allMutationsCompleted = true;
@@ -541,7 +548,7 @@ public void save(DataOutput out) throws IOException {
541548
*/
542549
@Experimental
543550
@Deprecated
544-
public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException {
551+
public static OnHeapGraphIndex load(RandomAccessReader in, int dimension, double overflowRatio, DiversityProvider diversityProvider) throws IOException {
545552
int magic = in.readInt(); // the magic number
546553
if (magic != OnHeapGraphIndex.MAGIC) {
547554
throw new IOException("Unsupported magic number: " + magic);
@@ -561,7 +568,7 @@ public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio,
561568

562569
int entryNode = in.readInt();
563570

564-
var graph = new OnHeapGraphIndex(layerDegrees, overflowRatio, diversityProvider);
571+
var graph = new OnHeapGraphIndex(layerDegrees, dimension, overflowRatio, diversityProvider);
565572

566573
Map<Integer, Integer> nodeLevelMap = new HashMap<>();
567574

jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ public Set<FeatureId> getFeatureSet() {
227227
return features.keySet();
228228
}
229229

230+
@Override
231+
public int getDimension() {
232+
return dimension;
233+
}
234+
230235
@Override
231236
public int size(int level) {
232237
return layerInfo.get(level).size;

jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,11 @@ public List<Integer> maxDegrees() {
264264
throw new NotImplementedException();
265265
}
266266

267+
@Override
268+
public int getDimension() {
269+
throw new NotImplementedException();
270+
}
271+
267272
@Override
268273
public int getIdUpperBound() {
269274
return ImmutableGraphIndex.super.getIdUpperBound();
@@ -424,6 +429,11 @@ public List<Integer> maxDegrees() {
424429
throw new NotImplementedException();
425430
}
426431

432+
@Override
433+
public int getDimension() {
434+
throw new NotImplementedException();
435+
}
436+
427437
@Override
428438
public int getIdUpperBound() {
429439
return ImmutableGraphIndex.super.getIdUpperBound();

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException {
151151
log.info("Reading on-heap graph from {}", heapGraphOutputPath);
152152
MutableGraphIndex reconstructedOnHeapGraphIndex;
153153
try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) {
154-
reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(baseBuildScoreProvider, ALPHA));
154+
reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), baseVectorsRavv.dimension(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(baseBuildScoreProvider, ALPHA));
155155
}
156156

157157
try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath());

0 commit comments

Comments
 (0)