diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/DenseIntMapConcurrentBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/DenseIntMapConcurrentBenchmark.java
new file mode 100644
index 000000000..bae99d237
--- /dev/null
+++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/DenseIntMapConcurrentBenchmark.java
@@ -0,0 +1,196 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package io.github.jbellis.jvector.bench;
+
+import io.github.jbellis.jvector.util.DenseIntMap;
+import io.github.jbellis.jvector.util.IntMap;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Group;
+import org.openjdk.jmh.annotations.GroupThreads;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntFunction;
+
+/**
+ * Measures the throughput of concurrent {@link DenseIntMap} operations — the map that sits on
+ * the hot path of {@code ConcurrentNeighborMap} and was identified as the top lock-contention
+ * hotspot in a herddb indexing profile.
+ *
+ * Parameters:
+ *
+ * - {@code impl} — {@code legacy} (previous RW-lock + single {@code AtomicReferenceArray})
+ * vs {@code segmented} (current lock-free spine-of-segments). Both implementations run
+ * in the same JVM under identical conditions so results are directly comparable.
+ * - {@code initialCapacity} — {@code 1024} (default) vs. {@code totalKeys} (pre-sized hint).
+ * The hinted case isolates the "no resize required" scenario that best-case deployments
+ * (known shard size) can hit.
+ * - {@code totalKeys} — size of the working set.
+ *
+ * Thread counts are expressed via {@code @Threads} on each benchmark method: 1 and 8.
+ */
+@BenchmarkMode(Mode.Throughput)
+@OutputTimeUnit(TimeUnit.SECONDS)
+@Fork(value = 1, jvmArgsAppend = {"--enable-preview"})
+@Warmup(iterations = 3, time = 2)
+@Measurement(iterations = 5, time = 3)
+@State(Scope.Benchmark)
+public class DenseIntMapConcurrentBenchmark {
+
+ public enum Impl {
+ legacy(LegacyDenseIntMap::new),
+ segmented(DenseIntMap::new);
+
+ final IntFunction> factory;
+
+ Impl(IntFunction> factory) {
+ this.factory = factory;
+ }
+ }
+
+ @Param
+ public Impl impl;
+
+ @Param({"1024", "1000000"})
+ public int initialCapacity;
+
+ @Param({"1000000"})
+ public int totalKeys;
+
+ /** Shared map used by the "mixed" (pre-populated) benchmarks. */
+ private IntMap prepopulated;
+
+ /** Monotonic counter used by insert benchmarks so threads never collide on keys. */
+ private final AtomicInteger insertCursor = new AtomicInteger();
+
+ /** The map written to by insert benchmarks. Replaced in each trial's setup. */
+ private IntMap insertMap;
+
+ @Setup
+ public void setup() {
+ this.prepopulated = impl.factory.apply(initialCapacity);
+ for (int i = 0; i < totalKeys; i++) {
+ prepopulated.compareAndPut(i, null, i);
+ }
+ this.insertMap = impl.factory.apply(initialCapacity);
+ this.insertCursor.set(0);
+ }
+
+ /**
+ * Models the {@code addNode} insertion pressure during graph build: many threads inserting
+ * disjoint dense keys. Under the legacy design this path contended on the read lock whenever
+ * any writer happened to be resizing the backing array.
+ */
+ @Benchmark
+ @Threads(8)
+ public boolean insertDense8(Blackhole bh) {
+ return doInsert(bh);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public boolean insertDense1(Blackhole bh) {
+ return doInsert(bh);
+ }
+
+ private boolean doInsert(Blackhole bh) {
+ int key = insertCursor.getAndIncrement();
+ if (key >= totalKeys) {
+ // Bounded workload; replace the map once we've filled it to avoid unbounded memory growth.
+ synchronized (this) {
+ if (insertCursor.get() >= totalKeys) {
+ insertMap = impl.factory.apply(initialCapacity);
+ insertCursor.set(0);
+ }
+ }
+ key = insertCursor.getAndIncrement();
+ }
+ boolean ok = insertMap.compareAndPut(key, null, key);
+ bh.consume(ok);
+ return ok;
+ }
+
+ /**
+ * Models the steady-state {@code insertEdge}/{@code insertDiverse} CAS-update pattern on an
+ * already-built base layer: each thread reads then CAS-updates a random pre-populated key.
+ */
+ @Benchmark
+ @Threads(8)
+ public boolean casUpdate8(Blackhole bh) {
+ return doCasUpdate(bh);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public boolean casUpdate1(Blackhole bh) {
+ return doCasUpdate(bh);
+ }
+
+ private boolean doCasUpdate(Blackhole bh) {
+ int key = ThreadLocalRandom.current().nextInt(totalKeys);
+ Integer current = prepopulated.get(key);
+ boolean ok = prepopulated.compareAndPut(key, current, key + 1);
+ bh.consume(ok);
+ return ok;
+ }
+
+ /**
+ * Pure {@code get()} throughput under heavy read load — sanity check that the lock-free
+ * read path remains as fast as before (and ideally faster, since there is no RW-lock
+ * machinery to traverse).
+ */
+ @Benchmark
+ @Threads(8)
+ public Integer getHot8() {
+ int key = ThreadLocalRandom.current().nextInt(totalKeys);
+ return prepopulated.get(key);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public Integer getHot1() {
+ int key = ThreadLocalRandom.current().nextInt(totalKeys);
+ return prepopulated.get(key);
+ }
+
+ /**
+ * Mixed read/write workload approximating the graph-build steady state: 7 readers for each
+ * writer doing a CAS update. Uses JMH groups so both run against the same shared map.
+ */
+ @Benchmark
+ @Group("mixed")
+ @GroupThreads(7)
+ public Integer mixedRead() {
+ int key = ThreadLocalRandom.current().nextInt(totalKeys);
+ return prepopulated.get(key);
+ }
+
+ @Benchmark
+ @Group("mixed")
+ @GroupThreads(1)
+ public boolean mixedWrite() {
+ int key = ThreadLocalRandom.current().nextInt(totalKeys);
+ Integer current = prepopulated.get(key);
+ return prepopulated.compareAndPut(key, current, key + 1);
+ }
+}
diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacyDenseIntMap.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacyDenseIntMap.java
new file mode 100644
index 000000000..c65f62df6
--- /dev/null
+++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacyDenseIntMap.java
@@ -0,0 +1,130 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package io.github.jbellis.jvector.bench;
+
+import io.github.jbellis.jvector.util.ArrayUtil;
+import io.github.jbellis.jvector.util.IntMap;
+import io.github.jbellis.jvector.util.RamUsageEstimator;
+
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReferenceArray;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+/**
+ * Verbatim copy of the previous {@code DenseIntMap} implementation (volatile
+ * {@link AtomicReferenceArray} + {@link ReentrantReadWriteLock} with Lucene-style
+ * {@code ArrayUtil.oversize} growth). Kept in the benchmarks module so
+ * {@link DenseIntMapConcurrentBenchmark} can measure the new segmented impl against
+ * the old one under identical conditions, without needing a separate checkout.
+ */
+public class LegacyDenseIntMap implements IntMap {
+ private final ReadWriteLock rwl = new ReentrantReadWriteLock();
+ private volatile AtomicReferenceArray objects;
+ private final AtomicInteger size;
+
+ public LegacyDenseIntMap(int initialCapacity) {
+ objects = new AtomicReferenceArray<>(initialCapacity);
+ size = new AtomicInteger();
+ }
+
+ @Override
+ public boolean compareAndPut(int key, T existing, T value) {
+ if (value == null) {
+ throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead");
+ }
+
+ ensureCapacity(key);
+ rwl.readLock().lock();
+ try {
+ var success = objects.compareAndSet(key, existing, value);
+ var isInsert = success && existing == null;
+ if (isInsert) {
+ size.incrementAndGet();
+ }
+ return success;
+ } finally {
+ rwl.readLock().unlock();
+ }
+ }
+
+ @Override
+ public int size() {
+ return size.get();
+ }
+
+ @Override
+ public T get(int key) {
+ if (key >= objects.length()) {
+ return null;
+ }
+ return objects.get(key);
+ }
+
+ private void ensureCapacity(int node) {
+ if (node < objects.length()) {
+ return;
+ }
+
+ rwl.writeLock().lock();
+ try {
+ var oldArray = objects;
+ if (node >= oldArray.length()) {
+ int newSize = ArrayUtil.oversize(node + 1, RamUsageEstimator.NUM_BYTES_OBJECT_REF);
+ var newArray = new AtomicReferenceArray(newSize);
+ for (int i = 0; i < oldArray.length(); i++) {
+ newArray.set(i, oldArray.get(i));
+ }
+ objects = newArray;
+ }
+ } finally {
+ rwl.writeLock().unlock();
+ }
+ }
+
+ @Override
+ public T remove(int key) {
+ if (key >= objects.length()) {
+ return null;
+ }
+ var old = objects.get(key);
+ if (old == null) {
+ return null;
+ }
+
+ rwl.readLock().lock();
+ try {
+ if (objects.compareAndSet(key, old, null)) {
+ size.decrementAndGet();
+ return old;
+ } else {
+ return null;
+ }
+ } finally {
+ rwl.readLock().unlock();
+ }
+ }
+
+ @Override
+ public boolean containsKey(int key) {
+ return get(key) != null;
+ }
+
+ @Override
+ public void forEach(IntBiConsumer consumer) {
+ var ref = objects;
+ for (int i = 0; i < ref.length(); i++) {
+ var value = get(i);
+ if (value != null) {
+ consumer.consume(i, value);
+ }
+ }
+ }
+}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java
index 891fda756..24be6eaa4 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java
@@ -42,7 +42,11 @@ public class ConcurrentNeighborMap {
public final int maxOverflowDegree;
public ConcurrentNeighborMap(DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree) {
- this(new DenseIntMap<>(1024), diversityProvider, maxDegree, maxOverflowDegree);
+ this(diversityProvider, maxDegree, maxOverflowDegree, 1024);
+ }
+
+ public ConcurrentNeighborMap(DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree, int initialCapacity) {
+ this(new DenseIntMap<>(initialCapacity), diversityProvider, maxDegree, maxOverflowDegree);
}
public ConcurrentNeighborMap(IntMap neighbors, DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree) {
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java
index 9d8af192e..04670cfff 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java
@@ -301,6 +301,44 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
boolean refineFinalGraph,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor)
+ {
+ this(scoreProvider, dimension, maxDegrees, beamWidth, neighborOverflow, alpha, addHierarchy, refineFinalGraph, simdExecutor, parallelExecutor, 1024);
+ }
+
+ /**
+ * Reads all the vectors from vector values, builds a graph connecting them by their dense
+ * ordinals, using the given hyperparameter settings, and returns the resulting graph.
+ *
+ * @param scoreProvider describes how to determine the similarities between vectors
+ * @param maxDegrees the maximum number of connections a node can have in each layer; if fewer entries
+ * are specified than the number of layers, the last entry is used for all remaining layers.
+ * @param beamWidth the size of the beam search to use when finding nearest neighbors.
+ * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a
+ * node. larger values will build more efficiently, but use more memory.
+ * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to
+ * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of
+ * an HNSW graph will be created, which is usually not what you want.
+ * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index.
+ * @param refineFinalGraph whether we do a second pass over each node in the graph to refine its connections
+ * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of
+ * the number of physical cores.
+ * @param parallelExecutor ForkJoinPool instance for parallel stream operations
+ * @param initialCapacity initial capacity hint for the dense base layer (layer 0). When the upper bound on
+ * node count is known in advance (e.g. a fixed dataset size), passing it here lets the
+ * base-layer map skip resizes during concurrent build, eliminating contention on the
+ * internal resize lock. Use {@code 1024} (the default) when the size is unknown.
+ */
+ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
+ int dimension,
+ List maxDegrees,
+ int beamWidth,
+ float neighborOverflow,
+ float alpha,
+ boolean addHierarchy,
+ boolean refineFinalGraph,
+ ForkJoinPool simdExecutor,
+ ForkJoinPool parallelExecutor,
+ int initialCapacity)
{
if (maxDegrees.stream().anyMatch(i -> i <= 0)) {
throw new IllegalArgumentException("layer degrees must be positive");
@@ -317,6 +355,9 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
if (alpha <= 0) {
throw new IllegalArgumentException("alpha must be positive");
}
+ if (initialCapacity <= 0) {
+ throw new IllegalArgumentException("initialCapacity must be positive");
+ }
this.scoreProvider = scoreProvider;
this.dimension = dimension;
@@ -328,7 +369,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;
- this.graph = new OnHeapGraphIndex(maxDegrees, dimension, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha), addHierarchy);
+ this.graph = new OnHeapGraphIndex(maxDegrees, dimension, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha), addHierarchy, initialCapacity);
this.searchers = ExplicitThreadLocal.withInitial(() -> {
var gs = new GraphSearcher(graph);
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java
index 9ed1a92dd..b3a45ff90 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java
@@ -85,15 +85,22 @@ public class OnHeapGraphIndex implements MutableGraphIndex {
private final boolean isHierarchical;
OnHeapGraphIndex(List maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider, boolean isHierarchical) {
+ this(maxDegrees, dimension, overflowRatio, diversityProvider, isHierarchical, 1024);
+ }
+
+ OnHeapGraphIndex(List maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider, boolean isHierarchical, int baseLayerInitialCapacity) {
+ if (baseLayerInitialCapacity <= 0) {
+ throw new IllegalArgumentException("baseLayerInitialCapacity must be positive, got " + baseLayerInitialCapacity);
+ }
this.overflowRatio = overflowRatio;
this.maxDegrees = new IntArrayList();
this.dimension = dimension;
setDegrees(maxDegrees);
entryPoint = new AtomicReference<>();
- this.completions = new CompletionTracker(1024);
+ this.completions = new CompletionTracker(baseLayerInitialCapacity);
// Initialize the base layer (layer 0) with a dense map.
this.layers.add(new ConcurrentNeighborMap(
- new DenseIntMap<>(1024),
+ new DenseIntMap<>(baseLayerInitialCapacity),
diversityProvider,
getDegree(0),
(int) (getDegree(0) * overflowRatio))
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java
index 683cfb5dc..6f1132538 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java
@@ -15,34 +15,72 @@
*/
package io.github.jbellis.jvector.util;
-import io.github.jbellis.jvector.graph.NodesIterator;
-
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-import java.util.stream.IntStream;
/**
- * A map (but not a Map) of int -> T where the int keys are dense-ish and start at zero,
- * but the size of the map is not known in advance. This provides fast, concurrent
- * updates and minimizes contention when the map is resized.
+ * A map (but not a Map) of int -> T where the int keys are dense-ish and start at zero,
+ * but the size of the map is not known in advance. This provides fast, concurrent
+ * updates with zero lock contention on the common path.
*
* "Dense-ish" means that space is allocated for all keys from 0 to the highest key, but
- * it is valid to have gaps in the keys. The value associated with "gap" keys is null.
+ * it is valid to have gaps in the keys. The value associated with "gap" keys is null.
+ *
+ * Storage layout. The map uses a two-tier structure:
+ *
+ * - A base {@link AtomicReferenceArray} sized from the constructor's
+ * {@code initialCapacity}. The base is immortal — allocated once, never
+ * resized, never copied — so reads and writes for keys below {@code initialCapacity}
+ * are a single volatile load + slot access (and, for writes, a CAS + an
+ * {@code AtomicInteger} increment). This matches the legacy implementation's
+ * read throughput exactly, and beats its write throughput because no
+ * {@link java.util.concurrent.locks.ReentrantReadWriteLock} traversal is required.
+ * - A lazily-allocated overflow tier of fixed-size segments (1024 slots each)
+ * for keys at or beyond {@code initialCapacity}. Once a segment is installed it is
+ * never reallocated, so writes through the overflow are also lock-free on the steady
+ * state. Only spine growth + first-time segment install share a single
+ * {@code synchronized} block — the hot path never takes it.
+ *
+ * Callers that know their upper bound on node count (e.g. a fixed dataset size) should
+ * pass it as {@code initialCapacity} so all operations stay on the fast base path.
*/
public class DenseIntMap implements IntMap {
- // locking strategy:
- // - writelock to resize the array
- // - readlock to update the array with put or remove
- // - no lock to read the array, volatile is enough
- private final ReadWriteLock rwl = new ReentrantReadWriteLock();
- private volatile AtomicReferenceArray objects;
- private final AtomicInteger size;
+ private static final int OVERFLOW_SEGMENT_BITS = 10;
+ private static final int OVERFLOW_SEGMENT_SIZE = 1 << OVERFLOW_SEGMENT_BITS;
+ private static final int OVERFLOW_SEGMENT_MASK = OVERFLOW_SEGMENT_SIZE - 1;
+
+ /** Immortal base array. Sized at construction and never reassigned. */
+ private final AtomicReferenceArray base;
+ /** Cached {@code base.length()} so the hot path avoids the extra volatile read. */
+ private final int baseLen;
+
+ /** Lazily-installed segmented overflow for keys at or beyond {@link #baseLen}. */
+ private volatile Overflow overflow;
+ private final Object overflowInitLock = new Object();
+
+ private final AtomicInteger size = new AtomicInteger();
public DenseIntMap(int initialCapacity) {
- objects = new AtomicReferenceArray<>(initialCapacity);
- size = new AtomicInteger();
+ if (initialCapacity <= 0) {
+ throw new IllegalArgumentException("initialCapacity must be positive, got " + initialCapacity);
+ }
+ this.base = new AtomicReferenceArray<>(initialCapacity);
+ this.baseLen = initialCapacity;
+ }
+
+ @Override
+ public T get(int key) {
+ if (key < 0) {
+ return null;
+ }
+ if (key < baseLen) {
+ return base.get(key);
+ }
+ Overflow o = overflow;
+ if (o == null) {
+ return null;
+ }
+ return o.get(key - baseLen);
}
@Override
@@ -50,19 +88,17 @@ public boolean compareAndPut(int key, T existing, T value) {
if (value == null) {
throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead");
}
-
- ensureCapacity(key);
- rwl.readLock().lock();
- try {
- var success = objects.compareAndSet(key, existing, value);
- var isInsert = success && existing == null;
- if (isInsert) {
+ if (key < 0) {
+ throw new IndexOutOfBoundsException("key must be non-negative, got " + key);
+ }
+ if (key < baseLen) {
+ boolean success = base.compareAndSet(key, existing, value);
+ if (success && existing == null) {
size.incrementAndGet();
}
return success;
- } finally {
- rwl.readLock().unlock();
}
+ return overflowForWrite().compareAndPut(key - baseLen, existing, value, size);
}
@Override
@@ -71,70 +107,161 @@ public int size() {
}
@Override
- public T get(int key) {
- if (key >= objects.length()) {
+ public T remove(int key) {
+ if (key < 0) {
return null;
}
+ if (key < baseLen) {
+ T old = base.get(key);
+ if (old == null) {
+ return null;
+ }
+ if (base.compareAndSet(key, old, null)) {
+ size.decrementAndGet();
+ return old;
+ }
+ return null;
+ }
+ Overflow o = overflow;
+ if (o == null) {
+ return null;
+ }
+ return o.remove(key - baseLen, size);
+ }
- return objects.get(key);
+ @Override
+ public boolean containsKey(int key) {
+ return get(key) != null;
}
- private void ensureCapacity(int node) {
- if (node < objects.length()) {
- return;
+ @Override
+ public void forEach(IntBiConsumer consumer) {
+ for (int i = 0; i < baseLen; i++) {
+ T value = base.get(i);
+ if (value != null) {
+ consumer.consume(i, value);
+ }
+ }
+ Overflow o = overflow;
+ if (o != null) {
+ o.forEach(baseLen, consumer);
}
+ }
- rwl.writeLock().lock();
- try {
- var oldArray = objects;
- if (node >= oldArray.length()) {
- int newSize = ArrayUtil.oversize(node + 1, RamUsageEstimator.NUM_BYTES_OBJECT_REF);
- var newArray = new AtomicReferenceArray(newSize);
- for (int i = 0; i < oldArray.length(); i++) {
- newArray.set(i, oldArray.get(i));
- }
- objects = newArray;
+ private Overflow overflowForWrite() {
+ Overflow o = overflow;
+ if (o != null) {
+ return o;
+ }
+ synchronized (overflowInitLock) {
+ if (overflow == null) {
+ overflow = new Overflow<>();
}
- } finally {
- rwl.writeLock().unlock();
+ return overflow;
}
}
- @Override
- public T remove(int key) {
- if (key >= objects.length()) {
- return null;
+ /**
+ * Segmented overflow tier for keys whose slot is not in {@link DenseIntMap#base}.
+ * Keys are stored at relative offsets from {@code baseLen} (so the first overflow slot is
+ * relKey 0). Segments are fixed-size and immortal once installed; the spine itself grows
+ * under {@link #spineLock} but the hot path never touches it.
+ */
+ private static final class Overflow {
+ private volatile AtomicReferenceArray> spine;
+ private final Object spineLock = new Object();
+
+ Overflow() {
+ this.spine = new AtomicReferenceArray<>(1);
}
- var old = objects.get(key);
- if (old == null) {
- return null;
+
+ T get(int relKey) {
+ AtomicReferenceArray> spineSnap = spine;
+ int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS;
+ if (segIdx >= spineSnap.length()) {
+ return null;
+ }
+ AtomicReferenceArray seg = spineSnap.get(segIdx);
+ if (seg == null) {
+ return null;
+ }
+ return seg.get(relKey & OVERFLOW_SEGMENT_MASK);
}
- rwl.readLock().lock();
- try {
- if (objects.compareAndSet(key, old, null)) {
+ boolean compareAndPut(int relKey, T existing, T value, AtomicInteger size) {
+ AtomicReferenceArray seg = segmentFor(relKey);
+ boolean success = seg.compareAndSet(relKey & OVERFLOW_SEGMENT_MASK, existing, value);
+ if (success && existing == null) {
+ size.incrementAndGet();
+ }
+ return success;
+ }
+
+ T remove(int relKey, AtomicInteger size) {
+ AtomicReferenceArray> spineSnap = spine;
+ int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS;
+ if (segIdx >= spineSnap.length()) {
+ return null;
+ }
+ AtomicReferenceArray seg = spineSnap.get(segIdx);
+ if (seg == null) {
+ return null;
+ }
+ int slot = relKey & OVERFLOW_SEGMENT_MASK;
+ T old = seg.get(slot);
+ if (old == null) {
+ return null;
+ }
+ if (seg.compareAndSet(slot, old, null)) {
size.decrementAndGet();
return old;
- } else {
- return null;
}
- } finally {
- rwl.readLock().unlock();
+ return null;
}
- }
- @Override
- public boolean containsKey(int key) {
- return get(key) != null;
- }
+ void forEach(int baseOffset, IntBiConsumer consumer) {
+ AtomicReferenceArray> spineSnap = spine;
+ for (int s = 0; s < spineSnap.length(); s++) {
+ AtomicReferenceArray seg = spineSnap.get(s);
+ if (seg == null) {
+ continue;
+ }
+ int segBase = s << OVERFLOW_SEGMENT_BITS;
+ for (int i = 0; i < OVERFLOW_SEGMENT_SIZE; i++) {
+ T value = seg.get(i);
+ if (value != null) {
+ consumer.consume(baseOffset + segBase + i, value);
+ }
+ }
+ }
+ }
- @Override
- public void forEach(IntBiConsumer consumer) {
- var ref = objects;
- for (int i = 0; i < ref.length(); i++) {
- var value = get(i);
- if (value != null) {
- consumer.consume(i, value);
+ private AtomicReferenceArray segmentFor(int relKey) {
+ int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS;
+ AtomicReferenceArray> spineSnap = spine;
+ if (segIdx < spineSnap.length()) {
+ AtomicReferenceArray seg = spineSnap.get(segIdx);
+ if (seg != null) {
+ return seg;
+ }
+ }
+ synchronized (spineLock) {
+ spineSnap = spine;
+ if (segIdx >= spineSnap.length()) {
+ int newLen = Math.max(spineSnap.length() * 2, segIdx + 1);
+ AtomicReferenceArray> next = new AtomicReferenceArray<>(newLen);
+ for (int i = 0; i < spineSnap.length(); i++) {
+ next.set(i, spineSnap.get(i));
+ }
+ spineSnap = next;
+ spine = next;
+ }
+ AtomicReferenceArray seg = spineSnap.get(segIdx);
+ if (seg == null) {
+ seg = new AtomicReferenceArray<>(OVERFLOW_SEGMENT_SIZE);
+ spineSnap.set(segIdx, seg);
+ }
+ return seg;
}
}
}
diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java
index 59b248584..c8f15ea87 100644
--- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java
+++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java
@@ -127,6 +127,56 @@ public void testRescore(boolean addHierarchy) {
}
+ @Test
+ public void testInitialCapacityHintProducesEquivalentGraph() {
+ int dimension = 8;
+ int size = 200;
+ var ravv = MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, getRandom()));
+ var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.COSINE);
+
+ var builderDefault = new GraphIndexBuilder(bsp, dimension, 8, 10, 1.2f, 1.0f, true, false);
+ var graphDefault = TestUtil.buildSequentially(builderDefault, ravv);
+
+ // Same build, but the caller passes an initialCapacity hint equal to the known vector count.
+ // This should produce an identical graph — the hint only changes internal sizing.
+ var builderHinted = new GraphIndexBuilder(bsp,
+ dimension,
+ java.util.List.of(8),
+ 10,
+ 1.2f,
+ 1.0f,
+ true,
+ false,
+ java.util.concurrent.ForkJoinPool.commonPool(),
+ java.util.concurrent.ForkJoinPool.commonPool(),
+ size);
+ var graphHinted = TestUtil.buildSequentially(builderHinted, ravv);
+
+ assertEquals(graphDefault.size(0), graphHinted.size(0));
+ for (int i = 0; i < ravv.size(); i++) {
+ assertTrue(graphDefault.containsNode(i));
+ assertTrue(graphHinted.containsNode(i));
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testInitialCapacityHintRejectsNonPositive() {
+ int dimension = 4;
+ var ravv = MockVectorValues.fromValues(createRandomFloatVectors(8, dimension, getRandom()));
+ var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.COSINE);
+ new GraphIndexBuilder(bsp,
+ dimension,
+ java.util.List.of(4),
+ 10,
+ 1.0f,
+ 1.0f,
+ false,
+ false,
+ java.util.concurrent.ForkJoinPool.commonPool(),
+ java.util.concurrent.ForkJoinPool.commonPool(),
+ 0);
+ }
+
@Test
public void testSaveAndLoad() throws IOException {
int dimension = randomIntBetween(2, 32);
diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMapSegmented.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMapSegmented.java
new file mode 100644
index 000000000..396623378
--- /dev/null
+++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMapSegmented.java
@@ -0,0 +1,225 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package io.github.jbellis.jvector.util;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Extra coverage for the segmented DenseIntMap implementation, exercising:
+ * - cross-segment-boundary writes (keys around multiples of 1024)
+ * - concurrent inserts that force the spine to grow from a small initial capacity
+ * - concurrent inserts + removes on the same dense key range
+ * - forEach ascending-key iteration and visibility of prior writes
+ *
+ * Complements the existing {@link TestIntMap} tests which cover both Dense and Sparse
+ * implementations against a small key range.
+ */
+@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
+public class TestDenseIntMapSegmented extends RandomizedTest {
+
+ @Test
+ public void testCrossSegmentBoundary() {
+ // SEGMENT_SIZE is 1024 internally; pick keys straddling the first few boundaries
+ // so a successful round trip proves routing and ascending iteration work across segments.
+ DenseIntMap map = new DenseIntMap<>(1);
+ int[] keys = {0, 1023, 1024, 2047, 2048, 4096, 10_000};
+ for (int k : keys) {
+ Assert.assertTrue(map.compareAndPut(k, null, "v" + k));
+ }
+ Assert.assertEquals(keys.length, map.size());
+ for (int k : keys) {
+ Assert.assertEquals("v" + k, map.get(k));
+ Assert.assertTrue(map.containsKey(k));
+ }
+
+ // forEach visits in ascending key order
+ List visited = new ArrayList<>();
+ map.forEach((key, value) -> {
+ visited.add(key);
+ Assert.assertEquals("v" + key, value);
+ });
+ Assert.assertEquals(keys.length, visited.size());
+ for (int i = 1; i < visited.size(); i++) {
+ Assert.assertTrue("forEach must iterate ascending", visited.get(i) > visited.get(i - 1));
+ }
+ }
+
+ /**
+ * Under a tiny initial capacity, many concurrent writers across a wide key range force
+ * the spine to grow repeatedly. Every inserted key must be visible afterwards, and the
+ * {@code size()} counter must match the number of successful inserts. Regressions here
+ * would indicate a lost write during spine growth.
+ */
+ @Test
+ public void testConcurrentInsertForcesSpineGrowth() throws InterruptedException {
+ DenseIntMap map = new DenseIntMap<>(1); // starts with a single segment
+ int nThreads = 16;
+ int perThread = 5_000;
+ int totalKeys = nThreads * perThread;
+
+ CountDownLatch start = new CountDownLatch(1);
+ CountDownLatch done = new CountDownLatch(nThreads);
+ AtomicInteger successes = new AtomicInteger();
+
+ for (int t = 0; t < nThreads; t++) {
+ final int threadId = t;
+ new Thread(() -> {
+ try {
+ start.await();
+ for (int i = 0; i < perThread; i++) {
+ int key = threadId * perThread + i;
+ if (map.compareAndPut(key, null, key)) {
+ successes.incrementAndGet();
+ }
+ }
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ } finally {
+ done.countDown();
+ }
+ }, "inserter-" + t).start();
+ }
+ start.countDown();
+ done.await();
+
+ Assert.assertEquals(totalKeys, successes.get());
+ Assert.assertEquals(totalKeys, map.size());
+ for (int k = 0; k < totalKeys; k++) {
+ Assert.assertEquals("missing key " + k, Integer.valueOf(k), map.get(k));
+ }
+ }
+
+ /**
+ * Concurrent {@code compareAndPut(null, v)} races on the same key must give exactly one
+ * winner. A bug where two threads both observe the slot as null and both increment
+ * {@code size} would show up as an inflated size.
+ */
+ @Test
+ public void testConcurrentInsertSameKey() throws InterruptedException {
+ for (int trial = 0; trial < 20; trial++) {
+ DenseIntMap map = new DenseIntMap<>(1);
+ int nThreads = 8;
+ int sameKey = 12_345; // far enough to require spine growth and segment install
+ CountDownLatch start = new CountDownLatch(1);
+ CountDownLatch done = new CountDownLatch(nThreads);
+ AtomicInteger winners = new AtomicInteger();
+
+ for (int t = 0; t < nThreads; t++) {
+ final int tid = t;
+ new Thread(() -> {
+ try {
+ start.await();
+ if (map.compareAndPut(sameKey, null, "tid=" + tid)) {
+ winners.incrementAndGet();
+ }
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ } finally {
+ done.countDown();
+ }
+ }).start();
+ }
+ start.countDown();
+ done.await();
+
+ Assert.assertEquals("exactly one thread must win the initial CAS", 1, winners.get());
+ Assert.assertEquals(1, map.size());
+ Assert.assertNotNull(map.get(sameKey));
+ }
+ }
+
+ /**
+ * Concurrent insert + remove on disjoint key subsets: the remover handles its own keys
+ * after the inserter has populated them. Final state is deterministic: all inserted keys
+ * present, none of the removed-then-re-inserted keys missing.
+ */
+ @Test
+ public void testConcurrentInsertAndRemove() throws InterruptedException {
+ DenseIntMap map = new DenseIntMap<>(1);
+ int keys = 20_000;
+ // First populate
+ for (int k = 0; k < keys; k++) {
+ Assert.assertTrue(map.compareAndPut(k, null, k));
+ }
+ Assert.assertEquals(keys, map.size());
+
+ int nThreads = 8;
+ CountDownLatch start = new CountDownLatch(1);
+ CountDownLatch done = new CountDownLatch(nThreads);
+ for (int t = 0; t < nThreads; t++) {
+ final int tid = t;
+ new Thread(() -> {
+ try {
+ start.await();
+ // Each thread touches its own slice: remove every 2nd, re-insert, remove again.
+ for (int k = tid; k < keys; k += nThreads) {
+ Integer removed = map.remove(k);
+ Assert.assertEquals(Integer.valueOf(k), removed);
+ Assert.assertTrue(map.compareAndPut(k, null, k));
+ }
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ } finally {
+ done.countDown();
+ }
+ }).start();
+ }
+ start.countDown();
+ done.await();
+
+ Assert.assertEquals(keys, map.size());
+ for (int k = 0; k < keys; k++) {
+ Assert.assertEquals(Integer.valueOf(k), map.get(k));
+ }
+ }
+
+ /**
+ * The initialCapacity hint must actually widen the spine so that keys up to the hint
+ * are reachable without any spine grow. The contract is semantic (the map works at
+ * that size), not structural — we don't expose internal counters — but this test
+ * serves as a regression guard against the hint being silently ignored.
+ */
+ @Test
+ public void testInitialCapacityHintSupportsInsertsUpToHint() {
+ int capacity = 4 * 1024 + 17; // more than one segment's worth, with a remainder
+ DenseIntMap map = new DenseIntMap<>(capacity);
+ for (int k = 0; k < capacity; k++) {
+ Assert.assertTrue("insert failed at k=" + k, map.compareAndPut(k, null, k));
+ }
+ Assert.assertEquals(capacity, map.size());
+ for (int k = 0; k < capacity; k++) {
+ Assert.assertEquals(Integer.valueOf(k), map.get(k));
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testRejectsZeroInitialCapacity() {
+ new DenseIntMap(0);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testRejectsNegativeInitialCapacity() {
+ new DenseIntMap(-1);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testRejectsNullValue() {
+ DenseIntMap map = new DenseIntMap<>(16);
+ map.compareAndPut(0, null, null);
+ }
+}