From d869af530a14f49eafe14b7d6b1290774ece5078 Mon Sep 17 00:00:00 2001 From: Prashant Pandey Date: Thu, 30 Apr 2026 15:14:16 +0530 Subject: [PATCH] Batch updated by key groups --- .../FlatCollectionWriteTest.java | 118 +++++++++ .../postgres/FlatPostgresCollection.java | 230 ++++++++++++++---- 2 files changed, 307 insertions(+), 41 deletions(-) diff --git a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java index a52a55b2..c7c2ffd1 100644 --- a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java +++ b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java @@ -3531,6 +3531,124 @@ void testBulkUpdateAllOperatorTypes() throws Exception { } } + @Test + @DisplayName( + "Should efficiently batch updates across multiple key groups with complex operations") + void testBulkUpdateMultipleGroupsComplexOperations() throws Exception { + Map> updates = new LinkedHashMap<>(); + + // ===== Group 1: Top-level primitive + top-level array (3 keys: 1, 5, 8) ===== + // All have item="Soap" - these should be batched together + // This tests: SET on primitive field, APPEND_TO_LIST on array field + List group1Updates = + List.of( + SubDocumentUpdate.of("price", 99), // SET operator (top-level primitive) + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.APPEND_TO_LIST) + .subDocumentValue(SubDocumentValue.of(new String[] {"updated-tag", "batch-test"})) + .build()); // APPEND_TO_LIST on top-level array + + updates.put(rawKey("1"), group1Updates); + updates.put(rawKey("5"), group1Updates); + updates.put(rawKey("8"), group1Updates); + + // ===== Group 2: Nested JSONB updates (2 keys: 3, 7) ===== + // Both have props - these should be batched together + // This tests: SET on nested JSONB fields + List group2Updates = + List.of( + SubDocumentUpdate.builder() + .subDocument("props.brand") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("PremiumBrand")) + .build(), // SET on nested JSONB primitive + SubDocumentUpdate.builder() + .subDocument("props.size") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("XL")) + .build()); // SET on another nested field + + updates.put(rawKey("3"), group2Updates); + updates.put(rawKey("7"), group2Updates); + + // ===== Group 3: ADD operator + REMOVE_ALL_FROM_LIST (2 keys: 2, 6) ===== + // Both have quantity and tags - these should be batched together + // This tests: ADD on numeric field, REMOVE_ALL_FROM_LIST on array + List group3Updates = + List.of( + SubDocumentUpdate.builder() + .subDocument("quantity") + .operator(UpdateOperator.ADD) + .subDocumentValue(SubDocumentValue.of(100)) + .build(), // ADD to numeric field + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[] {"glass", "plastic"})) + .build()); // REMOVE_ALL_FROM_LIST + + updates.put(rawKey("2"), group3Updates); + updates.put(rawKey("6"), group3Updates); + + // Execute bulk update - should have 3 groups with 2-3 keys each + BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); + + // Total unique keys: 1, 2, 3, 5, 6, 7, 8 = 7 keys + assertEquals(7, result.getUpdatedCount(), "Should update 7 rows"); + + // Verify keys 1, 5, 8 have Group 1 updates (top-level primitive + array) + for (String id : List.of("1", "5", "8")) { + try (CloseableIterator iter = flatCollection.find(queryById(id))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(99, json.get("price").asInt(), "Key " + id + " price should be 99"); + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertTrue( + tagList.contains("updated-tag"), "Key " + id + " should contain 'updated-tag'"); + assertTrue(tagList.contains("batch-test"), "Key " + id + " should contain 'batch-test'"); + } + } + + // Verify keys 3, 7 have Group 2 updates (nested JSONB) + for (String id : List.of("3", "7")) { + try (CloseableIterator iter = flatCollection.find(queryById(id))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode props = json.get("props"); + assertNotNull(props, "Key " + id + " should have props"); + assertEquals( + "PremiumBrand", + props.get("brand").asText(), + "Key " + id + " brand should be updated"); + assertEquals("XL", props.get("size").asText(), "Key " + id + " size should be XL"); + } + } + + // Verify keys 2, 6 have Group 3 updates (ADD + REMOVE_ALL_FROM_LIST) + try (CloseableIterator iter = flatCollection.find(queryById("2"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(101, json.get("quantity").asInt()); // 1 + 100 + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertFalse(tagList.contains("glass"), "Key 2 should not have 'glass' tag"); + } + + try (CloseableIterator iter = flatCollection.find(queryById("6"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(105, json.get("quantity").asInt()); // 5 + 100 + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertFalse(tagList.contains("plastic"), "Key 6 should not have 'plastic' tag"); + } + } + @Test @DisplayName("Should handle edge cases: empty map, null map, non-existent keys") void testBulkUpdateEdgeCases() throws Exception { diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java index 440b295b..3cef66cc 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -874,59 +875,34 @@ public BulkUpdateResult bulkUpdate( String tableName = tableIdentifier.getTableName(); String quotedPkColumn = PostgresUtils.wrapFieldNamesWithDoubleQuotes(getPKForTable(tableName)); - - Set updatedKeys = new HashSet<>(); - long batchUpdateTimestamp = System.currentTimeMillis(); - try (Connection connection = client.getPooledConnection()) { - for (Map.Entry> entry : updates.entrySet()) { - Key key = entry.getKey(); - Collection keyUpdates = entry.getValue(); + // Group keys by their "SQL shape" (same update operations) + Map keyGroups = groupKeysByUpdateShape(updates, tableName); - if (keyUpdates == null || keyUpdates.isEmpty()) { - continue; - } + int totalUpdated = 0; + try (Connection connection = client.getPooledConnection()) { + // Execute one multi-row UPDATE per group (or fallback to single-key if group size = 1) + for (Map.Entry entry : keyGroups.entrySet()) { try { - boolean updated = - updateSingleKey( - connection, key, keyUpdates, tableName, quotedPkColumn, batchUpdateTimestamp); - if (updated) { - updatedKeys.add(key); - } + int updated = + executeBatchUpdate( + connection, entry.getValue(), tableName, quotedPkColumn, batchUpdateTimestamp); + totalUpdated += updated; } catch (Exception e) { - LOGGER.warn("Failed to update key {}: {}", key, e.getMessage()); - // Continue with other keys - no cross-key atomicity + LOGGER.warn( + "Failed to update key group (size: {}): {}", + entry.getValue().getKeys().size(), + e.getMessage()); + // Continue with other groups - no cross-group atomicity } } } catch (SQLException e) { throw new IOException("Failed to get connection for bulk update", e); } - return new BulkUpdateResult(updatedKeys.size()); - } - - private boolean updateSingleKey( - Connection connection, - Key key, - Collection keyUpdates, - String tableName, - String quotedPkColumn, - long keyUpdateTimestamp) - throws IOException, SQLException { - - updateValidator.validate(keyUpdates); - Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); - - return executeKeyUpdate( - connection, - key, - keyUpdates, - tableName, - quotedPkColumn, - resolvedColumns, - keyUpdateTimestamp); + return new BulkUpdateResult(totalUpdated); } private boolean executeKeyUpdate( @@ -972,6 +948,178 @@ private boolean executeKeyUpdate( } } + /** + * Groups keys that have identical update operations together. Keys with the same "shape" can be + * updated in a single multi-row statement. + */ + private Map groupKeysByUpdateShape( + Map> updates, String tableName) { + + Map groups = new LinkedHashMap<>(); + + for (Map.Entry> entry : updates.entrySet()) { + Key key = entry.getKey(); + Collection keyUpdates = entry.getValue(); + + if (keyUpdates == null || keyUpdates.isEmpty()) { + continue; + } + + try { + updateValidator.validate(keyUpdates); + Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); + + String shapeKey = computeUpdateShapeKey(keyUpdates, resolvedColumns); + + groups + .computeIfAbsent(shapeKey, k -> new KeyUpdateGroup(resolvedColumns)) + .addKeyWithUpdates(key, keyUpdates); + + } catch (Exception e) { + LOGGER.warn("Failed to group key {}: {}", key, e.getMessage()); + } + } + + return groups; + } + + private String computeUpdateShapeKey( + Collection updates, Map resolvedColumns) { + + List sorted = new ArrayList<>(updates); + sorted.sort(Comparator.comparing(u -> u.getSubDocument().getPath())); + + StringBuilder sb = new StringBuilder(); + for (SubDocumentUpdate update : sorted) { + String path = update.getSubDocument().getPath(); + String column = resolvedColumns.get(path); + sb.append(column) + .append(":") + .append(update.getOperator()) + .append(":") + .append(path) + .append(";"); + } + + return sb.toString(); + } + + /** + * Executes a batch UPDATE for all keys in the group using JDBC batching. All keys in the group + * share the same SQL structure, so we can use a single PreparedStatement. + */ + private int executeBatchUpdate( + Connection connection, + KeyUpdateGroup keyGroup, + String tableName, + String quotedPkColumn, + long epochMillis) + throws SQLException { + + List keys = keyGroup.getKeys(); + List> allKeyUpdates = keyGroup.getKeyUpdates(); + Map resolvedColumns = keyGroup.getResolvedColumns(); + + // Use the first key's updates to build the SQL template + Collection templateUpdates = allKeyUpdates.get(0); + List setFragments = new ArrayList<>(); + List templateParams = new ArrayList<>(); + + boolean hasUpdates = + buildSetClauseFragments( + connection, templateUpdates, tableName, resolvedColumns, setFragments, templateParams); + + if (!hasUpdates) { + return 0; + } + + appendLastUpdatedTimestamp(setFragments, templateParams, tableName, epochMillis); + + // Build UPDATE SQL (same for all keys in this group) + String sql = + String.format( + "UPDATE %s SET %s WHERE %s = ?", + tableIdentifier, String.join(", ", setFragments), quotedPkColumn); + + LOGGER.debug("Executing batch update SQL: {} for {} keys", sql, keys.size()); + + // Use JDBC batching to execute all updates in one round-trip + try (PreparedStatement ps = connection.prepareStatement(sql)) { + for (int i = 0; i < keys.size(); i++) { + Key key = keys.get(i); + Collection keyUpdates = allKeyUpdates.get(i); + + // Build parameters for this specific key + List keySetFragments = new ArrayList<>(); + List keyParams = new ArrayList<>(); + buildSetClauseFragments( + connection, keyUpdates, tableName, resolvedColumns, keySetFragments, keyParams); + + // Add timestamp parameter + if (lastUpdatedTsColumn != null) { + Optional colMeta = + schemaRegistry.getColumnOrRefresh(tableName, lastUpdatedTsColumn); + if (colMeta.isPresent()) { + Object timestampValue = + convertTimestampForType(epochMillis, colMeta.get().getPostgresType()); + keyParams.add(timestampValue); + } + } + + // Bind parameters for this key + int idx = 1; + for (Object param : keyParams) { + ps.setObject(idx++, param); + } + ps.setObject(idx, key.toString()); // WHERE clause parameter + + ps.addBatch(); + } + + int[] results = ps.executeBatch(); + int totalUpdated = 0; + for (int result : results) { + if (result > 0) { + totalUpdated++; + } + } + + LOGGER.debug("Batch update affected {} rows out of {} keys", totalUpdated, keys.size()); + return totalUpdated; + } catch (SQLException e) { + LOGGER.warn("Failed to execute batch update. SQL: {}, Error: {}", sql, e.getMessage()); + throw e; + } + } + + /** Holds a group of keys that share the same update shape. */ + private static class KeyUpdateGroup { + private final Map resolvedColumns; + private final List keys = new ArrayList<>(); + private final List> keyUpdates = new ArrayList<>(); + + KeyUpdateGroup(Map resolvedColumns) { + this.resolvedColumns = resolvedColumns; + } + + void addKeyWithUpdates(Key key, Collection updates) { + keys.add(key); + keyUpdates.add(updates); + } + + Map getResolvedColumns() { + return resolvedColumns; + } + + List getKeys() { + return keys; + } + + List> getKeyUpdates() { + return keyUpdates; + } + } + /** * Validates all updates and resolves column names. *