diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/DenseIntMapConcurrentBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/DenseIntMapConcurrentBenchmark.java new file mode 100644 index 000000000..bae99d237 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/DenseIntMapConcurrentBenchmark.java @@ -0,0 +1,196 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.util.DenseIntMap; +import io.github.jbellis.jvector.util.IntMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Group; +import org.openjdk.jmh.annotations.GroupThreads; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntFunction; + +/** + * Measures the throughput of concurrent {@link DenseIntMap} operations — the map that sits on + * the hot path of {@code ConcurrentNeighborMap} and was identified as the top lock-contention + * hotspot in a herddb indexing profile. + *

+ * Parameters: + *

+ * Thread counts are expressed via {@code @Threads} on each benchmark method: 1 and 8. + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Fork(value = 1, jvmArgsAppend = {"--enable-preview"}) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 3) +@State(Scope.Benchmark) +public class DenseIntMapConcurrentBenchmark { + + public enum Impl { + legacy(LegacyDenseIntMap::new), + segmented(DenseIntMap::new); + + final IntFunction> factory; + + Impl(IntFunction> factory) { + this.factory = factory; + } + } + + @Param + public Impl impl; + + @Param({"1024", "1000000"}) + public int initialCapacity; + + @Param({"1000000"}) + public int totalKeys; + + /** Shared map used by the "mixed" (pre-populated) benchmarks. */ + private IntMap prepopulated; + + /** Monotonic counter used by insert benchmarks so threads never collide on keys. */ + private final AtomicInteger insertCursor = new AtomicInteger(); + + /** The map written to by insert benchmarks. Replaced in each trial's setup. */ + private IntMap insertMap; + + @Setup + public void setup() { + this.prepopulated = impl.factory.apply(initialCapacity); + for (int i = 0; i < totalKeys; i++) { + prepopulated.compareAndPut(i, null, i); + } + this.insertMap = impl.factory.apply(initialCapacity); + this.insertCursor.set(0); + } + + /** + * Models the {@code addNode} insertion pressure during graph build: many threads inserting + * disjoint dense keys. Under the legacy design this path contended on the read lock whenever + * any writer happened to be resizing the backing array. + */ + @Benchmark + @Threads(8) + public boolean insertDense8(Blackhole bh) { + return doInsert(bh); + } + + @Benchmark + @Threads(1) + public boolean insertDense1(Blackhole bh) { + return doInsert(bh); + } + + private boolean doInsert(Blackhole bh) { + int key = insertCursor.getAndIncrement(); + if (key >= totalKeys) { + // Bounded workload; replace the map once we've filled it to avoid unbounded memory growth. + synchronized (this) { + if (insertCursor.get() >= totalKeys) { + insertMap = impl.factory.apply(initialCapacity); + insertCursor.set(0); + } + } + key = insertCursor.getAndIncrement(); + } + boolean ok = insertMap.compareAndPut(key, null, key); + bh.consume(ok); + return ok; + } + + /** + * Models the steady-state {@code insertEdge}/{@code insertDiverse} CAS-update pattern on an + * already-built base layer: each thread reads then CAS-updates a random pre-populated key. + */ + @Benchmark + @Threads(8) + public boolean casUpdate8(Blackhole bh) { + return doCasUpdate(bh); + } + + @Benchmark + @Threads(1) + public boolean casUpdate1(Blackhole bh) { + return doCasUpdate(bh); + } + + private boolean doCasUpdate(Blackhole bh) { + int key = ThreadLocalRandom.current().nextInt(totalKeys); + Integer current = prepopulated.get(key); + boolean ok = prepopulated.compareAndPut(key, current, key + 1); + bh.consume(ok); + return ok; + } + + /** + * Pure {@code get()} throughput under heavy read load — sanity check that the lock-free + * read path remains as fast as before (and ideally faster, since there is no RW-lock + * machinery to traverse). + */ + @Benchmark + @Threads(8) + public Integer getHot8() { + int key = ThreadLocalRandom.current().nextInt(totalKeys); + return prepopulated.get(key); + } + + @Benchmark + @Threads(1) + public Integer getHot1() { + int key = ThreadLocalRandom.current().nextInt(totalKeys); + return prepopulated.get(key); + } + + /** + * Mixed read/write workload approximating the graph-build steady state: 7 readers for each + * writer doing a CAS update. Uses JMH groups so both run against the same shared map. + */ + @Benchmark + @Group("mixed") + @GroupThreads(7) + public Integer mixedRead() { + int key = ThreadLocalRandom.current().nextInt(totalKeys); + return prepopulated.get(key); + } + + @Benchmark + @Group("mixed") + @GroupThreads(1) + public boolean mixedWrite() { + int key = ThreadLocalRandom.current().nextInt(totalKeys); + Integer current = prepopulated.get(key); + return prepopulated.compareAndPut(key, current, key + 1); + } +} diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacyDenseIntMap.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacyDenseIntMap.java new file mode 100644 index 000000000..c65f62df6 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacyDenseIntMap.java @@ -0,0 +1,130 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.util.ArrayUtil; +import io.github.jbellis.jvector.util.IntMap; +import io.github.jbellis.jvector.util.RamUsageEstimator; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * Verbatim copy of the previous {@code DenseIntMap} implementation (volatile + * {@link AtomicReferenceArray} + {@link ReentrantReadWriteLock} with Lucene-style + * {@code ArrayUtil.oversize} growth). Kept in the benchmarks module so + * {@link DenseIntMapConcurrentBenchmark} can measure the new segmented impl against + * the old one under identical conditions, without needing a separate checkout. + */ +public class LegacyDenseIntMap implements IntMap { + private final ReadWriteLock rwl = new ReentrantReadWriteLock(); + private volatile AtomicReferenceArray objects; + private final AtomicInteger size; + + public LegacyDenseIntMap(int initialCapacity) { + objects = new AtomicReferenceArray<>(initialCapacity); + size = new AtomicInteger(); + } + + @Override + public boolean compareAndPut(int key, T existing, T value) { + if (value == null) { + throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead"); + } + + ensureCapacity(key); + rwl.readLock().lock(); + try { + var success = objects.compareAndSet(key, existing, value); + var isInsert = success && existing == null; + if (isInsert) { + size.incrementAndGet(); + } + return success; + } finally { + rwl.readLock().unlock(); + } + } + + @Override + public int size() { + return size.get(); + } + + @Override + public T get(int key) { + if (key >= objects.length()) { + return null; + } + return objects.get(key); + } + + private void ensureCapacity(int node) { + if (node < objects.length()) { + return; + } + + rwl.writeLock().lock(); + try { + var oldArray = objects; + if (node >= oldArray.length()) { + int newSize = ArrayUtil.oversize(node + 1, RamUsageEstimator.NUM_BYTES_OBJECT_REF); + var newArray = new AtomicReferenceArray(newSize); + for (int i = 0; i < oldArray.length(); i++) { + newArray.set(i, oldArray.get(i)); + } + objects = newArray; + } + } finally { + rwl.writeLock().unlock(); + } + } + + @Override + public T remove(int key) { + if (key >= objects.length()) { + return null; + } + var old = objects.get(key); + if (old == null) { + return null; + } + + rwl.readLock().lock(); + try { + if (objects.compareAndSet(key, old, null)) { + size.decrementAndGet(); + return old; + } else { + return null; + } + } finally { + rwl.readLock().unlock(); + } + } + + @Override + public boolean containsKey(int key) { + return get(key) != null; + } + + @Override + public void forEach(IntBiConsumer consumer) { + var ref = objects; + for (int i = 0; i < ref.length(); i++) { + var value = get(i); + if (value != null) { + consumer.consume(i, value); + } + } + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java index 891fda756..24be6eaa4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java @@ -42,7 +42,11 @@ public class ConcurrentNeighborMap { public final int maxOverflowDegree; public ConcurrentNeighborMap(DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree) { - this(new DenseIntMap<>(1024), diversityProvider, maxDegree, maxOverflowDegree); + this(diversityProvider, maxDegree, maxOverflowDegree, 1024); + } + + public ConcurrentNeighborMap(DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree, int initialCapacity) { + this(new DenseIntMap<>(initialCapacity), diversityProvider, maxDegree, maxOverflowDegree); } public ConcurrentNeighborMap(IntMap neighbors, DiversityProvider diversityProvider, int maxDegree, int maxOverflowDegree) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 9d8af192e..04670cfff 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -301,6 +301,44 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) + { + this(scoreProvider, dimension, maxDegrees, beamWidth, neighborOverflow, alpha, addHierarchy, refineFinalGraph, simdExecutor, parallelExecutor, 1024); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param scoreProvider describes how to determine the similarities between vectors + * @param maxDegrees the maximum number of connections a node can have in each layer; if fewer entries + * are specified than the number of layers, the last entry is used for all remaining layers. + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a + * node. larger values will build more efficiently, but use more memory. + * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to + * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of + * an HNSW graph will be created, which is usually not what you want. + * @param addHierarchy whether we want to add an HNSW-style hierarchy on top of the Vamana index. + * @param refineFinalGraph whether we do a second pass over each node in the graph to refine its connections + * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of + * the number of physical cores. + * @param parallelExecutor ForkJoinPool instance for parallel stream operations + * @param initialCapacity initial capacity hint for the dense base layer (layer 0). When the upper bound on + * node count is known in advance (e.g. a fixed dataset size), passing it here lets the + * base-layer map skip resizes during concurrent build, eliminating contention on the + * internal resize lock. Use {@code 1024} (the default) when the size is unknown. + */ + public GraphIndexBuilder(BuildScoreProvider scoreProvider, + int dimension, + List maxDegrees, + int beamWidth, + float neighborOverflow, + float alpha, + boolean addHierarchy, + boolean refineFinalGraph, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor, + int initialCapacity) { if (maxDegrees.stream().anyMatch(i -> i <= 0)) { throw new IllegalArgumentException("layer degrees must be positive"); @@ -317,6 +355,9 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, if (alpha <= 0) { throw new IllegalArgumentException("alpha must be positive"); } + if (initialCapacity <= 0) { + throw new IllegalArgumentException("initialCapacity must be positive"); + } this.scoreProvider = scoreProvider; this.dimension = dimension; @@ -328,7 +369,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.simdExecutor = simdExecutor; this.parallelExecutor = parallelExecutor; - this.graph = new OnHeapGraphIndex(maxDegrees, dimension, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha), addHierarchy); + this.graph = new OnHeapGraphIndex(maxDegrees, dimension, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha), addHierarchy, initialCapacity); this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 9ed1a92dd..b3a45ff90 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -85,15 +85,22 @@ public class OnHeapGraphIndex implements MutableGraphIndex { private final boolean isHierarchical; OnHeapGraphIndex(List maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider, boolean isHierarchical) { + this(maxDegrees, dimension, overflowRatio, diversityProvider, isHierarchical, 1024); + } + + OnHeapGraphIndex(List maxDegrees, int dimension, double overflowRatio, DiversityProvider diversityProvider, boolean isHierarchical, int baseLayerInitialCapacity) { + if (baseLayerInitialCapacity <= 0) { + throw new IllegalArgumentException("baseLayerInitialCapacity must be positive, got " + baseLayerInitialCapacity); + } this.overflowRatio = overflowRatio; this.maxDegrees = new IntArrayList(); this.dimension = dimension; setDegrees(maxDegrees); entryPoint = new AtomicReference<>(); - this.completions = new CompletionTracker(1024); + this.completions = new CompletionTracker(baseLayerInitialCapacity); // Initialize the base layer (layer 0) with a dense map. this.layers.add(new ConcurrentNeighborMap( - new DenseIntMap<>(1024), + new DenseIntMap<>(baseLayerInitialCapacity), diversityProvider, getDegree(0), (int) (getDegree(0) * overflowRatio)) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java index 683cfb5dc..6f1132538 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java @@ -15,34 +15,72 @@ */ package io.github.jbellis.jvector.util; -import io.github.jbellis.jvector.graph.NodesIterator; - import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.stream.IntStream; /** - * A map (but not a Map) of int -> T where the int keys are dense-ish and start at zero, - * but the size of the map is not known in advance. This provides fast, concurrent - * updates and minimizes contention when the map is resized. + * A map (but not a Map) of int -> T where the int keys are dense-ish and start at zero, + * but the size of the map is not known in advance. This provides fast, concurrent + * updates with zero lock contention on the common path. *

* "Dense-ish" means that space is allocated for all keys from 0 to the highest key, but - * it is valid to have gaps in the keys. The value associated with "gap" keys is null. + * it is valid to have gaps in the keys. The value associated with "gap" keys is null. + *

+ * Storage layout. The map uses a two-tier structure: + *

    + *
  • A base {@link AtomicReferenceArray} sized from the constructor's + * {@code initialCapacity}. The base is immortal — allocated once, never + * resized, never copied — so reads and writes for keys below {@code initialCapacity} + * are a single volatile load + slot access (and, for writes, a CAS + an + * {@code AtomicInteger} increment). This matches the legacy implementation's + * read throughput exactly, and beats its write throughput because no + * {@link java.util.concurrent.locks.ReentrantReadWriteLock} traversal is required.
  • + *
  • A lazily-allocated overflow tier of fixed-size segments (1024 slots each) + * for keys at or beyond {@code initialCapacity}. Once a segment is installed it is + * never reallocated, so writes through the overflow are also lock-free on the steady + * state. Only spine growth + first-time segment install share a single + * {@code synchronized} block — the hot path never takes it.
  • + *
+ * Callers that know their upper bound on node count (e.g. a fixed dataset size) should + * pass it as {@code initialCapacity} so all operations stay on the fast base path. */ public class DenseIntMap implements IntMap { - // locking strategy: - // - writelock to resize the array - // - readlock to update the array with put or remove - // - no lock to read the array, volatile is enough - private final ReadWriteLock rwl = new ReentrantReadWriteLock(); - private volatile AtomicReferenceArray objects; - private final AtomicInteger size; + private static final int OVERFLOW_SEGMENT_BITS = 10; + private static final int OVERFLOW_SEGMENT_SIZE = 1 << OVERFLOW_SEGMENT_BITS; + private static final int OVERFLOW_SEGMENT_MASK = OVERFLOW_SEGMENT_SIZE - 1; + + /** Immortal base array. Sized at construction and never reassigned. */ + private final AtomicReferenceArray base; + /** Cached {@code base.length()} so the hot path avoids the extra volatile read. */ + private final int baseLen; + + /** Lazily-installed segmented overflow for keys at or beyond {@link #baseLen}. */ + private volatile Overflow overflow; + private final Object overflowInitLock = new Object(); + + private final AtomicInteger size = new AtomicInteger(); public DenseIntMap(int initialCapacity) { - objects = new AtomicReferenceArray<>(initialCapacity); - size = new AtomicInteger(); + if (initialCapacity <= 0) { + throw new IllegalArgumentException("initialCapacity must be positive, got " + initialCapacity); + } + this.base = new AtomicReferenceArray<>(initialCapacity); + this.baseLen = initialCapacity; + } + + @Override + public T get(int key) { + if (key < 0) { + return null; + } + if (key < baseLen) { + return base.get(key); + } + Overflow o = overflow; + if (o == null) { + return null; + } + return o.get(key - baseLen); } @Override @@ -50,19 +88,17 @@ public boolean compareAndPut(int key, T existing, T value) { if (value == null) { throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead"); } - - ensureCapacity(key); - rwl.readLock().lock(); - try { - var success = objects.compareAndSet(key, existing, value); - var isInsert = success && existing == null; - if (isInsert) { + if (key < 0) { + throw new IndexOutOfBoundsException("key must be non-negative, got " + key); + } + if (key < baseLen) { + boolean success = base.compareAndSet(key, existing, value); + if (success && existing == null) { size.incrementAndGet(); } return success; - } finally { - rwl.readLock().unlock(); } + return overflowForWrite().compareAndPut(key - baseLen, existing, value, size); } @Override @@ -71,70 +107,161 @@ public int size() { } @Override - public T get(int key) { - if (key >= objects.length()) { + public T remove(int key) { + if (key < 0) { return null; } + if (key < baseLen) { + T old = base.get(key); + if (old == null) { + return null; + } + if (base.compareAndSet(key, old, null)) { + size.decrementAndGet(); + return old; + } + return null; + } + Overflow o = overflow; + if (o == null) { + return null; + } + return o.remove(key - baseLen, size); + } - return objects.get(key); + @Override + public boolean containsKey(int key) { + return get(key) != null; } - private void ensureCapacity(int node) { - if (node < objects.length()) { - return; + @Override + public void forEach(IntBiConsumer consumer) { + for (int i = 0; i < baseLen; i++) { + T value = base.get(i); + if (value != null) { + consumer.consume(i, value); + } + } + Overflow o = overflow; + if (o != null) { + o.forEach(baseLen, consumer); } + } - rwl.writeLock().lock(); - try { - var oldArray = objects; - if (node >= oldArray.length()) { - int newSize = ArrayUtil.oversize(node + 1, RamUsageEstimator.NUM_BYTES_OBJECT_REF); - var newArray = new AtomicReferenceArray(newSize); - for (int i = 0; i < oldArray.length(); i++) { - newArray.set(i, oldArray.get(i)); - } - objects = newArray; + private Overflow overflowForWrite() { + Overflow o = overflow; + if (o != null) { + return o; + } + synchronized (overflowInitLock) { + if (overflow == null) { + overflow = new Overflow<>(); } - } finally { - rwl.writeLock().unlock(); + return overflow; } } - @Override - public T remove(int key) { - if (key >= objects.length()) { - return null; + /** + * Segmented overflow tier for keys whose slot is not in {@link DenseIntMap#base}. + * Keys are stored at relative offsets from {@code baseLen} (so the first overflow slot is + * relKey 0). Segments are fixed-size and immortal once installed; the spine itself grows + * under {@link #spineLock} but the hot path never touches it. + */ + private static final class Overflow { + private volatile AtomicReferenceArray> spine; + private final Object spineLock = new Object(); + + Overflow() { + this.spine = new AtomicReferenceArray<>(1); } - var old = objects.get(key); - if (old == null) { - return null; + + T get(int relKey) { + AtomicReferenceArray> spineSnap = spine; + int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS; + if (segIdx >= spineSnap.length()) { + return null; + } + AtomicReferenceArray seg = spineSnap.get(segIdx); + if (seg == null) { + return null; + } + return seg.get(relKey & OVERFLOW_SEGMENT_MASK); } - rwl.readLock().lock(); - try { - if (objects.compareAndSet(key, old, null)) { + boolean compareAndPut(int relKey, T existing, T value, AtomicInteger size) { + AtomicReferenceArray seg = segmentFor(relKey); + boolean success = seg.compareAndSet(relKey & OVERFLOW_SEGMENT_MASK, existing, value); + if (success && existing == null) { + size.incrementAndGet(); + } + return success; + } + + T remove(int relKey, AtomicInteger size) { + AtomicReferenceArray> spineSnap = spine; + int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS; + if (segIdx >= spineSnap.length()) { + return null; + } + AtomicReferenceArray seg = spineSnap.get(segIdx); + if (seg == null) { + return null; + } + int slot = relKey & OVERFLOW_SEGMENT_MASK; + T old = seg.get(slot); + if (old == null) { + return null; + } + if (seg.compareAndSet(slot, old, null)) { size.decrementAndGet(); return old; - } else { - return null; } - } finally { - rwl.readLock().unlock(); + return null; } - } - @Override - public boolean containsKey(int key) { - return get(key) != null; - } + void forEach(int baseOffset, IntBiConsumer consumer) { + AtomicReferenceArray> spineSnap = spine; + for (int s = 0; s < spineSnap.length(); s++) { + AtomicReferenceArray seg = spineSnap.get(s); + if (seg == null) { + continue; + } + int segBase = s << OVERFLOW_SEGMENT_BITS; + for (int i = 0; i < OVERFLOW_SEGMENT_SIZE; i++) { + T value = seg.get(i); + if (value != null) { + consumer.consume(baseOffset + segBase + i, value); + } + } + } + } - @Override - public void forEach(IntBiConsumer consumer) { - var ref = objects; - for (int i = 0; i < ref.length(); i++) { - var value = get(i); - if (value != null) { - consumer.consume(i, value); + private AtomicReferenceArray segmentFor(int relKey) { + int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS; + AtomicReferenceArray> spineSnap = spine; + if (segIdx < spineSnap.length()) { + AtomicReferenceArray seg = spineSnap.get(segIdx); + if (seg != null) { + return seg; + } + } + synchronized (spineLock) { + spineSnap = spine; + if (segIdx >= spineSnap.length()) { + int newLen = Math.max(spineSnap.length() * 2, segIdx + 1); + AtomicReferenceArray> next = new AtomicReferenceArray<>(newLen); + for (int i = 0; i < spineSnap.length(); i++) { + next.set(i, spineSnap.get(i)); + } + spineSnap = next; + spine = next; + } + AtomicReferenceArray seg = spineSnap.get(segIdx); + if (seg == null) { + seg = new AtomicReferenceArray<>(OVERFLOW_SEGMENT_SIZE); + spineSnap.set(segIdx, seg); + } + return seg; } } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index 59b248584..c8f15ea87 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -127,6 +127,56 @@ public void testRescore(boolean addHierarchy) { } + @Test + public void testInitialCapacityHintProducesEquivalentGraph() { + int dimension = 8; + int size = 200; + var ravv = MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, getRandom())); + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.COSINE); + + var builderDefault = new GraphIndexBuilder(bsp, dimension, 8, 10, 1.2f, 1.0f, true, false); + var graphDefault = TestUtil.buildSequentially(builderDefault, ravv); + + // Same build, but the caller passes an initialCapacity hint equal to the known vector count. + // This should produce an identical graph — the hint only changes internal sizing. + var builderHinted = new GraphIndexBuilder(bsp, + dimension, + java.util.List.of(8), + 10, + 1.2f, + 1.0f, + true, + false, + java.util.concurrent.ForkJoinPool.commonPool(), + java.util.concurrent.ForkJoinPool.commonPool(), + size); + var graphHinted = TestUtil.buildSequentially(builderHinted, ravv); + + assertEquals(graphDefault.size(0), graphHinted.size(0)); + for (int i = 0; i < ravv.size(); i++) { + assertTrue(graphDefault.containsNode(i)); + assertTrue(graphHinted.containsNode(i)); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testInitialCapacityHintRejectsNonPositive() { + int dimension = 4; + var ravv = MockVectorValues.fromValues(createRandomFloatVectors(8, dimension, getRandom())); + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.COSINE); + new GraphIndexBuilder(bsp, + dimension, + java.util.List.of(4), + 10, + 1.0f, + 1.0f, + false, + false, + java.util.concurrent.ForkJoinPool.commonPool(), + java.util.concurrent.ForkJoinPool.commonPool(), + 0); + } + @Test public void testSaveAndLoad() throws IOException { int dimension = randomIntBetween(2, 32); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMapSegmented.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMapSegmented.java new file mode 100644 index 000000000..396623378 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestDenseIntMapSegmented.java @@ -0,0 +1,225 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.github.jbellis.jvector.util; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Extra coverage for the segmented DenseIntMap implementation, exercising: + * - cross-segment-boundary writes (keys around multiples of 1024) + * - concurrent inserts that force the spine to grow from a small initial capacity + * - concurrent inserts + removes on the same dense key range + * - forEach ascending-key iteration and visibility of prior writes + * + * Complements the existing {@link TestIntMap} tests which cover both Dense and Sparse + * implementations against a small key range. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestDenseIntMapSegmented extends RandomizedTest { + + @Test + public void testCrossSegmentBoundary() { + // SEGMENT_SIZE is 1024 internally; pick keys straddling the first few boundaries + // so a successful round trip proves routing and ascending iteration work across segments. + DenseIntMap map = new DenseIntMap<>(1); + int[] keys = {0, 1023, 1024, 2047, 2048, 4096, 10_000}; + for (int k : keys) { + Assert.assertTrue(map.compareAndPut(k, null, "v" + k)); + } + Assert.assertEquals(keys.length, map.size()); + for (int k : keys) { + Assert.assertEquals("v" + k, map.get(k)); + Assert.assertTrue(map.containsKey(k)); + } + + // forEach visits in ascending key order + List visited = new ArrayList<>(); + map.forEach((key, value) -> { + visited.add(key); + Assert.assertEquals("v" + key, value); + }); + Assert.assertEquals(keys.length, visited.size()); + for (int i = 1; i < visited.size(); i++) { + Assert.assertTrue("forEach must iterate ascending", visited.get(i) > visited.get(i - 1)); + } + } + + /** + * Under a tiny initial capacity, many concurrent writers across a wide key range force + * the spine to grow repeatedly. Every inserted key must be visible afterwards, and the + * {@code size()} counter must match the number of successful inserts. Regressions here + * would indicate a lost write during spine growth. + */ + @Test + public void testConcurrentInsertForcesSpineGrowth() throws InterruptedException { + DenseIntMap map = new DenseIntMap<>(1); // starts with a single segment + int nThreads = 16; + int perThread = 5_000; + int totalKeys = nThreads * perThread; + + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(nThreads); + AtomicInteger successes = new AtomicInteger(); + + for (int t = 0; t < nThreads; t++) { + final int threadId = t; + new Thread(() -> { + try { + start.await(); + for (int i = 0; i < perThread; i++) { + int key = threadId * perThread + i; + if (map.compareAndPut(key, null, key)) { + successes.incrementAndGet(); + } + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } finally { + done.countDown(); + } + }, "inserter-" + t).start(); + } + start.countDown(); + done.await(); + + Assert.assertEquals(totalKeys, successes.get()); + Assert.assertEquals(totalKeys, map.size()); + for (int k = 0; k < totalKeys; k++) { + Assert.assertEquals("missing key " + k, Integer.valueOf(k), map.get(k)); + } + } + + /** + * Concurrent {@code compareAndPut(null, v)} races on the same key must give exactly one + * winner. A bug where two threads both observe the slot as null and both increment + * {@code size} would show up as an inflated size. + */ + @Test + public void testConcurrentInsertSameKey() throws InterruptedException { + for (int trial = 0; trial < 20; trial++) { + DenseIntMap map = new DenseIntMap<>(1); + int nThreads = 8; + int sameKey = 12_345; // far enough to require spine growth and segment install + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(nThreads); + AtomicInteger winners = new AtomicInteger(); + + for (int t = 0; t < nThreads; t++) { + final int tid = t; + new Thread(() -> { + try { + start.await(); + if (map.compareAndPut(sameKey, null, "tid=" + tid)) { + winners.incrementAndGet(); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } finally { + done.countDown(); + } + }).start(); + } + start.countDown(); + done.await(); + + Assert.assertEquals("exactly one thread must win the initial CAS", 1, winners.get()); + Assert.assertEquals(1, map.size()); + Assert.assertNotNull(map.get(sameKey)); + } + } + + /** + * Concurrent insert + remove on disjoint key subsets: the remover handles its own keys + * after the inserter has populated them. Final state is deterministic: all inserted keys + * present, none of the removed-then-re-inserted keys missing. + */ + @Test + public void testConcurrentInsertAndRemove() throws InterruptedException { + DenseIntMap map = new DenseIntMap<>(1); + int keys = 20_000; + // First populate + for (int k = 0; k < keys; k++) { + Assert.assertTrue(map.compareAndPut(k, null, k)); + } + Assert.assertEquals(keys, map.size()); + + int nThreads = 8; + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(nThreads); + for (int t = 0; t < nThreads; t++) { + final int tid = t; + new Thread(() -> { + try { + start.await(); + // Each thread touches its own slice: remove every 2nd, re-insert, remove again. + for (int k = tid; k < keys; k += nThreads) { + Integer removed = map.remove(k); + Assert.assertEquals(Integer.valueOf(k), removed); + Assert.assertTrue(map.compareAndPut(k, null, k)); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } finally { + done.countDown(); + } + }).start(); + } + start.countDown(); + done.await(); + + Assert.assertEquals(keys, map.size()); + for (int k = 0; k < keys; k++) { + Assert.assertEquals(Integer.valueOf(k), map.get(k)); + } + } + + /** + * The initialCapacity hint must actually widen the spine so that keys up to the hint + * are reachable without any spine grow. The contract is semantic (the map works at + * that size), not structural — we don't expose internal counters — but this test + * serves as a regression guard against the hint being silently ignored. + */ + @Test + public void testInitialCapacityHintSupportsInsertsUpToHint() { + int capacity = 4 * 1024 + 17; // more than one segment's worth, with a remainder + DenseIntMap map = new DenseIntMap<>(capacity); + for (int k = 0; k < capacity; k++) { + Assert.assertTrue("insert failed at k=" + k, map.compareAndPut(k, null, k)); + } + Assert.assertEquals(capacity, map.size()); + for (int k = 0; k < capacity; k++) { + Assert.assertEquals(Integer.valueOf(k), map.get(k)); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testRejectsZeroInitialCapacity() { + new DenseIntMap(0); + } + + @Test(expected = IllegalArgumentException.class) + public void testRejectsNegativeInitialCapacity() { + new DenseIntMap(-1); + } + + @Test(expected = IllegalArgumentException.class) + public void testRejectsNullValue() { + DenseIntMap map = new DenseIntMap<>(16); + map.compareAndPut(0, null, null); + } +}