diff --git a/graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/objects/ObjectHashMapTests.java b/graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/objects/ObjectHashMapTests.java index 510d275c03..d209b42c56 100644 --- a/graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/objects/ObjectHashMapTests.java +++ b/graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/objects/ObjectHashMapTests.java @@ -74,9 +74,6 @@ import com.oracle.truffle.api.frame.Frame; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.nodes.Node; -import com.oracle.truffle.api.profiles.InlinedBranchProfile; -import com.oracle.truffle.api.profiles.InlinedConditionProfile; -import com.oracle.truffle.api.profiles.InlinedCountingConditionProfile; public class ObjectHashMapTests { public static final class DictKey implements TruffleObject { @@ -220,8 +217,7 @@ private static void popValues(ObjectHashMap map, LinkedHashMap exp var keys = expected.keySet().stream().toList().reversed().stream().limit(count).toArray(Long[]::new); for (int i = 0; i < keys.length; i++) { Long key = keys[i]; - Object[] popped = PopNode.doPopWithRestart(null, map, InlinedConditionProfile.getUncached(), InlinedCountingConditionProfile.getUncached(), InlinedCountingConditionProfile.getUncached(), - InlinedBranchProfile.getUncached()); + Object[] popped = PopNode.doPopWithRestartForTests(map); Assert.assertEquals(Integer.toString(i), key, popped[0]); Assert.assertEquals(Integer.toString(i), expected.get(key), popped[1]); expected.remove(key); @@ -376,25 +372,14 @@ private static long getKeyHash(Object key) { } private static Object get(ObjectHashMap map, Object key, long hash) { - InlinedCountingConditionProfile uncachedCounting = InlinedCountingConditionProfile.getUncached(); - return ObjectHashMap.GetNode.doGetWithRestart(null, null, map, key, hash, - InlinedBranchProfile.getUncached(), uncachedCounting, uncachedCounting, uncachedCounting, - uncachedCounting, uncachedCounting, - new EqNodeStub()); + return ObjectHashMap.GetNode.doGetWithRestartForTests(map, key, hash, new EqNodeStub()); } private static void remove(ObjectHashMap map, Object key, long hash) { - InlinedCountingConditionProfile uncachedCounting = InlinedCountingConditionProfile.getUncached(); - ObjectHashMap.RemoveNode.doRemoveWithRestart(null, null, map, key, hash, - InlinedBranchProfile.getUncached(), uncachedCounting, uncachedCounting, uncachedCounting, - uncachedCounting, InlinedBranchProfile.getUncached(), new EqNodeStub()); + ObjectHashMap.RemoveNode.doRemoveWithRestartForTests(map, key, hash, new EqNodeStub()); } private static void put(ObjectHashMap map, Object key, long hash, Object value) { - InlinedCountingConditionProfile uncachedCounting = InlinedCountingConditionProfile.getUncached(); - PutNode.doPutWithRestart(null, null, map, key, hash, value, - InlinedBranchProfile.getUncached(), uncachedCounting, uncachedCounting, uncachedCounting, - uncachedCounting, InlinedBranchProfile.getUncached(), InlinedBranchProfile.getUncached(), - new EqNodeStub()); + PutNode.doPutWithRestartForTests(map, key, hash, value, new EqNodeStub()); } } diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/ObjectHashMap.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/ObjectHashMap.java index 4f37786bc4..182211f628 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/ObjectHashMap.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/ObjectHashMap.java @@ -50,10 +50,12 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.HostCompilerDirectives.InliningCutoff; import com.oracle.truffle.api.TruffleSafepoint; +import com.oracle.truffle.api.dsl.Bind; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.GenerateCached; import com.oracle.truffle.api.dsl.GenerateInline; import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.Frame; import com.oracle.truffle.api.nodes.LoopNode; @@ -121,6 +123,8 @@ public class ObjectHashMap extends HashingStorage { private static final int TIGHT_ENTRY_CAPACITY_LIMIT = 8; + static final int INDEX_BYTE_SIZE_CACHE_LIMIT = 3; + /** * Every hash map will preallocate at least this many buckets (and corresponding # of slots for * the real items). @@ -139,6 +143,9 @@ public class ObjectHashMap extends HashingStorage { private static final int COLLISION_MASK = 1 << 31; private static final int BYTE_COLLISION_MASK = 1 << 7; private static final int SHORT_COLLISION_MASK = 1 << 15; + private static final int BYTE_COLLISION_SHIFT = 24; + private static final int SHORT_COLLISION_SHIFT = 16; + private static final int INT_COLLISION_SHIFT = 0; /** * Sparse table indices use 0 and 1 as reserved markers and store real compact-array indices as @@ -150,12 +157,6 @@ public class ObjectHashMap extends HashingStorage { private static final int MAX_BYTE_INDEX = (BYTE_COLLISION_MASK - 1) - INDEX_OFFSET; private static final int MAX_SHORT_INDEX = (SHORT_COLLISION_MASK - 1) - INDEX_OFFSET; - private static void markCollision(byte[] metadata, int entryCapacity, int compactIndex) { - int indexByteSize = getIndexByteSize(entryCapacity); - int indicesOffset = getIndicesOffset(entryCapacity); - markCollision(metadata, indicesOffset, indexByteSize, getPhysicalCollisionMaskForIndexByteSize(indexByteSize), compactIndex); - } - private static void markCollision(byte[] metadata, int indicesOffset, int indexByteSize, int physicalCollisionMask, int compactIndex) { int index = getIndex(metadata, indicesOffset, indexByteSize, compactIndex); assert index != EMPTY_INDEX; @@ -298,23 +299,19 @@ public Object getValue() { } } - private int getEntryCapacity() { + int getEntryCapacity() { return keysAndValues.length >> 1; } - private int getBucketsCount() { - return getBucketsCount(metadata, getEntryCapacity()); - } - private static byte[] createMetadata(int bucketsCount, int usableSize) { return new byte[getMetadataLength(bucketsCount, usableSize)]; } - private static int getBucketsCount(byte[] metadata, int entryCapacity) { - return (metadata.length - getIndicesOffset(entryCapacity)) / getIndexByteSize(entryCapacity); + private static int getBucketsCount(byte[] metadata, int entryCapacity, int indexByteSize) { + return (metadata.length - getIndicesOffset(entryCapacity)) / indexByteSize; } - private static int getIndexByteSize(int entryCapacity) { + static int getIndexByteSize(int entryCapacity) { if (entryCapacity - 1 <= MAX_BYTE_INDEX) { return Byte.BYTES; } else if (entryCapacity - 1 <= MAX_SHORT_INDEX) { @@ -324,6 +321,11 @@ private static int getIndexByteSize(int entryCapacity) { } } + @TruffleBoundary + private static int getIndexByteSizeAfterRestart(int entryCapacity) { + return getIndexByteSize(entryCapacity); + } + private static int getIndicesOffset(int entryCapacity) { return castMetadataInt(getIndicesOffsetLong(entryCapacity)); } @@ -336,10 +338,6 @@ private static int getHashOffset(int index) { return castMetadataInt((long) index * Long.BYTES); } - private static int getIndexOffset(int entryCapacity, int compactIndex) { - return castMetadataInt((long) getIndicesOffset(entryCapacity) + ((long) compactIndex * getIndexByteSize(entryCapacity))); - } - private static int getIndexOffset(int indicesOffset, int indexByteSize, int compactIndex) { return indicesOffset + compactIndex * indexByteSize; } @@ -375,35 +373,20 @@ private static int getPhysicalCollisionMaskForIndexByteSize(int indexByteSize) { } } - private static int getPhysicalCollisionMask(int entryCapacity) { - return getPhysicalCollisionMaskForIndexByteSize(getIndexByteSize(entryCapacity)); - } - - private static int getIndex(byte[] metadata, int entryCapacity, int compactIndex) { - return getIndex(metadata, getIndicesOffset(entryCapacity), getIndexByteSize(entryCapacity), compactIndex); - } - private static int getIndex(byte[] metadata, int indicesOffset, int indexByteSize, int compactIndex) { int offset = getIndexOffset(indicesOffset, indexByteSize, compactIndex); if (indexByteSize == Byte.BYTES) { - return decodeIndex(metadata[offset] & 0xFF, BYTE_COLLISION_MASK); + return decodeIndex(metadata[offset] & 0xFF, BYTE_COLLISION_MASK, BYTE_COLLISION_SHIFT); } else if (indexByteSize == Short.BYTES) { - return decodeIndex(PythonUtils.ARRAY_ACCESSOR.getShort(metadata, offset) & 0xFFFF, SHORT_COLLISION_MASK); + return decodeIndex(PythonUtils.ARRAY_ACCESSOR.getShort(metadata, offset) & 0xFFFF, SHORT_COLLISION_MASK, SHORT_COLLISION_SHIFT); } else { - return decodeIndex(PythonUtils.ARRAY_ACCESSOR.getInt(metadata, offset), COLLISION_MASK); + return decodeIndex(PythonUtils.ARRAY_ACCESSOR.getInt(metadata, offset), COLLISION_MASK, INT_COLLISION_SHIFT); } } - private static int decodeIndex(int encodedValue, int physicalCollisionMask) { + private static int decodeIndex(int encodedValue, int physicalCollisionMask, int collisionShift) { int value = encodedValue & (physicalCollisionMask - 1); - if ((encodedValue & physicalCollisionMask) != 0) { - value |= COLLISION_MASK; - } - return value; - } - - private static void setIndex(byte[] metadata, int entryCapacity, int compactIndex, int logicalValue) { - setIndex(metadata, getIndicesOffset(entryCapacity), getIndexByteSize(entryCapacity), getPhysicalCollisionMask(entryCapacity), compactIndex, logicalValue); + return value | ((encodedValue & physicalCollisionMask) << collisionShift); } private static void setIndex(byte[] metadata, int indicesOffset, int indexByteSize, int physicalCollisionMask, int compactIndex, int logicalValue) { @@ -433,11 +416,10 @@ long getHash(int index) { return getHash(metadata, index); } - private boolean needsResize(byte[] localMetadata) { + private boolean needsResize(int entryCapacity, int bucketsCount) { // Keep one slot empty at all times. For the smallest table, that means resizing once 2 of // the 4 buckets are already in use instead of allowing the table to become completely full. - int bucketsCount = getBucketsCount(localMetadata, getEntryCapacity()); - return usedHashes >= getEntryCapacity() || usedIndices >= getUsableSize(bucketsCount); + return usedHashes >= entryCapacity || usedIndices >= getUsableSize(bucketsCount); } public int size() { @@ -447,20 +429,33 @@ public int size() { @GenerateUncached @GenerateInline @GenerateCached(false) + @ImportStatic(ObjectHashMap.class) public abstract static class PopNode extends Node { public abstract Object[] execute(Node inliningTarget, ObjectHashMap map); - @Specialization - public static Object[] doPopWithRestart(Node inliningTarget, ObjectHashMap map, + @TruffleBoundary + public static Object[] doPopWithRestartForTests(ObjectHashMap map) { + // Public entry point for Java tests that bypass the generated node. + int entryCapacity = map.getEntryCapacity(); + return doPopWithRestart(null, map, entryCapacity, getIndexByteSize(entryCapacity), InlinedConditionProfile.getUncached(), + InlinedCountingConditionProfile.getUncached(), InlinedCountingConditionProfile.getUncached(), InlinedBranchProfile.getUncached()); + } + + @Specialization(guards = "indexByteSize == getIndexByteSize(entryCapacity)", limit = "INDEX_BYTE_SIZE_CACHE_LIMIT") + static Object[] doPopWithRestart(Node inliningTarget, ObjectHashMap map, + @Bind("map.getEntryCapacity()") int entryCapacity, + @Cached(value = "getIndexByteSize(entryCapacity)", allowUncached = true) int indexByteSize, @Cached InlinedConditionProfile emptyMapProfile, @Cached InlinedCountingConditionProfile hasValueProfile, @Cached InlinedCountingConditionProfile hasCollisionProfile, @Cached InlinedBranchProfile lookupRestart) { while (true) { try { - return doPop(inliningTarget, map, map.metadata, emptyMapProfile, hasValueProfile, hasCollisionProfile); + return doPop(inliningTarget, map, map.metadata, entryCapacity, indexByteSize, emptyMapProfile, hasValueProfile, hasCollisionProfile); } catch (RestartLookupException ignore) { lookupRestart.enter(inliningTarget); + entryCapacity = map.getEntryCapacity(); + indexByteSize = getIndexByteSizeAfterRestart(entryCapacity); } } } @@ -469,7 +464,7 @@ private static boolean isIndex(int indexInIndices, int indexToFind) { return indexInIndices != DUMMY_INDEX && indexInIndices != EMPTY_INDEX && indexToFind == unwrapIndex(indexInIndices); } - private static Object[] doPop(Node inliningTarget, ObjectHashMap map, byte[] metadata, + private static Object[] doPop(Node inliningTarget, ObjectHashMap map, byte[] metadata, int entryCapacity, int indexByteSize, @Cached InlinedConditionProfile emptyMapProfile, @Cached InlinedCountingConditionProfile hasValueProfile, @Cached InlinedCountingConditionProfile hasCollisionProfile) throws RestartLookupException { @@ -477,9 +472,7 @@ private static Object[] doPop(Node inliningTarget, ObjectHashMap map, byte[] met return null; } Object[] localKeysAndValues = map.keysAndValues; - int entryCapacity = map.getEntryCapacity(); - int indicesLen = getBucketsCount(metadata, entryCapacity); - int indexByteSize = getIndexByteSize(entryCapacity); + int indicesLen = getBucketsCount(metadata, entryCapacity, indexByteSize); int indicesOffset = getIndicesOffset(entryCapacity); int physicalCollisionMask = getPhysicalCollisionMaskForIndexByteSize(indexByteSize); int usedHashes = map.usedHashes; @@ -538,12 +531,23 @@ private static void removeBucketWithIndex(ObjectHashMap map, byte[] metadata, in @GenerateUncached @GenerateInline @GenerateCached(false) + @ImportStatic(ObjectHashMap.class) public abstract static class GetNode extends Node { public abstract Object execute(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash); - // "public" for testing... - @Specialization - public static Object doGetWithRestart(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, + @TruffleBoundary + public static Object doGetWithRestartForTests(ObjectHashMap map, Object key, long keyHash, PyObjectRichCompareBool eqNode) { + // Public entry point for Java tests that bypass the generated node. + int entryCapacity = map.getEntryCapacity(); + InlinedCountingConditionProfile uncachedCounting = InlinedCountingConditionProfile.getUncached(); + return doGetWithRestart(null, null, map, key, keyHash, entryCapacity, getIndexByteSize(entryCapacity), InlinedBranchProfile.getUncached(), uncachedCounting, + uncachedCounting, uncachedCounting, uncachedCounting, uncachedCounting, eqNode); + } + + @Specialization(guards = "indexByteSize == getIndexByteSize(entryCapacity)", limit = "INDEX_BYTE_SIZE_CACHE_LIMIT") + static Object doGetWithRestart(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, + @Bind("map.getEntryCapacity()") int entryCapacity, + @Cached(value = "getIndexByteSize(entryCapacity)", allowUncached = true) int indexByteSize, @Cached InlinedBranchProfile lookupRestart, @Cached InlinedCountingConditionProfile foundNullKey, @Cached InlinedCountingConditionProfile foundSameHashKey, @@ -553,16 +557,18 @@ public static Object doGetWithRestart(Frame frame, Node inliningTarget, ObjectHa @Cached PyObjectRichCompareBool eqNode) { while (true) { try { - return doGet(frame, map, key, keyHash, inliningTarget, foundNullKey, foundSameHashKey, + return doGet(frame, map, key, keyHash, indexByteSize, entryCapacity, inliningTarget, foundNullKey, foundSameHashKey, foundEqKey, collisionFoundNoValue, collisionFoundEqKey, eqNode); } catch (RestartLookupException ignore) { lookupRestart.enter(inliningTarget); TruffleSafepoint.poll(inliningTarget); + entryCapacity = map.getEntryCapacity(); + indexByteSize = getIndexByteSizeAfterRestart(entryCapacity); } } } - static Object doGet(Frame frame, ObjectHashMap map, Object key, long keyHash, + static Object doGet(Frame frame, ObjectHashMap map, Object key, long keyHash, int indexByteSize, int entryCapacity, Node inliningTarget, InlinedCountingConditionProfile foundNullKey, InlinedCountingConditionProfile foundSameHashKey, @@ -570,11 +576,9 @@ static Object doGet(Frame frame, ObjectHashMap map, Object key, long keyHash, InlinedCountingConditionProfile collisionFoundNoValue, InlinedCountingConditionProfile collisionFoundEqKey, PyObjectRichCompareBool eqNode) throws RestartLookupException { - assert map.checkInternalState(); + assert map.checkInternalState(entryCapacity, indexByteSize); byte[] metadata = map.metadata; - int entryCapacity = map.getEntryCapacity(); - int indicesLen = getBucketsCount(metadata, entryCapacity); - int indexByteSize = getIndexByteSize(entryCapacity); + int indicesLen = getBucketsCount(metadata, entryCapacity, indexByteSize); int indicesOffset = getIndicesOffset(entryCapacity); int compactIndex = getIndex(indicesLen, keyHash); @@ -644,6 +648,7 @@ private static Object getCollision(Frame frame, ObjectHashMap map, Object key, l @GenerateUncached @GenerateInline @GenerateCached(false) + @ImportStatic(ObjectHashMap.class) public abstract static class PutNode extends Node { public static PutNode getUncached() { return PutNodeGen.getUncached(); @@ -663,9 +668,19 @@ public static void putUncached(ObjectHashMap map, Object key, long keyHash, Obje abstract void execute(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, Object value); - // "public" for testing... - @Specialization - public static void doPutWithRestart(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, Object value, + @TruffleBoundary + public static void doPutWithRestartForTests(ObjectHashMap map, Object key, long keyHash, Object value, PyObjectRichCompareBool eqNode) { + // Public entry point for Java tests that bypass the generated node. + int entryCapacity = map.getEntryCapacity(); + InlinedCountingConditionProfile uncachedCounting = InlinedCountingConditionProfile.getUncached(); + doPutWithRestart(null, null, map, key, keyHash, value, entryCapacity, getIndexByteSize(entryCapacity), InlinedBranchProfile.getUncached(), uncachedCounting, + uncachedCounting, uncachedCounting, uncachedCounting, InlinedBranchProfile.getUncached(), InlinedBranchProfile.getUncached(), eqNode); + } + + @Specialization(guards = "indexByteSize == getIndexByteSize(entryCapacity)", limit = "INDEX_BYTE_SIZE_CACHE_LIMIT") + static void doPutWithRestart(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, Object value, + @Bind("map.getEntryCapacity()") int entryCapacity, + @Cached(value = "getIndexByteSize(entryCapacity)", allowUncached = true) int indexByteSize, @Cached InlinedBranchProfile lookupRestart, @Cached InlinedCountingConditionProfile foundNullKey, @Cached InlinedCountingConditionProfile foundEqKey, @@ -676,18 +691,20 @@ public static void doPutWithRestart(Frame frame, Node inliningTarget, ObjectHash @Cached PyObjectRichCompareBool eqNode) { while (true) { try { - doPut(frame, map, key, keyHash, value, inliningTarget, foundNullKey, foundEqKey, + doPut(frame, map, key, keyHash, value, entryCapacity, indexByteSize, inliningTarget, foundNullKey, foundEqKey, collisionFoundNoValue, collisionFoundEqKey, rehash1Profile, rehash2Profile, eqNode); return; } catch (RestartLookupException ignore) { lookupRestart.enter(inliningTarget); TruffleSafepoint.poll(inliningTarget); + entryCapacity = map.getEntryCapacity(); + indexByteSize = getIndexByteSizeAfterRestart(entryCapacity); } } } - static void doPut(Frame frame, ObjectHashMap map, Object key, long keyHash, Object value, + static void doPut(Frame frame, ObjectHashMap map, Object key, long keyHash, Object value, int entryCapacity, int indexByteSize, Node inliningTarget, InlinedCountingConditionProfile foundNullKey, InlinedCountingConditionProfile foundEqKey, @@ -696,18 +713,16 @@ static void doPut(Frame frame, ObjectHashMap map, Object key, long keyHash, Obje InlinedBranchProfile rehash1Profile, InlinedBranchProfile rehash2Profile, PyObjectRichCompareBool eqNode) throws RestartLookupException { - assert map.checkInternalState(); + assert map.checkInternalState(entryCapacity, indexByteSize); byte[] metadata = map.metadata; - int entryCapacity = map.getEntryCapacity(); - int indicesLen = getBucketsCount(metadata, entryCapacity); - int indexByteSize = getIndexByteSize(entryCapacity); + int indicesLen = getBucketsCount(metadata, entryCapacity, indexByteSize); int indicesOffset = getIndicesOffset(entryCapacity); int physicalCollisionMask = getPhysicalCollisionMaskForIndexByteSize(indexByteSize); int compactIndex = getIndex(indicesLen, keyHash); int index = getIndex(metadata, indicesOffset, indexByteSize, compactIndex); if (foundNullKey.profile(inliningTarget, index == EMPTY_INDEX)) { - map.putInNewSlot(metadata, entryCapacity, inliningTarget, rehash1Profile, key, keyHash, value, compactIndex); + map.putInNewSlot(metadata, indicesOffset, indexByteSize, physicalCollisionMask, entryCapacity, indicesLen, inliningTarget, rehash1Profile, key, keyHash, value, compactIndex); return; } @@ -741,7 +756,7 @@ private static void putCollision(Frame frame, ObjectHashMap map, Object key, lon compactIndex = nextIndex(indicesLen, compactIndex, perturb); int index = getIndex(metadata, indicesOffset, indexByteSize, compactIndex); if (collisionFoundNoValue.profile(inliningTarget, index == EMPTY_INDEX)) { - map.putInNewSlot(metadata, entryCapacity, inliningTarget, rehash2Profile, key, keyHash, value, compactIndex); + map.putInNewSlot(metadata, indicesOffset, indexByteSize, physicalCollisionMask, entryCapacity, indicesLen, inliningTarget, rehash2Profile, key, keyHash, value, compactIndex); return; } if (collisionFoundEqKey.profile(inliningTarget, index != DUMMY_INDEX && map.keysEqual(metadata, frame, inliningTarget, unwrapIndex(index), key, keyHash, eqNode))) { @@ -766,17 +781,12 @@ private static void putCollision(Frame frame, ObjectHashMap map, Object key, lon // Internal helper: it is not profiling, never rehashes, and it assumes that the hash map never // contains the key that we are inserting - private void insertNewKey(byte[] localMetadata, Object key, long keyHash, Object value) { + private void insertNewKey(byte[] localMetadata, int indicesLen, int indicesOffset, int indexByteSize, int physicalCollisionMask, Object key, long keyHash, Object value) { assert localMetadata == this.metadata; - int entryCapacity = getEntryCapacity(); - int indicesLen = getBucketsCount(localMetadata, entryCapacity); - int indexByteSize = getIndexByteSize(entryCapacity); - int indicesOffset = getIndicesOffset(entryCapacity); - int physicalCollisionMask = getPhysicalCollisionMaskForIndexByteSize(indexByteSize); int compactIndex = getIndex(indicesLen, keyHash); int index = getIndex(localMetadata, indicesOffset, indexByteSize, compactIndex); if (index == EMPTY_INDEX) { - putInNewSlot(localMetadata, entryCapacity, key, keyHash, value, compactIndex); + putInNewSlot(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, key, keyHash, value, compactIndex); return; } @@ -789,7 +799,7 @@ private void insertNewKey(byte[] localMetadata, Object key, long keyHash, Object compactIndex = nextIndex(indicesLen, compactIndex, perturb); index = getIndex(localMetadata, indicesOffset, indexByteSize, compactIndex); if (index == EMPTY_INDEX) { - putInNewSlot(localMetadata, entryCapacity, key, keyHash, value, compactIndex); + putInNewSlot(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, key, keyHash, value, compactIndex); return; } markCollision(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, compactIndex); @@ -800,30 +810,31 @@ private void insertNewKey(byte[] localMetadata, Object key, long keyHash, Object throw CompilerDirectives.shouldNotReachHere(); } - private void putInNewSlot(byte[] localMetadata, int entryCapacity, Node inliningTarget, InlinedBranchProfile rehashProfile, Object key, long keyHash, Object value, int compactIndex) { + private void putInNewSlot(byte[] localMetadata, int indicesOffset, int indexByteSize, int physicalCollisionMask, int entryCapacity, int bucketsCount, Node inliningTarget, + InlinedBranchProfile rehashProfile, Object key, long keyHash, Object value, int compactIndex) { assert metadata == localMetadata; assert entryCapacity == getEntryCapacity(); - if (CompilerDirectives.injectBranchProbability(SLOWPATH_PROBABILITY, needsResize(localMetadata))) { + if (CompilerDirectives.injectBranchProbability(SLOWPATH_PROBABILITY, needsResize(entryCapacity, bucketsCount))) { rehashProfile.enter(inliningTarget); rehashAndPut(key, keyHash, value); return; } - putInNewSlot(localMetadata, entryCapacity, key, keyHash, value, compactIndex); + putInNewSlot(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, key, keyHash, value, compactIndex); } - private void putInNewSlot(byte[] localMetadata, int entryCapacity, Object key, long keyHash, Object value, int compactIndex) { + private void putInNewSlot(byte[] localMetadata, int indicesOffset, int indexByteSize, int physicalCollisionMask, Object key, long keyHash, Object value, int compactIndex) { size++; usedIndices++; int newIndex = usedHashes++; - setIndex(localMetadata, entryCapacity, compactIndex, newIndex + INDEX_OFFSET); + setIndex(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, compactIndex, newIndex + INDEX_OFFSET); setValue(newIndex, value); setKey(newIndex, key); setHash(localMetadata, newIndex, keyHash); } - private boolean needsCompaction() { + private boolean needsCompaction(int entryCapacity) { // if more than quarter of all the slots are occupied by dummy values -> compact - int quarterOfUsable = getEntryCapacity() >> 2; + int quarterOfUsable = entryCapacity >> 2; int dummyCnt = usedHashes - size; return dummyCnt > quarterOfUsable; } @@ -831,6 +842,7 @@ private boolean needsCompaction() { @GenerateUncached @GenerateInline @GenerateCached(false) + @ImportStatic(ObjectHashMap.class) public abstract static class RemoveNode extends Node { public static Object removeUncached(ObjectHashMap map, Object key, long keyHash) { return RemoveNodeGen.getUncached().execute(null, null, map, key, keyHash); @@ -838,9 +850,19 @@ public static Object removeUncached(ObjectHashMap map, Object key, long keyHash) public abstract Object execute(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash); - // "public" for testing... - @Specialization - public static Object doRemoveWithRestart(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, + @TruffleBoundary + public static Object doRemoveWithRestartForTests(ObjectHashMap map, Object key, long keyHash, PyObjectRichCompareBool eqNode) { + // Public entry point for Java tests that bypass the generated node. + int entryCapacity = map.getEntryCapacity(); + InlinedCountingConditionProfile uncachedCounting = InlinedCountingConditionProfile.getUncached(); + return doRemoveWithRestart(null, null, map, key, keyHash, entryCapacity, getIndexByteSize(entryCapacity), InlinedBranchProfile.getUncached(), uncachedCounting, + uncachedCounting, uncachedCounting, uncachedCounting, InlinedBranchProfile.getUncached(), eqNode); + } + + @Specialization(guards = "indexByteSize == getIndexByteSize(entryCapacity)", limit = "INDEX_BYTE_SIZE_CACHE_LIMIT") + static Object doRemoveWithRestart(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, + @Bind("map.getEntryCapacity()") int entryCapacity, + @Cached(value = "getIndexByteSize(entryCapacity)", allowUncached = true) int indexByteSize, @Cached InlinedBranchProfile lookupRestart, @Cached InlinedCountingConditionProfile foundNullKey, @Cached InlinedCountingConditionProfile foundEqKey, @@ -850,32 +872,32 @@ public static Object doRemoveWithRestart(Frame frame, Node inliningTarget, Objec @Cached PyObjectRichCompareBool eqNode) { while (true) { try { - return doRemove(frame, inliningTarget, map, key, keyHash, foundNullKey, foundEqKey, + return doRemove(frame, inliningTarget, map, key, keyHash, entryCapacity, indexByteSize, foundNullKey, foundEqKey, collisionFoundNoValue, collisionFoundEqKey, compactProfile, eqNode); } catch (RestartLookupException ignore) { lookupRestart.enter(inliningTarget); + entryCapacity = map.getEntryCapacity(); + indexByteSize = getIndexByteSizeAfterRestart(entryCapacity); } } } - static Object doRemove(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, + static Object doRemove(Frame frame, Node inliningTarget, ObjectHashMap map, Object key, long keyHash, int entryCapacity, int indexByteSize, InlinedCountingConditionProfile foundNullKey, InlinedCountingConditionProfile foundEqKey, InlinedCountingConditionProfile collisionFoundNoValue, InlinedCountingConditionProfile collisionFoundEqKey, InlinedBranchProfile compactProfile, PyObjectRichCompareBool eqNode) throws RestartLookupException { - assert map.checkInternalState(); + assert map.checkInternalState(entryCapacity, indexByteSize); // TODO: move this to the point after we find the value to remove? - if (CompilerDirectives.injectBranchProbability(SLOWPATH_PROBABILITY, map.needsCompaction())) { + if (CompilerDirectives.injectBranchProbability(SLOWPATH_PROBABILITY, map.needsCompaction(entryCapacity))) { compactProfile.enter(inliningTarget); map.compact(); } byte[] metadata = map.metadata; - int entryCapacity = map.getEntryCapacity(); - int indicesLen = getBucketsCount(metadata, entryCapacity); - int indexByteSize = getIndexByteSize(entryCapacity); + int indicesLen = getBucketsCount(metadata, entryCapacity, indexByteSize); int indicesOffset = getIndicesOffset(entryCapacity); int physicalCollisionMask = getPhysicalCollisionMaskForIndexByteSize(indexByteSize); @@ -1003,15 +1025,20 @@ private void rehashAndPut(Object newKey, long newKeyHash, Object newValue) { usedHashes = 0; usedIndices = 0; byte[] localMetadata = this.metadata; + int entryCapacity = getEntryCapacity(); + int indexByteSize = getIndexByteSize(entryCapacity); + int indicesLen = getBucketsCount(localMetadata, entryCapacity, indexByteSize); + int indicesOffset = getIndicesOffset(entryCapacity); + int physicalCollisionMask = getPhysicalCollisionMaskForIndexByteSize(indexByteSize); for (int i = 0; i < oldUsedSize; i++) { if (getValue(i, oldKeysAndValues) != null) { final Object key = getKey(i, oldKeysAndValues); - insertNewKey(localMetadata, key, getHash(oldMetadata, i), getValue(i, oldKeysAndValues)); + insertNewKey(localMetadata, indicesLen, indicesOffset, indexByteSize, physicalCollisionMask, key, getHash(oldMetadata, i), getValue(i, oldKeysAndValues)); } } assert size == oldSize : String.format("size=%d, oldSize=%d, oldUsedSize=%d, usedHashes=%d, usedIndices=%d", size, oldSize, oldUsedSize, usedHashes, usedIndices); - insertNewKey(localMetadata, newKey, newKeyHash, newValue); + insertNewKey(localMetadata, indicesLen, indicesOffset, indexByteSize, physicalCollisionMask, newKey, newKeyHash, newValue); } private static int getRequestedEntryCapacity(int requestedCapacity, int bucketsCount) { @@ -1046,16 +1073,20 @@ private void compact() { usedHashes -= dummyCount; // We've "removed" the dummy entries byte[] localMetadata = metadata; int entryCapacity = getEntryCapacity(); - int localIndicesLength = getBucketsCount(localMetadata, entryCapacity); + int indexByteSize = getIndexByteSize(entryCapacity); + int localIndicesLength = getBucketsCount(localMetadata, entryCapacity, indexByteSize); + int indicesOffset = getIndicesOffset(entryCapacity); + int physicalCollisionMask = getPhysicalCollisionMaskForIndexByteSize(indexByteSize); for (int i = 0; i < localIndicesLength; i++) { - int index = getIndex(localMetadata, entryCapacity, i); + int index = getIndex(localMetadata, indicesOffset, indexByteSize, i); if (index != EMPTY_INDEX && index != DUMMY_INDEX) { boolean collision = isCollision(index); int unwrapped = unwrapIndex(index); int newIndex = unwrapped - shuffle[unwrapped]; - setIndex(localMetadata, entryCapacity, i, newIndex + INDEX_OFFSET); if (collision) { - markCollision(localMetadata, entryCapacity, i); + setIndex(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, i, newIndex + INDEX_OFFSET | COLLISION_MASK); + } else { + setIndex(localMetadata, indicesOffset, indexByteSize, physicalCollisionMask, i, newIndex + INDEX_OFFSET); } } else if (index == DUMMY_INDEX) { dummyCount--; @@ -1117,10 +1148,10 @@ public void setKey(int index, Object key) { keysAndValues[(index << 1)] = key; } - private boolean checkInternalState() { + private boolean checkInternalState(int entryCapacity, int indexByteSize) { // We must have at least one empty slot, collision resolution relies on the fact that it is // always going to find an empty slot - assert usedIndices < getBucketsCount() : usedIndices; + assert usedIndices < getBucketsCount(metadata, entryCapacity, indexByteSize) : usedIndices; return true; }