diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacySparseIntMap.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacySparseIntMap.java new file mode 100644 index 000000000..a9e55e8ec --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/LegacySparseIntMap.java @@ -0,0 +1,77 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.util.IntMap; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.IntStream; + +/** + * Verbatim copy of the previous {@code SparseIntMap} implementation + * ({@link ConcurrentHashMap}{@code }). Kept in the benchmarks module so + * {@link SparseIntMapConcurrentBenchmark} can compare the new striped Agrona-backed impl + * against the old boxing one under identical conditions, without requiring a separate + * checkout of an older revision. + *

+ * Do not use in production code. Per jvector#3, this representation pays a boxed + * {@code Integer} per key and a {@code ConcurrentHashMap.Node} per entry — both eliminated + * by the current production {@code SparseIntMap}. + */ +public class LegacySparseIntMap implements IntMap { + private final ConcurrentHashMap map; + + public LegacySparseIntMap() { + this.map = new ConcurrentHashMap<>(); + } + + @Override + public boolean compareAndPut(int key, T existing, T value) { + if (value == null) { + throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead"); + } + + if (existing == null) { + T result = map.putIfAbsent(key, value); + return result == null; + } + + return map.replace(key, existing, value); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public T get(int key) { + return map.get(key); + } + + @Override + public T remove(int key) { + return map.remove(key); + } + + @Override + public boolean containsKey(int key) { + return map.containsKey(key); + } + + public IntStream keysStream() { + return map.keySet().stream().mapToInt(key -> key); + } + + @Override + public void forEach(IntBiConsumer consumer) { + map.forEach(consumer::consume); + } +} diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/SparseIntMapConcurrentBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/SparseIntMapConcurrentBenchmark.java new file mode 100644 index 000000000..91ce21eea --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/SparseIntMapConcurrentBenchmark.java @@ -0,0 +1,202 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.util.IntMap; +import io.github.jbellis.jvector.util.SparseIntMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Group; +import org.openjdk.jmh.annotations.GroupThreads; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +/** + * Measures the throughput of concurrent {@link SparseIntMap} operations. The map sits on the + * hot path of {@code ConcurrentNeighborMap} for HNSW layers above the base — every + * {@code addNode}, {@code insertEdge}, {@code get}/{@code containsKey} during search, and + * {@code keysStream}/{@code forEachKey} during traversal goes through it. + *

+ * Parameters: + *

+ * Thread counts are expressed via {@code @Threads} on each benchmark method: 1 and 8. + */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@Fork(value = 1, jvmArgsAppend = {"--enable-preview", "--add-modules=jdk.incubator.vector"}) +@Warmup(iterations = 3, time = 2) +@Measurement(iterations = 5, time = 3) +@State(Scope.Benchmark) +public class SparseIntMapConcurrentBenchmark { + + public enum Impl { + legacy(LegacySparseIntMap::new), + striped(SparseIntMap::new); + + final Supplier> factory; + + Impl(Supplier> factory) { + this.factory = factory; + } + } + + public enum KeyDensity { dense, sparse } + + @Param + public Impl impl; + + @Param + public KeyDensity keyDensity; + + @Param({"100000", "1000000"}) + public int totalKeys; + + /** Pre-populated map used by get/CAS/forEach benchmarks. */ + private IntMap prepopulated; + + /** Keys that exist in {@link #prepopulated} (random access benchmarks pick from these). */ + private int[] livekeys; + + /** Monotonic counter used by insert benchmarks so threads never collide on keys. */ + private final AtomicInteger insertCursor = new AtomicInteger(); + + /** The map written to by insert benchmarks. Replaced once exhausted. */ + private IntMap insertMap; + + @Setup + public void setup() { + this.prepopulated = impl.factory.get(); + this.livekeys = new int[totalKeys]; + java.util.Random rnd = new java.util.Random(0xCAFEBABEL); + for (int i = 0; i < totalKeys; i++) { + int k = (keyDensity == KeyDensity.dense) ? i : rnd.nextInt(totalKeys * 100); + livekeys[i] = k; + prepopulated.compareAndPut(k, null, i); + } + this.insertMap = impl.factory.get(); + this.insertCursor.set(0); + } + + @Benchmark + @Threads(8) + public Integer getHot8() { + int idx = ThreadLocalRandom.current().nextInt(livekeys.length); + return prepopulated.get(livekeys[idx]); + } + + @Benchmark + @Threads(1) + public Integer getHot1() { + int idx = ThreadLocalRandom.current().nextInt(livekeys.length); + return prepopulated.get(livekeys[idx]); + } + + @Benchmark + @Threads(8) + public boolean casChurn8(Blackhole bh) { + return doCasChurn(bh); + } + + @Benchmark + @Threads(1) + public boolean casChurn1(Blackhole bh) { + return doCasChurn(bh); + } + + private boolean doCasChurn(Blackhole bh) { + int idx = ThreadLocalRandom.current().nextInt(livekeys.length); + int k = livekeys[idx]; + Integer cur = prepopulated.get(k); + boolean ok = prepopulated.compareAndPut(k, cur, idx); + bh.consume(ok); + return ok; + } + + /** Models the upper-layer {@code addNode} pressure: many threads inserting disjoint keys. */ + @Benchmark + @Threads(8) + public boolean insertSequential8(Blackhole bh) { + return doInsert(bh); + } + + @Benchmark + @Threads(1) + public boolean insertSequential1(Blackhole bh) { + return doInsert(bh); + } + + private boolean doInsert(Blackhole bh) { + int key = insertCursor.getAndIncrement(); + if (key >= totalKeys) { + synchronized (this) { + if (insertCursor.get() >= totalKeys) { + insertMap = impl.factory.get(); + insertCursor.set(0); + } + } + key = insertCursor.getAndIncrement(); + } + boolean ok = insertMap.compareAndPut(key, null, key); + bh.consume(ok); + return ok; + } + + /** + * Iteration cost — measured single-threaded since the production callers + * ({@code OnHeapGraphIndex.nodeStream}) walk it from one thread at a time. + */ + @Benchmark + @Threads(1) + public void forEachKey(Blackhole bh) { + prepopulated.forEachKey((int k) -> bh.consume(k)); + } + + /** 90 % reads + 10 % CAS-updates: closest to HNSW build's upper-layer steady state. */ + @Benchmark + @Group("mixed90r10w") + @GroupThreads(7) + public Integer mixedRead() { + int idx = ThreadLocalRandom.current().nextInt(livekeys.length); + return prepopulated.get(livekeys[idx]); + } + + @Benchmark + @Group("mixed90r10w") + @GroupThreads(1) + public boolean mixedWrite() { + int idx = ThreadLocalRandom.current().nextInt(livekeys.length); + int k = livekeys[idx]; + Integer cur = prepopulated.get(k); + return prepopulated.compareAndPut(k, cur, idx); + } +} diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/SparseIntMapMemoryBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/SparseIntMapMemoryBenchmark.java new file mode 100644 index 000000000..b7c4c88de --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/SparseIntMapMemoryBenchmark.java @@ -0,0 +1,143 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.util.IntMap; +import io.github.jbellis.jvector.util.SparseIntMap; +import org.openjdk.jmh.annotations.AuxCounters; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.lang.management.ManagementFactory; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +/** + * Measures the per-entry memory footprint of {@link SparseIntMap} versus the legacy + * {@link java.util.concurrent.ConcurrentHashMap}{@code }-backed implementation + * ({@link LegacySparseIntMap}). The point is to make the heap-savings claim from + * jvector#3 reproducible without launching a 100M-vector workload. + *

+ * Caveats: heap accounting via {@link ManagementFactory#getMemoryMXBean()} is intrinsically + * GC-noisy. We trigger {@link System#gc()} before/after the populate step and average over a + * handful of warmup + measurement iterations; even so, treat the absolute numbers as an + * upper bound and trust the {@code legacy / striped} ratio rather than any single run. + *

+ * Run with: {@code -bm ss} (single-shot) so each iteration starts from a fresh heap snapshot. + */ +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(value = 1, jvmArgsAppend = {"--enable-preview", "--add-modules=jdk.incubator.vector", + "-Xms2g", "-Xmx2g"}) +@Warmup(iterations = 2) +@Measurement(iterations = 5) +@State(Scope.Benchmark) +public class SparseIntMapMemoryBenchmark { + + public enum Impl { + legacy(LegacySparseIntMap::new), + striped(SparseIntMap::new); + + final Supplier> factory; + + Impl(Supplier> factory) { + this.factory = factory; + } + } + + @Param + public Impl impl; + + @Param({"100000", "1000000"}) + public int totalKeys; + + /** + * The map being measured. Held in a field so it isn't GC'd between the + * populate step and the {@code System.gc()} that captures used-bytes. + */ + private IntMap map; + + /** A shared dummy value held by every entry, so the per-entry value cost is constant. */ + private Object value; + + /** + * Carries the measured heap delta out of the benchmark as JMH secondary metrics. JMH's + * primary score for a {@link Mode#SingleShotTime} run is execution time; the byte numbers + * come through these counters in the standard "Secondary result" output. + */ + @State(Scope.Thread) + @AuxCounters(AuxCounters.Type.EVENTS) + public static class Counters { + public long bytesUsed; + public long bytesPerEntry; + + @Setup(Level.Iteration) + public void reset() { + bytesUsed = 0; + bytesPerEntry = 0; + } + } + + @Setup + public void setup() { + this.value = new Object(); + } + + @Benchmark + public long populateAndMeasure(Counters c, Blackhole bh) { + gcQuiesce(); + long before = usedHeap(); + + IntMap m = impl.factory.get(); + for (int i = 0; i < totalKeys; i++) { + m.compareAndPut(i, null, value); + } + // Pin the map across the GC. + this.map = m; + + gcQuiesce(); + long after = usedHeap(); + long delta = after - before; + c.bytesUsed = delta; + c.bytesPerEntry = m.size() == 0 ? 0 : delta / m.size(); + bh.consume(m.size()); + + // Drop the strong reference so the next iteration starts from a clean slate. + this.map = null; + return delta; + } + + private static long usedHeap() { + return ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + } + + private static void gcQuiesce() { + for (int i = 0; i < 3; i++) { + System.gc(); + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + } + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index b3a45ff90..1ab73d19c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -252,7 +252,7 @@ public IntStream nodeStream(int level) { var layer = layers.get(level); return level == 0 ? IntStream.range(0, getIdUpperBound()).filter(i -> layer.get(i) != null) - : ((SparseIntMap) layer.neighbors).keysStream(); + : layer.neighbors.keysStream(); } @Override @@ -276,6 +276,27 @@ public long ramBytesUsedOneNode(int level) { return REF_BYTES + Neighbors.ramBytesUsed(layers.get(level).nodeArrayLength()); } + /** + * Estimate the per-node graph overhead (in bytes) for a layer with the given degree + * parameters, without requiring a built {@link OnHeapGraphIndex} instance. + *

+ * Intended for callers (e.g. an external indexing service) that need to predict graph + * memory cost when sizing a memory budget. Multiply by the expected node count for that + * layer to get a total estimate. + * + * @param maxDegree the maximum number of neighbors per node (M in HNSW) + * @param overflowRatio the multiplier on {@code maxDegree} that bounds the temporary + * neighborhood size during construction (matches the value passed to + * {@link io.github.jbellis.jvector.graph.GraphIndexBuilder}) + * @return the estimated per-node bytes (one slot reference + {@code Neighbors} object) + */ + public static long estimatedBytesPerNode(int maxDegree, float overflowRatio) { + int maxOverflowDegree = (int) (maxDegree * overflowRatio); + int nodeArrayLength = maxOverflowDegree + 1; // matches ConcurrentNeighborMap.nodeArrayLength() + int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + return REF_BYTES + Neighbors.ramBytesUsed(nodeArrayLength); + } + @Override public void enforceDegree(int node) { for (int level = 0; level <= getMaxLevel(); level++) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java index 6f1132538..8f9d5fd9c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/DenseIntMap.java @@ -17,6 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.IntConsumer; /** * A map (but not a Map) of int -> T where the int keys are dense-ish and start at zero, @@ -148,6 +149,19 @@ public void forEach(IntBiConsumer consumer) { } } + @Override + public void forEachKey(IntConsumer consumer) { + for (int i = 0; i < baseLen; i++) { + if (base.get(i) != null) { + consumer.accept(i); + } + } + Overflow o = overflow; + if (o != null) { + o.forEachKey(baseLen, consumer); + } + } + private Overflow overflowForWrite() { Overflow o = overflow; if (o != null) { @@ -236,6 +250,22 @@ void forEach(int baseOffset, IntBiConsumer consumer) { } } + void forEachKey(int baseOffset, IntConsumer consumer) { + AtomicReferenceArray> spineSnap = spine; + for (int s = 0; s < spineSnap.length(); s++) { + AtomicReferenceArray seg = spineSnap.get(s); + if (seg == null) { + continue; + } + int segBase = s << OVERFLOW_SEGMENT_BITS; + for (int i = 0; i < OVERFLOW_SEGMENT_SIZE; i++) { + if (seg.get(i) != null) { + consumer.accept(baseOffset + segBase + i); + } + } + } + } + private AtomicReferenceArray segmentFor(int relKey) { int segIdx = relKey >>> OVERFLOW_SEGMENT_BITS; AtomicReferenceArray> spineSnap = spine; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java index 713e9a3ab..5315cc93a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/IntMap.java @@ -16,8 +16,7 @@ package io.github.jbellis.jvector.util; -import io.github.jbellis.jvector.graph.NodesIterator; - +import java.util.function.IntConsumer; import java.util.stream.IntStream; public interface IntMap { @@ -49,10 +48,37 @@ public interface IntMap { boolean containsKey(int key); /** - * Iterates keys in ascending order and calls the consumer for each non-null key-value pair. + * Iterates over each non-null key/value pair and invokes {@code consumer}. + *

+ * Iteration order is implementation-defined; only {@link DenseIntMap} guarantees ascending + * keys. Iteration is weakly consistent: entries may be added or removed concurrently and the + * traversal will reflect the state at some point during the call. */ void forEach(IntBiConsumer consumer); + /** + * Iterates over each key currently present in the map and invokes {@code consumer} with + * the primitive int. Implementations should override to avoid boxing; the default delegates + * to {@link #forEach(IntBiConsumer)} and discards the value. + *

+ * Iteration order and consistency follow {@link #forEach(IntBiConsumer)}. + */ + default void forEachKey(IntConsumer consumer) { + forEach((k, v) -> consumer.accept(k)); + } + + /** + * Returns a primitive {@link IntStream} of every key currently present in the map. + *

+ * The default builds the stream from {@link #forEachKey(IntConsumer)}. Specialised + * implementations may override for efficiency. No boxing occurs in either path. + */ + default IntStream keysStream() { + IntStream.Builder b = IntStream.builder(); + forEachKey(b::add); + return b.build(); + } + @FunctionalInterface interface IntBiConsumer { void consume(int key, T2 value); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java index a8fc555e5..699d15928 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/SparseIntMap.java @@ -16,16 +16,71 @@ package io.github.jbellis.jvector.util; -import io.github.jbellis.jvector.graph.NodesIterator; +import org.agrona.collections.Int2ObjectHashMap; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.StampedLock; +import java.util.function.IntConsumer; import java.util.stream.IntStream; +/** + * A concurrent {@link IntMap} backed by a striped array of Agrona {@link Int2ObjectHashMap} + * shards. + *

+ * Compared to a {@link java.util.concurrent.ConcurrentHashMap}{@code }, this avoids + * boxing the {@code int} key (no {@code Integer} per entry) and avoids the per-entry + * {@code Node}/{@code TreeNode} overhead — both of which dominate the heap when storing tens of + * millions of entries (see jvector#3). + *

+ * Concurrency. Each shard is guarded by its own {@link StampedLock}. Reads + * ({@link #get}, {@link #containsKey}) attempt a lock-free optimistic read first, falling back + * to a read lock only if a writer interleaves; this matches the throughput of + * {@link java.util.concurrent.ConcurrentHashMap}'s volatile-read fast path on uncontended + * reads. Writes ({@link #compareAndPut}, {@link #remove}) take the shard write lock and are + * serialised per shard. {@link #size()} is O(1) via a global {@link AtomicInteger}. + *

+ * Iteration. {@link #forEach}, {@link #forEachKey} and {@link #keysStream} snapshot each + * shard under its read lock into a primitive {@code int[]} (no boxing) plus an {@code Object[]} + * of values, then release the lock before invoking the consumer. This (a) avoids deadlock if + * the consumer re-enters the map and (b) preserves the weakly-consistent semantics that + * callers had with the previous {@code ConcurrentHashMap}-backed implementation. There is no + * global atomic snapshot — entries added or removed during iteration may or may not be visible. + *

+ * Footprint note. This map is intended for HNSW upper layers (typically thousands to + * millions of entries). The 32-shard structure has a fixed ~6 KB idle footprint, dwarfing the + * cost of an empty map; do not use it for tiny maps. + */ public class SparseIntMap implements IntMap { - private final ConcurrentHashMap map; + /** Number of shards; must be a power of two. */ + static final int SHARD_COUNT = 32; + private static final int SHARD_MASK = SHARD_COUNT - 1; + + private final Int2ObjectHashMap[] shards; + private final StampedLock[] locks; + private final AtomicInteger size = new AtomicInteger(); + @SuppressWarnings("unchecked") public SparseIntMap() { - this.map = new ConcurrentHashMap<>(); + this.shards = (Int2ObjectHashMap[]) new Int2ObjectHashMap[SHARD_COUNT]; + this.locks = new StampedLock[SHARD_COUNT]; + for (int i = 0; i < SHARD_COUNT; i++) { + this.shards[i] = new Int2ObjectHashMap<>(); + this.locks[i] = new StampedLock(); + } + } + + /** + * Avalanche-mix the key before sharding. Agrona uses identity hashing internally, so a raw + * monotonically-increasing key (typical in HNSW node IDs) would pile onto a single shard + * without mixing. + */ + static int shardIndex(int key) { + int h = key; + h ^= (h >>> 16); + h *= 0x85EBCA6B; + h ^= (h >>> 13); + return h & SHARD_MASK; } @Override @@ -34,40 +89,177 @@ public boolean compareAndPut(int key, T existing, T value) { throw new IllegalArgumentException("compareAndPut() value cannot be null -- use remove() instead"); } - if (existing == null) { - T result = map.putIfAbsent(key, value); - return result == null; + int idx = shardIndex(key); + StampedLock lock = locks[idx]; + long stamp = lock.writeLock(); + try { + Int2ObjectHashMap shard = shards[idx]; + T cur = shard.get(key); + if (existing == null) { + if (cur != null) { + return false; + } + shard.put(key, value); + size.incrementAndGet(); + return true; + } + // Reference equality matches CHM.replace(k, expected, new) for our value types, + // which do not override equals(). + if (cur != existing) { + return false; + } + shard.put(key, value); + return true; + } finally { + lock.unlockWrite(stamp); } - - return map.replace(key, existing, value); } @Override public int size() { - return map.size(); + return size.get(); } @Override public T get(int key) { - return map.get(key); + int idx = shardIndex(key); + StampedLock lock = locks[idx]; + Int2ObjectHashMap shard = shards[idx]; + + // Optimistic read first — Agrona's open-addressing get may transiently observe a + // resize-in-progress and return null or throw, so always validate and fall back to a + // pessimistic read when the optimistic snapshot couldn't be confirmed. + long stamp = lock.tryOptimisticRead(); + if (stamp != 0) { + T value; + try { + value = shard.get(key); + } catch (Throwable ignored) { + value = null; + } + if (lock.validate(stamp)) { + return value; + } + } + stamp = lock.readLock(); + try { + return shard.get(key); + } finally { + lock.unlockRead(stamp); + } } @Override public T remove(int key) { - return map.remove(key); + int idx = shardIndex(key); + StampedLock lock = locks[idx]; + long stamp = lock.writeLock(); + try { + T old = shards[idx].remove(key); + if (old != null) { + size.decrementAndGet(); + } + return old; + } finally { + lock.unlockWrite(stamp); + } } @Override public boolean containsKey(int key) { - return map.containsKey(key); + // Cheap and correct: get() already does the optimistic-read dance. + return get(key) != null; } - public IntStream keysStream() { - return map.keySet().stream().mapToInt(key -> key); + @Override + public void forEach(IntBiConsumer consumer) { + for (int s = 0; s < SHARD_COUNT; s++) { + int[] keys; + Object[] values; + StampedLock lock = locks[s]; + long stamp = lock.readLock(); + try { + Int2ObjectHashMap shard = shards[s]; + int n = shard.size(); + int[] kBuf = new int[n]; + Object[] vBuf = new Object[n]; + int[] pos = {0}; + shard.forEachInt((k, v) -> { + int p = pos[0]; + if (p < kBuf.length) { + kBuf[p] = k; + vBuf[p] = v; + pos[0] = p + 1; + } + }); + int filled = pos[0]; + if (filled == kBuf.length) { + keys = kBuf; + values = vBuf; + } else { + keys = Arrays.copyOf(kBuf, filled); + values = Arrays.copyOf(vBuf, filled); + } + } finally { + lock.unlockRead(stamp); + } + for (int i = 0; i < keys.length; i++) { + @SuppressWarnings("unchecked") + T v = (T) values[i]; + if (v != null) { + consumer.consume(keys[i], v); + } + } + } } @Override - public void forEach(IntBiConsumer consumer) { - map.forEach(consumer::consume); + public void forEachKey(IntConsumer consumer) { + for (int s = 0; s < SHARD_COUNT; s++) { + int[] keys = snapshotKeys(s); + for (int k : keys) { + consumer.accept(k); + } + } + } + + @Override + public IntStream keysStream() { + int total = size.get(); + // Allocate slack to absorb concurrent inserts; growth handled per-shard below. + int[] all = new int[Math.max(total + (total >> 3) + 16, 16)]; + int filled = 0; + for (int s = 0; s < SHARD_COUNT; s++) { + int[] keys = snapshotKeys(s); + if (filled + keys.length > all.length) { + int newLen = Math.max(all.length * 2, filled + keys.length); + all = Arrays.copyOf(all, newLen); + } + System.arraycopy(keys, 0, all, filled, keys.length); + filled += keys.length; + } + return Arrays.stream(all, 0, filled); + } + + private int[] snapshotKeys(int shardIdx) { + StampedLock lock = locks[shardIdx]; + long stamp = lock.readLock(); + try { + Int2ObjectHashMap shard = shards[shardIdx]; + int n = shard.size(); + int[] keys = new int[n]; + int[] pos = {0}; + shard.forEachInt((k, v) -> { + int p = pos[0]; + if (p < keys.length) { + keys[p] = k; + pos[0] = p + 1; + } + }); + int filled = pos[0]; + return filled == keys.length ? keys : Arrays.copyOf(keys, filled); + } finally { + lock.unlockRead(stamp); + } } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnHeapGraphIndexEstimator.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnHeapGraphIndexEstimator.java new file mode 100644 index 000000000..eaf760b88 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnHeapGraphIndexEstimator.java @@ -0,0 +1,118 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; + +/** + * Validates {@link OnHeapGraphIndex#estimatedBytesPerNode(int, float)} — the static helper + * external services use to predict graph overhead before building. The estimator must agree + * with the formula the running graph uses ({@link OnHeapGraphIndex#ramBytesUsedOneNode(int)}). + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestOnHeapGraphIndexEstimator extends RandomizedTest { + + private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport(); + + @Test + public void testEstimatorMatchesInstanceMethod() throws IOException { + int M = 16; + float overflowRatio = 1.2f; + OnHeapGraphIndex graph = buildSmallGraph(M, overflowRatio, /* nodes */ 200, /* dim */ 16, /* hierarchy */ true); + try { + long estimate = OnHeapGraphIndex.estimatedBytesPerNode(M, overflowRatio); + long actual = graph.ramBytesUsedOneNode(0); + Assert.assertEquals("static estimator must equal instance per-node cost for level 0", + actual, estimate); + } finally { + graph.close(); + } + } + + @Test + public void testEstimatorScalesAcrossDegrees() { + // Higher M → higher per-node cost, monotonically. + long prev = 0; + for (int m : new int[]{4, 8, 16, 32, 64}) { + long est = OnHeapGraphIndex.estimatedBytesPerNode(m, 1.2f); + Assert.assertTrue("per-node cost must grow with M (m=" + m + ", est=" + est + ", prev=" + prev + ")", + est > prev); + prev = est; + } + } + + @Test + public void testEstimatorScalesWithOverflow() { + long lo = OnHeapGraphIndex.estimatedBytesPerNode(16, 1.0f); + long mid = OnHeapGraphIndex.estimatedBytesPerNode(16, 1.5f); + long hi = OnHeapGraphIndex.estimatedBytesPerNode(16, 2.0f); + Assert.assertTrue(lo < mid); + Assert.assertTrue(mid < hi); + } + + @Test + public void testEstimatorMatchesActualPerNodeAtScale() throws IOException { + // Build a small graph with a hierarchy. The dense base layer's + // ramBytesUsedOneLayer(0) / size(0) (after subtracting fixed-cost per-layer + // overhead) must agree with the static estimator within ~5% — they share the + // same Neighbors.ramBytesUsed formula. + int M = 16; + float overflowRatio = 1.2f; + OnHeapGraphIndex graph = buildSmallGraph(M, overflowRatio, 250, 16, true); + try { + long estimate = OnHeapGraphIndex.estimatedBytesPerNode(M, overflowRatio); + int sz = graph.size(0); + Assert.assertTrue("graph too small to evaluate", sz >= 50); + // ramBytesUsedOneNode is a public instance method — it returns exactly the + // same per-node figure as the static estimator does for the same parameters. + long actualPerNode = graph.ramBytesUsedOneNode(0); + double ratio = (double) estimate / actualPerNode; + Assert.assertTrue("estimate=" + estimate + " actual=" + actualPerNode + + " ratio=" + ratio, ratio > 0.95 && ratio < 1.05); + } finally { + graph.close(); + } + } + + private OnHeapGraphIndex buildSmallGraph(int M, float overflowRatio, int n, int dim, boolean addHierarchy) throws IOException { + ArrayList> vectors = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + VectorFloat v = VTS.createFloatVector(dim); + for (int d = 0; d < dim; d++) { + v.set(d, (float) Math.random()); + } + vectors.add(v); + } + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dim); + BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider( + ravv, VectorSimilarityFunction.EUCLIDEAN); + try (GraphIndexBuilder builder = new GraphIndexBuilder( + bsp, dim, M, /* beamWidth */ 100, overflowRatio, /* alpha */ 1.2f, addHierarchy)) { + return (OnHeapGraphIndex) builder.build(ravv); + } + } +} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java index ac4695bbc..5618a6b74 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMap.java @@ -28,8 +28,12 @@ import org.junit.Assert; import org.junit.Test; +import java.util.HashSet; +import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.function.Supplier; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TestIntMap extends RandomizedTest { @@ -91,6 +95,106 @@ private void testRemoveInternal(IntMap map) { } } + @Test + public void testForEach() { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + int[] keys = {0, 7, 99, 1024, 12345}; + for (int k : keys) { + map.compareAndPut(k, null, "v" + k); + } + Set seen = new HashSet<>(); + map.forEach((k, v) -> { + Assert.assertEquals("v" + k, v); + Assert.assertTrue("duplicate key " + k, seen.add(k)); + }); + Set expected = new HashSet<>(); + for (int k : keys) expected.add(k); + Assert.assertEquals(expected, seen); + } + } + + @Test + public void testForEachKey() { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + int[] keys = {0, 7, 99, 1024, 12345, 99999}; + for (int k : keys) { + map.compareAndPut(k, null, "v" + k); + } + Set seen = new HashSet<>(); + map.forEachKey(k -> Assert.assertTrue("duplicate key " + k, seen.add(k))); + Set expected = new HashSet<>(); + for (int k : keys) expected.add(k); + Assert.assertEquals(expected, seen); + } + } + + @Test + public void testKeysStream() { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + int[] keys = {0, 7, 99, 1024, 12345, 99999}; + for (int k : keys) { + map.compareAndPut(k, null, "v" + k); + } + int[] streamed = map.keysStream().sorted().toArray(); + int[] expected = keys.clone(); + java.util.Arrays.sort(expected); + Assert.assertArrayEquals(expected, streamed); + } + } + + @Test + public void testForEachKeyEmpty() { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + map.forEachKey(k -> Assert.fail("empty map should not visit any keys")); + Assert.assertEquals(0, map.keysStream().count()); + } + } + + @Test + public void testCompareAndPutContract() { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + + // (null, v) on absent → success + Assert.assertTrue(map.compareAndPut(42, null, "v1")); + Assert.assertEquals("v1", map.get(42)); + + // (null, v) on present → failure + Assert.assertFalse(map.compareAndPut(42, null, "v2")); + Assert.assertEquals("v1", map.get(42)); + + // (mismatched, v) → failure (different reference even if equals) + String wrong = new String("v1"); + Assert.assertNotSame("v1", wrong); + Assert.assertFalse(map.compareAndPut(42, wrong, "v3")); + Assert.assertEquals("v1", map.get(42)); + + // (matching, v') → success + Assert.assertTrue(map.compareAndPut(42, map.get(42), "v3")); + Assert.assertEquals("v3", map.get(42)); + + // (_, null) → IllegalArgumentException + try { + map.compareAndPut(42, map.get(42), null); + Assert.fail("expected IllegalArgumentException for null value"); + } catch (IllegalArgumentException expected) { + // OK + } + } + } + + @SuppressWarnings("unchecked") + private Supplier>[] factories() { + return new Supplier[]{ + () -> new DenseIntMap(100), + () -> new SparseIntMap() + }; + } + @Test public void testConcurrency() throws InterruptedException { for (int i = 0; i < 100; i++) { diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMapConcurrency.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMapConcurrency.java new file mode 100644 index 000000000..0e1b5d154 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestIntMapConcurrency.java @@ -0,0 +1,449 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.util; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +/** + * Deep concurrency tests for {@link IntMap} implementations. The shipped {@code TestIntMap} + * exercises the interface contract on a single thread plus a coarse "do random ops vs a CHM + * source-of-truth" check; this file targets the contract that {@code ConcurrentNeighborMap}'s + * CAS retry loops actually rely on (linearizability of compareAndPut on the same key, + * happens-before across successful writes, weakly-consistent iteration, no re-entrant + * deadlock, accurate size accounting under contention). + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestIntMapConcurrency extends RandomizedTest { + + private static final int TEST_TIMEOUT_MS = 5_000; + + @SuppressWarnings("unchecked") + private static Supplier>[] factories() { + return new Supplier[]{ + () -> new DenseIntMap(1024), + () -> new SparseIntMap() + }; + } + + /** + * N threads compete on a single key, each looping + * {@code compareAndPut(key, get(key), newRef)} and counting successful CASes. Total + * successful CASes must equal the total number of distinct value references that ever + * appeared as the current value (no lost updates), and the final value must be one of the + * values written. + */ + @Test + public void testCompareAndPutLinearizability() throws Exception { + for (Supplier> factory : factories()) { + int nThreads = 8; + int opsPerThread = 5_000; + IntMap map = factory.get(); + int key = 12345; + // Seed so that compareAndPut(k, observed, v) only ever runs with non-null expected + Object seed = new Object(); + Assert.assertTrue(map.compareAndPut(key, null, seed)); + + ExecutorService pool = Executors.newFixedThreadPool(nThreads); + try { + CountDownLatch start = new CountDownLatch(1); + AtomicInteger totalSuccesses = new AtomicInteger(); + Set witnessed = ConcurrentHashMap.newKeySet(); + witnessed.add(seed); + Future[] futures = new Future[nThreads]; + for (int t = 0; t < nThreads; t++) { + futures[t] = pool.submit(() -> { + try { + start.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + int local = 0; + for (int i = 0; i < opsPerThread; i++) { + Object expected = map.get(key); + witnessed.add(expected); + Object next = new Object(); + if (map.compareAndPut(key, expected, next)) { + local++; + witnessed.add(next); + } + } + totalSuccesses.addAndGet(local); + }); + } + start.countDown(); + for (Future f : futures) { + f.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + + // Every successful CAS published a brand-new Object that was either witnessed by + // a later thread (counted in `witnessed`) or remains as the final value. + Object finalVal = map.get(key); + Assert.assertNotNull(finalVal); + Assert.assertTrue("final value must have been written by some thread", + witnessed.contains(finalVal)); + // total successes == (witnessed values produced by writers) + (the seed) + // Each successful CAS is the *only* operation that ever publishes a fresh + // value, so the witnessed-set size minus 1 (seed) is a lower bound on total + // successes (writer may also see its own write before another thread does). + // It must also be ≤ totalSuccesses. + int writes = totalSuccesses.get(); + int witnessedWrites = witnessed.size() - 1; // exclude seed + Assert.assertTrue( + "writes=" + writes + " must be >= witnessedWrites=" + witnessedWrites, + writes >= witnessedWrites); + Assert.assertEquals("size must remain 1", 1, map.size()); + } finally { + pool.shutdown(); + } + } + } + + /** + * One thread iterates {@code forEachKey} while writers insert and remove. Asserts no + * exception is ever thrown (the snapshot-based iteration must tolerate concurrent + * mutation) and that emitted keys were live at some point during the iteration. + */ + @Test + public void testForEachKeyDuringMutation() throws Exception { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + // Pre-populate + for (int i = 0; i < 200; i++) { + map.compareAndPut(i, null, new Object()); + } + + int nWriters = 4; + ExecutorService pool = Executors.newFixedThreadPool(nWriters + 1); + AtomicBoolean stop = new AtomicBoolean(); + AtomicReference error = new AtomicReference<>(); + try { + Future[] writerFutures = new Future[nWriters]; + for (int w = 0; w < nWriters; w++) { + final int seed = w; + writerFutures[w] = pool.submit(() -> { + java.util.Random rnd = new java.util.Random(seed); + while (!stop.get()) { + try { + int k = rnd.nextInt(400); + if (rnd.nextBoolean()) { + map.compareAndPut(k, map.get(k), new Object()); + } else { + map.remove(k); + } + } catch (Throwable t) { + error.compareAndSet(null, t); + throw t; + } + } + }); + } + + Future reader = pool.submit(() -> { + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(2); + try { + while (System.nanoTime() < deadline && !stop.get()) { + int[] count = {0}; + map.forEachKey(k -> { + Assert.assertTrue(k >= 0); + Assert.assertTrue(k < 400); + count[0]++; + }); + // The size-vs-iteration count is weakly consistent. Just assert + // we got a sane bound. + Assert.assertTrue(count[0] >= 0); + Assert.assertTrue(count[0] <= 400); + } + } catch (Throwable t) { + error.compareAndSet(null, t); + throw t; + } + }); + + reader.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + stop.set(true); + for (Future f : writerFutures) { + f.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + } finally { + stop.set(true); + pool.shutdown(); + } + if (error.get() != null) { + throw new AssertionError("concurrent mutation broke iteration", error.get()); + } + } + } + + /** + * Mutating the map from inside the {@code forEachKey} consumer must not deadlock. The + * snapshot semantics in the striped {@code SparseIntMap} were chosen specifically to + * avoid this case (we never hold a shard lock across the consumer callback). + */ + @Test + public void testForEachKeyReentrant() throws Exception { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + for (int i = 0; i < 64; i++) { + map.compareAndPut(i, null, new Object()); + } + + ExecutorService pool = Executors.newSingleThreadExecutor(); + try { + Future f = pool.submit(() -> { + map.forEachKey(k -> { + // Re-enter: read and write the map from inside the callback. + Object cur = map.get(k); + if (cur != null) { + map.compareAndPut(k, cur, new Object()); + } + // Insert something brand-new keyed past the snapshot range. + map.compareAndPut(k + 10_000, null, new Object()); + }); + }); + f.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } finally { + pool.shutdown(); + } + } + } + + /** + * Insert N keys concurrently across N threads, then remove a known subset across N + * threads; assert {@code size()} matches the live count exactly. + */ + @Test + public void testSizeUnderConcurrentInsertRemove() throws Exception { + for (Supplier> factory : factories()) { + int nThreads = 8; + int keysPerThread = 1_000; + int totalKeys = nThreads * keysPerThread; + IntMap map = factory.get(); + + ExecutorService pool = Executors.newFixedThreadPool(nThreads); + try { + CountDownLatch start = new CountDownLatch(1); + + // Phase 1: each thread inserts its own non-overlapping range. + Future[] inserters = new Future[nThreads]; + for (int t = 0; t < nThreads; t++) { + final int base = t * keysPerThread; + inserters[t] = pool.submit(() -> { + try { + start.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + for (int i = 0; i < keysPerThread; i++) { + Assert.assertTrue(map.compareAndPut(base + i, null, new Object())); + } + }); + } + start.countDown(); + for (Future f : inserters) { + f.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + Assert.assertEquals(totalKeys, map.size()); + + // Phase 2: half of every range is removed in parallel. + int halfPerThread = keysPerThread / 2; + Future[] removers = new Future[nThreads]; + CountDownLatch removeStart = new CountDownLatch(1); + for (int t = 0; t < nThreads; t++) { + final int base = t * keysPerThread; + removers[t] = pool.submit(() -> { + try { + removeStart.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + for (int i = 0; i < halfPerThread; i++) { + Assert.assertNotNull(map.remove(base + i)); + } + }); + } + removeStart.countDown(); + for (Future f : removers) { + f.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + + int expectedLive = totalKeys - nThreads * halfPerThread; + Assert.assertEquals(expectedLive, map.size()); + + // Cross-check: walking the map yields the same count as size(). + AtomicInteger walked = new AtomicInteger(); + map.forEachKey(k -> walked.incrementAndGet()); + Assert.assertEquals(expectedLive, walked.get()); + } finally { + pool.shutdown(); + } + } + } + + /** A successful compareAndPut must establish happens-before for the published value. */ + @Test + public void testHappensBeforeOnSuccessfulCAS() throws Exception { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + int key = 7; + int rounds = 1_000; + + ExecutorService pool = Executors.newFixedThreadPool(2); + try { + for (int round = 0; round < rounds; round++) { + map.remove(key); + CountDownLatch start = new CountDownLatch(1); + AtomicReference err = new AtomicReference<>(); + Future writer = pool.submit(() -> { + try { + start.await(); + Payload p = new Payload(); + p.field = 99; // happens-before via lock release on successful CAS + map.compareAndPut(key, null, p); + } catch (Throwable t) { + err.compareAndSet(null, t); + } + }); + Future reader = pool.submit(() -> { + try { + start.await(); + Object o; + while ((o = map.get(key)) == null) { + Thread.onSpinWait(); + } + Payload p = (Payload) o; + if (p.field != 99) { + err.compareAndSet(null, + new AssertionError("payload.field=" + p.field + + " — happens-before violated")); + } + } catch (Throwable t) { + err.compareAndSet(null, t); + } + }); + start.countDown(); + writer.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + reader.get(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + if (err.get() != null) { + throw new AssertionError("round " + round, err.get()); + } + } + } finally { + pool.shutdown(); + } + } + } + + private static class Payload { + // intentionally non-volatile — visibility must come from the IntMap CAS, not from + // the field declaration. + int field; + } + + /** Stale expected → no-op false return; current value unchanged. */ + @Test + public void testCompareAndPutWithStaleExpected() { + for (Supplier> factory : factories()) { + IntMap map = factory.get(); + int k = 3; + Object v0 = new Object(); + Assert.assertTrue(map.compareAndPut(k, null, v0)); + + Object v1 = new Object(); + Assert.assertTrue(map.compareAndPut(k, v0, v1)); + + // Try to overwrite using the stale v0 reference — must fail and leave v1 in place. + Object v2 = new Object(); + Assert.assertFalse(map.compareAndPut(k, v0, v2)); + Assert.assertSame(v1, map.get(k)); + } + } + + /** High-volume mixed workload; assert internal consistency. */ + @Test + public void testStressManyKeys() throws Exception { + for (Supplier> factory : factories()) { + int nThreads = 16; + int opsPerThread = 25_000; + int keySpace = 5_000; + IntMap map = factory.get(); + + ExecutorService pool = Executors.newFixedThreadPool(nThreads); + try { + CountDownLatch start = new CountDownLatch(1); + Future[] futures = new Future[nThreads]; + for (int t = 0; t < nThreads; t++) { + final int seed = t; + futures[t] = pool.submit(() -> { + try { + start.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + java.util.Random rnd = new java.util.Random(seed); + for (int i = 0; i < opsPerThread; i++) { + int k = rnd.nextInt(keySpace); + int op = rnd.nextInt(10); + if (op < 6) { + map.compareAndPut(k, map.get(k), new Object()); + } else if (op < 8) { + map.remove(k); + } else if (op < 9) { + map.get(k); + } else { + map.containsKey(k); + } + } + }); + } + start.countDown(); + for (Future f : futures) { + f.get(2 * TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + + // Final check: size() agrees with walking the map. + AtomicInteger walked = new AtomicInteger(); + Set seen = new HashSet<>(); + map.forEachKey(k -> { + Assert.assertTrue("dup key in walk: " + k, seen.add(k)); + walked.incrementAndGet(); + }); + Assert.assertEquals(map.size(), walked.get()); + } finally { + pool.shutdown(); + } + } + } +} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestSparseIntMapShards.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestSparseIntMapShards.java new file mode 100644 index 000000000..adfb0b492 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/util/TestSparseIntMapShards.java @@ -0,0 +1,76 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.util; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import org.junit.Assert; +import org.junit.Test; + +/** + * White-box tests for {@link SparseIntMap}'s shard distribution. Sequential keys (typical of + * HNSW node IDs assigned by {@code addGraphNode}) are the worst case for Agrona's identity + * hashing — without the avalanche mix in {@link SparseIntMap#shardIndex(int)} they would all + * pile onto one shard. + */ +public class TestSparseIntMapShards extends RandomizedTest { + + @Test + public void testSequentialKeysSpreadAcrossShards() { + int n = 100_000; + int[] perShard = new int[SparseIntMap.SHARD_COUNT]; + for (int i = 0; i < n; i++) { + perShard[SparseIntMap.shardIndex(i)]++; + } + assertWellDistributed(perShard, n); + } + + @Test + public void testGappyKeysSpreadAcrossShards() { + // HNSW upper layers have sparse, gappy node IDs. + int n = 50_000; + int[] perShard = new int[SparseIntMap.SHARD_COUNT]; + java.util.Random rnd = new java.util.Random(0xCAFEBABEL); + for (int i = 0; i < n; i++) { + perShard[SparseIntMap.shardIndex(rnd.nextInt())]++; + } + assertWellDistributed(perShard, n); + } + + private static void assertWellDistributed(int[] perShard, int total) { + int min = Integer.MAX_VALUE; + int max = Integer.MIN_VALUE; + for (int c : perShard) { + if (c < min) min = c; + if (c > max) max = c; + } + double avg = (double) total / SparseIntMap.SHARD_COUNT; + // Every shard should hold at least 80% of the average and at most 120%. + Assert.assertTrue("min shard " + min + " too small (avg=" + avg + ")", + min >= 0.80 * avg); + Assert.assertTrue("max shard " + max + " too large (avg=" + avg + ")", + max <= 1.20 * avg); + } + + @Test + public void testShardIndexBoundedAndStable() { + for (int k : new int[]{0, 1, -1, Integer.MAX_VALUE, Integer.MIN_VALUE, 12345, 0x55555555}) { + int s = SparseIntMap.shardIndex(k); + Assert.assertTrue("shard for " + k + " out of range: " + s, + s >= 0 && s < SparseIntMap.SHARD_COUNT); + Assert.assertEquals("shardIndex must be deterministic", s, SparseIntMap.shardIndex(k)); + } + } +}