From 4bf1d1747da61c6f7042b943fa1c7d7c863cef14 Mon Sep 17 00:00:00 2001 From: umi Date: Fri, 27 Feb 2026 19:27:29 +0800 Subject: [PATCH 1/4] proto --- .../pypaimon/tests/table_update_test.py | 192 ++++++++ .../pypaimon/write/commit_rollback.py | 157 +++++++ .../pypaimon/write/conflict_detection.py | 420 ++++++++++++++++++ .../pypaimon/write/file_store_commit.py | 107 ++++- paimon-python/pypaimon/write/table_commit.py | 16 + paimon-python/pypaimon/write/table_update.py | 4 +- .../pypaimon/write/table_update_by_row_id.py | 7 +- 7 files changed, 890 insertions(+), 13 deletions(-) create mode 100644 paimon-python/pypaimon/write/commit_rollback.py create mode 100644 paimon-python/pypaimon/write/conflict_detection.py diff --git a/paimon-python/pypaimon/tests/table_update_test.py b/paimon-python/pypaimon/tests/table_update_test.py index ad9158e9febf..5dc161e06813 100644 --- a/paimon-python/pypaimon/tests/table_update_test.py +++ b/paimon-python/pypaimon/tests/table_update_test.py @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import logging import os import shutil import tempfile @@ -859,6 +860,197 @@ def test_update_partial_rows_across_two_files(self): 'Seattle', 'Boston', 'Denver', 'Miami', 'Atlanta'] self.assertEqual(expected_cities, cities, "Cities should remain unchanged") + def test_concurrent_updates_with_retry(self): + """Test data evolution with multiple threads performing concurrent updates. + + Each thread updates different rows of the same column. If a conflict occurs, + the thread retries until the update succeeds. After all threads complete, + the final result is verified to ensure all updates were applied correctly. + """ + import threading + import traceback + table = self._create_table() + + # Table has 5 rows (row_id 0-4) after _create_table: + # row 0: age=25, row 1: age=30, row 2: age=35, row 3: age=40, row 4: age=45 + + # Thread 1 updates rows 0 and 1 + # Thread 2 updates rows 2 and 3 + # Thread 3 updates row 4 + thread_updates = [ + {'row_ids': [0, 1], 'ages': [100, 200]}, + {'row_ids': [2, 3], 'ages': [300, 400]}, + {'row_ids': [4], 'ages': [500]}, + ] + + errors = [] + success_counts = [0] * len(thread_updates) + + def update_worker(thread_index, update_spec): + max_retries = 20 + for attempt in range(max_retries): + try: + print("hello0") + write_builder = table.new_batch_write_builder() + table_update = write_builder.new_update().with_update_type(['age']) + + update_data = pa.Table.from_pydict({ + '_ROW_ID': update_spec['row_ids'], + 'age': update_spec['ages'], + }) + + commit_messages, snapshot_id = table_update.update_by_arrow_with_row_id(update_data) + + table_commit = write_builder.new_commit().row_id_check_conflict(snapshot_id) + table_commit.commit(commit_messages) + table_commit.close() + + success_counts[thread_index] = attempt + 1 + return + except Exception as e: + print( + "Thread-{} attempt {} failed: {}\n{}".format( + thread_index, attempt + 1, e, traceback.format_exc() + ) + ) + if attempt == max_retries - 1: + errors.append( + "Thread-{} failed after {} retries: {}".format( + thread_index, max_retries, e + ) + ) + + threads = [] + for idx, spec in enumerate(thread_updates): + thread = threading.Thread(target=update_worker, args=(idx, spec)) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join(timeout=120) + + if errors: + self.fail("Some threads failed:\n" + "\n".join(errors)) + + for idx, count in enumerate(success_counts): + self.assertGreater( + count, 0, + "Thread-{} did not succeed".format(idx) + ) + + # Verify the final data + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + result = table_read.to_arrow(splits) + + ages = result['age'].to_pylist() + expected_ages = [100, 200, 300, 400, 500] + self.assertEqual(expected_ages, ages, + "Concurrent updates did not produce correct final result") + + def test_concurrent_updates_same_rows_with_retry(self): + """Test data evolution with multiple threads updating overlapping rows. + + Multiple threads compete to update the same rows. Each thread retries + on conflict until success. The final result should reflect one of the + successful updates for each row (last-writer-wins). + """ + import threading + import traceback + + table = self._create_table() + + # All threads update the same rows but with different values + thread_updates = [ + {'row_ids': [0, 1, 2], 'ages': [101, 201, 301], 'thread_name': 'A'}, + {'row_ids': [0, 1, 2], 'ages': [102, 202, 302], 'thread_name': 'B'}, + {'row_ids': [0, 1, 2], 'ages': [103, 203, 303], 'thread_name': 'C'}, + ] + + errors = [] + completion_order = [] + order_lock = threading.Lock() + + def update_worker(thread_index, update_spec): + max_retries = 30 + for attempt in range(max_retries): + try: + write_builder = table.new_batch_write_builder() + table_update = write_builder.new_update().with_update_type(['age']) + + update_data = pa.Table.from_pydict({ + '_ROW_ID': update_spec['row_ids'], + 'age': update_spec['ages'], + }) + + commit_messages = table_update.update_by_arrow_with_row_id(update_data) + + table_commit = write_builder.new_commit() + table_commit.commit(commit_messages) + table_commit.close() + + with order_lock: + completion_order.append(thread_index) + return + except Exception as e: + print( + "Thread-{} ({}) attempt {} failed: {}".format( + thread_index, update_spec['thread_name'], + attempt + 1, e + ) + ) + if attempt == max_retries - 1: + errors.append( + "Thread-{} ({}) failed after {} retries: {}".format( + thread_index, update_spec['thread_name'], + max_retries, e + ) + ) + + threads = [] + for idx, spec in enumerate(thread_updates): + thread = threading.Thread(target=update_worker, args=(idx, spec)) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join(timeout=120) + + if errors: + self.fail("Some threads failed:\n" + "\n".join(errors)) + + self.assertEqual( + len(completion_order), len(thread_updates), + "Not all threads completed successfully" + ) + + # Verify the final data: the last thread to commit wins + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + result = table_read.to_arrow(splits) + + ages = result['age'].to_pylist() + + # The last thread to successfully commit determines rows 0-2 + last_winner = completion_order[-1] + winner_ages = thread_updates[last_winner]['ages'] + self.assertEqual(winner_ages[0], ages[0], + "Row 0 should reflect last writer's value") + self.assertEqual(winner_ages[1], ages[1], + "Row 1 should reflect last writer's value") + self.assertEqual(winner_ages[2], ages[2], + "Row 2 should reflect last writer's value") + + # Rows 3 and 4 should remain unchanged + self.assertEqual(40, ages[3], "Row 3 should remain unchanged") + self.assertEqual(45, ages[4], "Row 4 should remain unchanged") + if __name__ == '__main__': unittest.main() diff --git a/paimon-python/pypaimon/write/commit_rollback.py b/paimon-python/pypaimon/write/commit_rollback.py new file mode 100644 index 000000000000..d612c90850ca --- /dev/null +++ b/paimon-python/pypaimon/write/commit_rollback.py @@ -0,0 +1,157 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Commit rollback to rollback 'COMPACT' commits for resolving conflicts. + +Follows the design of Java's org.apache.paimon.operation.commit.CommitRollback +and org.apache.paimon.table.RollbackHelper. +""" + +import logging + +logger = logging.getLogger(__name__) + + +class CommitRollback: + """Rollback COMPACT commits to resolve conflicts. + + When a conflict is detected during commit, if the latest snapshot is a + COMPACT commit, it can be rolled back by cleaning all snapshots and tags + with IDs larger than the retained snapshot, following the logic of Java's + RollbackHelper.cleanLargerThan. + """ + + def __init__(self, snapshot_manager, tag_manager, file_io): + """Initialize CommitRollback. + + Args: + snapshot_manager: Manager for reading snapshot metadata. + tag_manager: Manager for tag operations. + file_io: FileIO instance for file operations. + """ + self.snapshot_manager = snapshot_manager + self.tag_manager = tag_manager + self.file_io = file_io + + def try_to_rollback(self, latest_snapshot): + """Try to rollback a COMPACT commit to resolve conflicts. + + Only rolls back COMPACT type commits. Follows Java + CommitRollback.tryToRollback and RollbackHelper.cleanLargerThan: + cleans all snapshots and tags with IDs larger than the retained + snapshot, then updates the LATEST hint. + + Args: + latest_snapshot: The latest snapshot that may need to be rolled back. + + Returns: + True if rollback succeeded, False otherwise. + """ + if latest_snapshot.commit_kind == "COMPACT": + latest_id = latest_snapshot.id + previous_id = latest_id - 1 + try: + previous_snapshot = self.snapshot_manager.get_snapshot_by_id( + previous_id) + if previous_snapshot is None: + logger.warning( + "Cannot rollback: previous snapshot %d does not exist.", + previous_id, + ) + return False + + self._clean_larger_than(previous_snapshot) + logger.info( + "Rolled back COMPACT snapshot %d to snapshot %d " + "to resolve conflict.", + latest_id, previous_id, + ) + return True + except Exception: + logger.warning( + "Failed to rollback COMPACT snapshot %d.", + latest_id, + exc_info=True, + ) + return False + + def _clean_larger_than(self, retained_snapshot): + """Clean snapshots and tags whose ID is larger than the retained snapshot. + + Follows Java RollbackHelper.cleanLargerThan logic: + 1. Clean snapshots with ID > retained and update LATEST hint + 2. Clean tags with ID > retained + + Args: + retained_snapshot: The snapshot to retain; all later ones are removed. + """ + self._clean_snapshots(retained_snapshot) + self._clean_tags(retained_snapshot) + + def _clean_snapshots(self, retained_snapshot): + """Clean snapshots with ID larger than the retained snapshot. + + Follows Java RollbackHelper.cleanSnapshots logic: + updates the LATEST hint first, then deletes snapshot files + from latest down to retained + 1. + + Args: + retained_snapshot: The snapshot to retain. + """ + earliest_snapshot = self.snapshot_manager.try_get_earliest_snapshot() + earliest_id = earliest_snapshot.id if earliest_snapshot is not None else 1 + + latest_content = self.snapshot_manager.read_latest_file() + latest_id = int(latest_content) + + # Update the LATEST hint to point to the retained snapshot + self.file_io.overwrite_file_utf8( + self.snapshot_manager.latest_file, str(retained_snapshot.id)) + + # Delete snapshot files from latest down to retained + 1 + lower_bound = max(earliest_id, retained_snapshot.id + 1) + for snapshot_id in range(latest_id, lower_bound - 1, -1): + snapshot_path = self.snapshot_manager.get_snapshot_path(snapshot_id) + if self.file_io.exists(snapshot_path): + self.file_io.delete_quietly(snapshot_path) + + def _clean_tags(self, retained_snapshot): + """Clean tags whose snapshot ID is larger than the retained snapshot. + + Follows Java RollbackHelper.cleanTags logic: + iterates all tags and deletes those pointing to snapshots + with IDs larger than the retained snapshot. + + Args: + retained_snapshot: The snapshot to retain. + """ + try: + tag_names = self.tag_manager.list_tags() + except Exception: + return + + if not tag_names: + return + + for tag_name in tag_names: + tag = self.tag_manager.get(tag_name) + if tag is None: + continue + if tag.id > retained_snapshot.id: + tag_path = self.tag_manager.tag_path(tag_name) + self.file_io.delete_quietly(tag_path) diff --git a/paimon-python/pypaimon/write/conflict_detection.py b/paimon-python/pypaimon/write/conflict_detection.py new file mode 100644 index 000000000000..1f1ab0d7c9ac --- /dev/null +++ b/paimon-python/pypaimon/write/conflict_detection.py @@ -0,0 +1,420 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Conflict detection for commit operations. + +Follows the design of Java's org.apache.paimon.operation.commit.ConflictDetection. +""" + +import logging +from typing import List, Optional, Set + +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.read.scanner.file_scanner import FileScanner +from pypaimon.snapshot.snapshot import Snapshot +from pypaimon.utils.range import Range + +logger = logging.getLogger(__name__) + + +class ConflictDetection: + """Detects conflicts between base and delta files during commit. + + Follows the design of Java's ConflictDetection class, providing + row ID range conflict checks and row ID from snapshot conflict checks + for Data Evolution tables. + """ + + def __init__(self, data_evolution_enabled, snapshot_manager, + manifest_list_manager, table): + """Initialize ConflictDetection. + + Args: + data_evolution_enabled: Whether data evolution feature is enabled. + snapshot_manager: Manager for reading snapshot metadata. + manifest_list_manager: Manager for reading manifest lists. + table: The FileStoreTable instance. + """ + self.data_evolution_enabled = data_evolution_enabled + self.snapshot_manager = snapshot_manager + self.manifest_list_manager = manifest_list_manager + self.table = table + self._row_id_check_from_snapshot = None + + def set_row_id_check_from_snapshot(self, row_id_check_from_snapshot): + """Set the snapshot ID from which to check row ID conflicts.""" + self._row_id_check_from_snapshot = row_id_check_from_snapshot + + def should_be_overwrite_commit(self, commit_entries): + """Check if the commit should be treated as an overwrite commit. + + Follows Java ConflictDetection.shouldBeOverwriteCommit logic: + returns True if any entry is a DELETE (kind=1), or if + rowIdCheckFromSnapshot is set. + + Args: + commit_entries: The entries being committed. + + Returns: + True if the commit should be treated as OVERWRITE. + """ + for entry in commit_entries: + if entry.kind == 1: + return True + print(self._row_id_check_from_snapshot) + return self._row_id_check_from_snapshot is not None + + def check_conflicts(self, latest_snapshot, base_entries, delta_entries, commit_kind): + """Run all conflict checks and return the first detected conflict. + + Follows Java ConflictDetection.checkConflicts logic: + merges base_entries and delta_entries, then runs conflict checks + on the merged result. + + Args: + latest_snapshot: The latest snapshot at commit time. + base_entries: All entries read from the latest snapshot. + delta_entries: The delta entries being committed. + commit_kind: The kind of commit (e.g. "APPEND", "COMPACT", "OVERWRITE"). + + Returns: + A RuntimeError if a conflict is detected, otherwise None. + """ + all_entries = list(base_entries) + list(delta_entries) + + try: + merged_entries = merge_entries(all_entries) + except Exception as e: + return RuntimeError( + "File deletion conflicts detected! Give up committing. " + str(e)) + print("hello1") + conflict = self.check_row_id_range_conflicts(commit_kind, merged_entries) + if conflict is not None: + return conflict + + return self.check_for_row_id_from_snapshot(latest_snapshot, delta_entries) + + def check_row_id_range_conflicts(self, commit_kind, commit_entries): + """Check for row ID range conflicts among merged entries. + + Follows Java ConflictDetection.checkRowIdRangeConflicts logic: + only enabled when data evolution is active, and checks that + overlapping row ID ranges in non-blob data files are identical. + + Args: + commit_kind: The kind of commit (e.g. "APPEND", "COMPACT"). + commit_entries: The entries being committed. + + Returns: + A RuntimeError if conflict is detected, otherwise None. + """ + if not self.data_evolution_enabled: + return None + if self._row_id_check_from_snapshot is None and commit_kind != "COMPACT": + return None + + entries_with_row_id = [ + entry for entry in commit_entries + if entry.file.first_row_id is not None + ] + + if not entries_with_row_id: + return None + + merged_groups = _merge_overlapping_row_id_ranges(entries_with_row_id) + + for group in merged_groups: + data_files = [ + entry for entry in group + if not DataFileMeta.is_blob_file(entry.file.file_name) + ] + if not _are_all_row_id_ranges_same(data_files): + file_descriptions = [ + "{name}(rowId={row_id}, count={count})".format( + name=entry.file.file_name, + row_id=entry.file.first_row_id, + count=entry.file.row_count, + ) + for entry in data_files + ] + return RuntimeError( + "For Data Evolution table, multiple 'MERGE INTO' and 'COMPACT' " + "operations have encountered conflicts, data files: " + + str(file_descriptions)) + + return None + + def check_for_row_id_from_snapshot(self, latest_snapshot, commit_entries): + """Check for row ID conflicts from a specific snapshot onwards. + + Follows Java ConflictDetection.checkForRowIdFromSnapshot logic: + collects row ID ranges from delta entries, then checks if any + incremental changes between the check snapshot and latest snapshot + have overlapping row ID ranges. + + Args: + latest_snapshot: The latest snapshot at commit time. + commit_entries: The delta entries being committed. + + Returns: + A RuntimeError if conflict is detected, otherwise None. + """ + if not self.data_evolution_enabled: + return None + if self._row_id_check_from_snapshot is None: + return None + + changed_parts = changed_partitions(commit_entries) + + history_id_ranges = [] + for entry in commit_entries: + first_row_id = entry.file.first_row_id + row_count = entry.file.row_count + if first_row_id is not None: + history_id_ranges.append( + Range(first_row_id, first_row_id + row_count - 1)) + + check_snapshot = self.snapshot_manager.get_snapshot_by_id( + self._row_id_check_from_snapshot) + if check_snapshot is None or check_snapshot.next_row_id is None: + raise RuntimeError( + "Next row id cannot be null for snapshot " + "{snapshot}.".format(snapshot=self._row_id_check_from_snapshot)) + check_next_row_id = check_snapshot.next_row_id + + for snapshot_id in range( + self._row_id_check_from_snapshot + 1, + latest_snapshot.id + 1): + snapshot = self.snapshot_manager.get_snapshot_by_id(snapshot_id) + if snapshot is None: + continue + if snapshot.commit_kind == "COMPACT": + continue + + incremental_entries = self._read_incremental_entries( + snapshot, changed_parts) + for entry in incremental_entries: + file_range = entry.file.row_id_range() + if file_range is None: + continue + if file_range.from_ < check_next_row_id: + for history_range in history_id_ranges: + if history_range.overlaps(file_range): + print("conflict2") + return RuntimeError( + "For Data Evolution table, multiple 'MERGE INTO' " + "operations have encountered conflicts, updating " + "the same file, which can render some updates " + "ineffective.") + + return None + + def _read_incremental_entries(self, snapshot, partition_filter): + """Read incremental manifest entries from a snapshot's delta manifest list. + + Follows Java CommitScanner.readIncrementalEntries logic: + reads the delta manifest list and filters entries by partition. + + Args: + snapshot: The snapshot to read incremental entries from. + partition_filter: Set of partition tuples to filter by. + + Returns: + List of ManifestEntry matching the partition filter. + """ + delta_manifests = self.manifest_list_manager.read_delta(snapshot) + if not delta_manifests: + return [] + + all_entries = FileScanner( + self.table, lambda: [], None + ).read_manifest_entries(delta_manifests) + + if not partition_filter: + return all_entries + + return [ + entry for entry in all_entries + if tuple(entry.partition.values) in partition_filter + ] + + +def changed_partitions(commit_entries): + """Extract unique changed partitions from commit entries. + + Follows Java ManifestEntryChanges.changedPartitions logic. + + Args: + commit_entries: List of ManifestEntry to extract partitions from. + + Returns: + Set of partition tuples. + """ + partitions = set() + for entry in commit_entries: + partition_key = tuple(entry.partition.values) + partitions.add(partition_key) + return partitions + + +def _merge_overlapping_row_id_ranges(entries): + """Merge entries with overlapping row ID ranges into groups. + + Follows Java RangeHelper.mergeOverlappingRanges logic: + sorts entries by row ID range start, then merges overlapping groups. + + Args: + entries: List of ManifestEntry with non-null first_row_id. + + Returns: + List of groups, where each group is a list of entries + with overlapping row ID ranges. + """ + if not entries: + return [] + + indexed = [] + for i, entry in enumerate(entries): + row_range = entry.file.row_id_range() + if row_range is not None: + indexed.append((entry, row_range, i)) + + if not indexed: + return [] + + indexed.sort(key=lambda item: (item[1].from_, item[1].to)) + + groups = [] + current_group = [indexed[0]] + current_end = indexed[0][1].to + + for i in range(1, len(indexed)): + entry, row_range, original_index = indexed[i] + if row_range.from_ <= current_end: + current_group.append(indexed[i]) + if row_range.to > current_end: + current_end = row_range.to + else: + groups.append(current_group) + current_group = [indexed[i]] + current_end = row_range.to + + groups.append(current_group) + + result = [] + for group in groups: + group.sort(key=lambda item: item[2]) + result.append([item[0] for item in group]) + + return result + + +def _are_all_row_id_ranges_same(entries): + """Check if all entries have the same row ID range. + + Follows Java RangeHelper.areAllRangesSame logic. + + Args: + entries: List of ManifestEntry to check. + + Returns: + True if all entries have the same row ID range, False otherwise. + """ + if not entries: + return True + + first_range = entries[0].file.row_id_range() + if first_range is None: + return False + + for entry in entries[1:]: + entry_range = entry.file.row_id_range() + if entry_range is None: + return False + if entry_range.from_ != first_range.from_ or entry_range.to != first_range.to: + return False + + return True + + +def _entry_identifier(entry): + """Build a unique identifier tuple for a ManifestEntry. + + Follows Java FileEntry.Identifier logic: uses partition, bucket, level, + fileName, extraFiles, embeddedIndex and externalPath to identify a file. + + Args: + entry: A ManifestEntry instance. + + Returns: + A hashable identifier tuple. + """ + partition_key = tuple(entry.partition.values) + extra_files = tuple(entry.file.extra_files) if entry.file.extra_files else () + embedded_index = (bytes(entry.file.embedded_index) + if entry.file.embedded_index is not None else None) + return ( + partition_key, + entry.bucket, + entry.file.level, + entry.file.file_name, + extra_files, + embedded_index, + entry.file.external_path, + ) + + +def merge_entries(entries): + """Merge manifest entries: ADD and DELETE of the same file cancel each other. + + Follows Java FileEntry.mergeEntries logic: + - ADD: if identifier already in map, raise error; otherwise add to map + - DELETE: if identifier already in map, remove both (cancel); otherwise add to map + + Args: + entries: Iterable of ManifestEntry. + + Returns: + List of merged ManifestEntry values. + + Raises: + RuntimeError: If trying to add a file that is already in the map. + """ + entry_map = {} + insertion_order = [] + + for entry in entries: + identifier = _entry_identifier(entry) + if entry.kind == 0: # ADD + if identifier in entry_map: + raise RuntimeError( + "Trying to add file {} which is already added.".format( + entry.file.file_name)) + entry_map[identifier] = entry + insertion_order.append(identifier) + elif entry.kind == 1: # DELETE + if identifier in entry_map: + del entry_map[identifier] + else: + entry_map[identifier] = entry + insertion_order.append(identifier) + else: + raise RuntimeError("Unknown entry kind: {}".format(entry.kind)) + + return [entry_map[key] for key in insertion_order if key in entry_map] diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index 80eb858087ef..db7a9c19338b 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -27,6 +27,9 @@ from pypaimon.manifest.manifest_list_manager import ManifestListManager from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.tag.tag_manager import TagManager +from pypaimon.write.commit_rollback import CommitRollback +from pypaimon.write.conflict_detection import ConflictDetection from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta from pypaimon.manifest.schema.simple_stats import SimpleStats from pypaimon.read.scanner.file_scanner import FileScanner @@ -93,6 +96,28 @@ def __init__(self, snapshot_commit: SnapshotCommit, table, commit_user: str): self.commit_min_retry_wait = table.options.commit_min_retry_wait() self.commit_max_retry_wait = table.options.commit_max_retry_wait() + self.conflict_detection = ConflictDetection( + data_evolution_enabled=table.options.data_evolution_enabled(), + snapshot_manager=self.snapshot_manager, + manifest_list_manager=self.manifest_list_manager, + table=table, + ) + self.tag_manager = TagManager( + file_io=table.file_io, + table_path=table.table_path, + ) + self.rollback = CommitRollback( + snapshot_manager=self.snapshot_manager, + tag_manager=self.tag_manager, + file_io=table.file_io, + ) + + def row_id_check_conflict(self, row_id_check_from_snapshot): + """Set the snapshot ID from which to check row ID conflicts.""" + self.conflict_detection.set_row_id_check_from_snapshot( + row_id_check_from_snapshot) + return self + def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in normal append mode.""" if not commit_messages: @@ -116,9 +141,20 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): )) logger.info("Finished collecting changes, including: %d entries", len(commit_entries)) - self._try_commit(commit_kind="APPEND", + + commit_kind = "APPEND" + detect_conflicts = False + allow_rollback = False + if self.conflict_detection.should_be_overwrite_commit(commit_entries): + # commit_kind = "OVERWRITE" + detect_conflicts = True + allow_rollback = True + + self._try_commit(commit_kind=commit_kind, commit_identifier=commit_identifier, - commit_entries_plan=lambda snapshot: commit_entries) + commit_entries_plan=lambda snapshot: commit_entries, + detect_conflicts=detect_conflicts, + allow_rollback=allow_rollback) def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in overwrite mode.""" @@ -149,7 +185,9 @@ def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], c commit_kind="OVERWRITE", commit_identifier=commit_identifier, commit_entries_plan=lambda snapshot: self._generate_overwrite_entries( - snapshot, partition_filter, commit_messages) + snapshot, partition_filter, commit_messages), + detect_conflicts=True, + allow_rollback=True, ) def drop_partitions(self, partitions: List[Dict[str, str]], commit_identifier: int) -> None: @@ -187,10 +225,13 @@ def drop_partitions(self, partitions: List[Dict[str, str]], commit_identifier: i commit_kind="OVERWRITE", commit_identifier=commit_identifier, commit_entries_plan=lambda snapshot: self._generate_overwrite_entries( - snapshot, partition_filter, []) + snapshot, partition_filter, []), + detect_conflicts=True, + allow_rollback=True, ) - def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan): + def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, + detect_conflicts=False, allow_rollback=False): import threading retry_count = 0 @@ -211,7 +252,9 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan): commit_kind=commit_kind, commit_entries=commit_entries, commit_identifier=commit_identifier, - latest_snapshot=latest_snapshot + latest_snapshot=latest_snapshot, + detect_conflicts=detect_conflicts, + allow_rollback=allow_rollback, ) if result.is_success(): @@ -267,11 +310,13 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan): def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str, commit_entries: List[ManifestEntry], commit_identifier: int, - latest_snapshot: Optional[Snapshot]) -> CommitResult: + latest_snapshot: Optional[Snapshot], + detect_conflicts: bool = False, + allow_rollback: bool = False) -> CommitResult: start_millis = int(time.time() * 1000) if self._is_duplicate_commit(retry_result, latest_snapshot, commit_identifier, commit_kind): return SuccessResult() - + unique_id = uuid.uuid4() base_manifest_list = f"manifest-list-{unique_id}-0" delta_manifest_list = f"manifest-list-{unique_id}-1" @@ -296,8 +341,20 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str # Assign row IDs to new files and get the next row ID for the snapshot commit_entries, next_row_id = self._assign_row_tracking_meta(first_row_id_start, commit_entries) + # Conflict detection: read base entries from latest snapshot, then check conflicts + if detect_conflicts and latest_snapshot is not None: + base_entries = self._read_all_entries_from_changed_partitions( + latest_snapshot, commit_entries) + conflict_exception = self.conflict_detection.check_conflicts( + latest_snapshot, base_entries, commit_entries, commit_kind) + + if conflict_exception is not None: + if allow_rollback: + if self.rollback.try_to_rollback(latest_snapshot): + return RetryResult(latest_snapshot, conflict_exception) + raise conflict_exception + try: - # TODO: implement noConflictsOrFail logic new_manifest_file_meta = self._write_manifest_file(commit_entries, new_manifest_file) self.manifest_list_manager.write(delta_manifest_list, [new_manifest_file_meta]) @@ -452,6 +509,38 @@ def _is_duplicate_commit(self, retry_result, latest_snapshot, commit_identifier, return True return False + def _read_all_entries_from_changed_partitions(self, latest_snapshot, delta_entries): + """Read all entries from the latest snapshot for partitions that are changed. + + Follows Java CommitScanner.readAllEntriesFromChangedPartitions logic: + extracts changed partitions from delta entries, then reads all entries + from the latest snapshot filtered by those partitions. + + Args: + latest_snapshot: The latest snapshot to read entries from. + delta_entries: The delta entries being committed, used to determine + which partitions have changed. + + Returns: + List of ManifestEntry from the latest snapshot for changed partitions. + """ + if latest_snapshot is None: + return [] + + changed_partition_set = set() + for entry in delta_entries: + changed_partition_set.add(tuple(entry.partition.values)) + + all_manifests = self.manifest_list_manager.read_all(latest_snapshot) + all_entries = FileScanner( + self.table, lambda: [], None + ).read_manifest_entries(all_manifests) + + return [ + entry for entry in all_entries + if tuple(entry.partition.values) in changed_partition_set + ] + def _generate_overwrite_entries(self, latestSnapshot, partition_filter, commit_messages): """Generate commit entries for OVERWRITE mode based on latest snapshot.""" entries = [] diff --git a/paimon-python/pypaimon/write/table_commit.py b/paimon-python/pypaimon/write/table_commit.py index 1eafafefc09f..9c37a27eb011 100644 --- a/paimon-python/pypaimon/write/table_commit.py +++ b/paimon-python/pypaimon/write/table_commit.py @@ -73,6 +73,22 @@ def abort(self, commit_messages: List[CommitMessage]): def close(self): self.file_store_commit.close() + def row_id_check_conflict(self, row_id_check_from_snapshot): + """Set the snapshot ID from which to check row ID conflicts. + + Follows Java TableCommitImpl.rowIdCheckConflict logic: + forwards the call to FileStoreCommit.row_id_check_conflict(). + + Args: + row_id_check_from_snapshot: The snapshot ID from which to start + checking row ID conflicts, or None to disable. + + Returns: + self for method chaining. + """ + self.file_store_commit.row_id_check_conflict(row_id_check_from_snapshot) + return self + def _check_committed(self): if self.batch_committed: raise RuntimeError("BatchTableCommit only supports one-time committing.") diff --git a/paimon-python/pypaimon/write/table_update.py b/paimon-python/pypaimon/write/table_update.py index 9ad86aa95dcd..596f5c447aea 100644 --- a/paimon-python/pypaimon/write/table_update.py +++ b/paimon-python/pypaimon/write/table_update.py @@ -122,10 +122,10 @@ def new_shard_updator(self, shard_num: int, total_shard_count: int): total_shard_count, ) - def update_by_arrow_with_row_id(self, table: pa.Table) -> List[CommitMessage]: + def update_by_arrow_with_row_id(self, table: pa.Table) -> (List[CommitMessage], int): update_by_row_id = TableUpdateByRowId(self.table, self.commit_user) update_by_row_id.update_columns(table, self.update_cols) - return update_by_row_id.commit_messages + return update_by_row_id.commit_messages, update_by_row_id.snapshot_id class ShardTableUpdator: diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py b/paimon-python/pypaimon/write/table_update_by_row_id.py index b027564ff398..ffee5ddceb2a 100644 --- a/paimon-python/pypaimon/write/table_update_by_row_id.py +++ b/paimon-python/pypaimon/write/table_update_by_row_id.py @@ -47,7 +47,8 @@ def __init__(self, table, commit_user: str): self.commit_user = commit_user # Load existing first_row_ids and build partition map - (self.first_row_ids, + (self.snapshot_id, + self.first_row_ids, self.first_row_id_to_partition_map, self.first_row_id_to_row_count_map, self.total_row_count, @@ -76,7 +77,9 @@ def _load_existing_files_info(self): total_row_count = sum(first_row_id_to_row_count_map.values()) - return (sorted(list(set(first_row_ids))), + snapshot_id = self.table.snapshot_manager().get_latest_snapshot().id + return (snapshot_id, + sorted(list(set(first_row_ids))), first_row_id_to_partition_map, first_row_id_to_row_count_map, total_row_count, From d707c9ee83b8f3e3f2d1f10df7df8c5e64e1c40c Mon Sep 17 00:00:00 2001 From: umi Date: Sat, 28 Feb 2026 14:52:27 +0800 Subject: [PATCH 2/4] rollback --- paimon-python/pypaimon/api/api_request.py | 10 ++ paimon-python/pypaimon/api/resource_paths.py | 4 + paimon-python/pypaimon/api/rest_api.py | 24 +++- paimon-python/pypaimon/catalog/catalog.py | 17 +++ .../pypaimon/catalog/catalog_environment.py | 17 +++ .../pypaimon/catalog/rest/rest_catalog.py | 22 +++ .../pypaimon/catalog/table_rollback.py | 62 ++++++++ paimon-python/pypaimon/table/instant.py | 136 ++++++++++++++++++ .../pypaimon/write/commit_rollback.py | 119 ++------------- .../pypaimon/write/file_store_commit.py | 17 +-- 10 files changed, 310 insertions(+), 118 deletions(-) create mode 100644 paimon-python/pypaimon/catalog/table_rollback.py create mode 100644 paimon-python/pypaimon/table/instant.py diff --git a/paimon-python/pypaimon/api/api_request.py b/paimon-python/pypaimon/api/api_request.py index f2d062f8d005..f191fb03d3da 100644 --- a/paimon-python/pypaimon/api/api_request.py +++ b/paimon-python/pypaimon/api/api_request.py @@ -26,6 +26,7 @@ from pypaimon.schema.schema_change import SchemaChange from pypaimon.snapshot.snapshot import Snapshot from pypaimon.snapshot.snapshot_commit import PartitionStatistics +from pypaimon.table.instant import Instant class RESTRequest(ABC): @@ -84,3 +85,12 @@ class AlterTableRequest(RESTRequest): FIELD_CHANGES = "changes" changes: List[SchemaChange] = json_field(FIELD_CHANGES) + + +@dataclass +class RollbackTableRequest(RESTRequest): + FIELD_INSTANT = "instant" + FIELD_FROM_SNAPSHOT = "fromSnapshot" + + instant: Instant = json_field(FIELD_INSTANT) + from_snapshot: Optional[int] = json_field(FIELD_FROM_SNAPSHOT) diff --git a/paimon-python/pypaimon/api/resource_paths.py b/paimon-python/pypaimon/api/resource_paths.py index 79adc0ade3a1..16967d16e23e 100644 --- a/paimon-python/pypaimon/api/resource_paths.py +++ b/paimon-python/pypaimon/api/resource_paths.py @@ -70,3 +70,7 @@ def rename_table(self) -> str: def commit_table(self, database_name: str, table_name: str) -> str: return ("{}/{}/{}/{}/{}/commit".format(self.base_path, self.DATABASES, RESTUtil.encode_string(database_name), self.TABLES, RESTUtil.encode_string(table_name))) + + def rollback_table(self, database_name: str, table_name: str) -> str: + return ("{}/{}/{}/{}/{}/rollback".format(self.base_path, self.DATABASES, RESTUtil.encode_string(database_name), + self.TABLES, RESTUtil.encode_string(table_name))) diff --git a/paimon-python/pypaimon/api/rest_api.py b/paimon-python/pypaimon/api/rest_api.py index dc2c02104b86..069220d68775 100755 --- a/paimon-python/pypaimon/api/rest_api.py +++ b/paimon-python/pypaimon/api/rest_api.py @@ -20,7 +20,8 @@ from pypaimon.api.api_request import (AlterDatabaseRequest, AlterTableRequest, CommitTableRequest, CreateDatabaseRequest, - CreateTableRequest, RenameTableRequest) + CreateTableRequest, RenameTableRequest, + RollbackTableRequest) from pypaimon.api.api_response import (CommitTableResponse, ConfigResponse, GetDatabaseResponse, GetTableResponse, GetTableTokenResponse, @@ -358,6 +359,27 @@ def commit_snapshot( ) return response.is_success() + def rollback_to(self, identifier, instant, from_snapshot=None): + """Rollback table to the given instant. + + Args: + identifier: The table identifier. + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + NoSuchResourceException: If the table, snapshot or tag does not exist. + ForbiddenException: If no permission to access this table. + """ + database_name, table_name = self.__validate_identifier(identifier) + request = RollbackTableRequest(instant=instant, from_snapshot=from_snapshot) + self.client.post( + self.resource_paths.rollback_table(database_name, table_name), + request, + self.rest_auth_function + ) + @staticmethod def __validate_identifier(identifier: Identifier): if not identifier: diff --git a/paimon-python/pypaimon/catalog/catalog.py b/paimon-python/pypaimon/catalog/catalog.py index 729522450ba7..aa914d297825 100644 --- a/paimon-python/pypaimon/catalog/catalog.py +++ b/paimon-python/pypaimon/catalog/catalog.py @@ -95,6 +95,23 @@ def commit_snapshot( """ + def rollback_to(self, identifier, instant, from_snapshot=None): + """Rollback table by the given identifier and instant. + + Args: + identifier: Path of the table (Identifier instance). + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + TableNotExistException: If the table does not exist. + UnsupportedOperationError: If the catalog does not support version management. + """ + raise NotImplementedError( + "rollback_to is not supported by this catalog." + ) + def drop_partitions( self, identifier: Union[str, Identifier], diff --git a/paimon-python/pypaimon/catalog/catalog_environment.py b/paimon-python/pypaimon/catalog/catalog_environment.py index 762a42dd6a7c..b820c7945b3c 100644 --- a/paimon-python/pypaimon/catalog/catalog_environment.py +++ b/paimon-python/pypaimon/catalog/catalog_environment.py @@ -19,6 +19,7 @@ from typing import Optional from pypaimon.catalog.catalog_loader import CatalogLoader +from pypaimon.catalog.table_rollback import TableRollback, CatalogTableRollback from pypaimon.common.identifier import Identifier from pypaimon.snapshot.catalog_snapshot_commit import CatalogSnapshotCommit from pypaimon.snapshot.renaming_snapshot_commit import RenamingSnapshotCommit @@ -60,6 +61,22 @@ def snapshot_commit(self, snapshot_manager) -> Optional[SnapshotCommit]: # to create locks based on the catalog lock context return RenamingSnapshotCommit(snapshot_manager) + def catalog_table_rollback(self): + """Create a TableRollback instance based on the catalog environment configuration. + + If catalog_loader is available and version management is supported, + returns a CatalogTableRollback that delegates to catalog.rollback_to. + Otherwise, returns None. + + Returns: + A TableRollback instance or None. + """ + if self.catalog_loader is not None and self.supports_version_management: + catalog = self.catalog_loader.load() + identifier = self.identifier + return CatalogTableRollback(catalog, identifier) + return None + def copy(self, identifier: Identifier) -> 'CatalogEnvironment': """ Create a copy of this CatalogEnvironment with a different identifier. diff --git a/paimon-python/pypaimon/catalog/rest/rest_catalog.py b/paimon-python/pypaimon/catalog/rest/rest_catalog.py index 9cf138c2543a..dab7633bc91c 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_catalog.py +++ b/paimon-python/pypaimon/catalog/rest/rest_catalog.py @@ -256,6 +256,28 @@ def alter_table( except ForbiddenException as e: raise TableNoPermissionException(identifier) from e + def rollback_to(self, identifier, instant, from_snapshot=None): + """Rollback table by the given identifier and instant. + + Args: + identifier: Path of the table (Identifier or string). + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + TableNotExistException: If the table does not exist. + TableNoPermissionException: If no permission to access this table. + """ + if not isinstance(identifier, Identifier): + identifier = Identifier.from_string(identifier) + try: + self.rest_api.rollback_to(identifier, instant, from_snapshot) + except NoSuchResourceException as e: + raise TableNotExistException(identifier) from e + except ForbiddenException as e: + raise TableNoPermissionException(identifier) from e + def load_table_metadata(self, identifier: Identifier) -> TableMetadata: try: response = self.rest_api.get_table(identifier) diff --git a/paimon-python/pypaimon/catalog/table_rollback.py b/paimon-python/pypaimon/catalog/table_rollback.py new file mode 100644 index 000000000000..70f155ce40e6 --- /dev/null +++ b/paimon-python/pypaimon/catalog/table_rollback.py @@ -0,0 +1,62 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod + + +class TableRollback(ABC): + """Rollback table to instant from snapshot. + """ + + @abstractmethod + def rollback_to(self, instant, from_snapshot=None): + """Rollback table to the given instant. + + Args: + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + """ + + +class CatalogTableRollback(TableRollback): + """ + Internal TableRollback implementation that delegates to catalog.rollback_to. + """ + + def __init__(self, catalog, identifier): + self._catalog = catalog + self._identifier = identifier + + def rollback_to(self, instant, from_snapshot=None): + """Rollback table to the given instant via catalog. + + Args: + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + RuntimeError: If the table does not exist in the catalog. + """ + try: + self._catalog.rollback_to(self._identifier, instant, from_snapshot) + except Exception as e: + raise RuntimeError( + "Failed to rollback table {}: {}".format( + self._identifier, e)) from e diff --git a/paimon-python/pypaimon/table/instant.py b/paimon-python/pypaimon/table/instant.py new file mode 100644 index 000000000000..0cf4383213c4 --- /dev/null +++ b/paimon-python/pypaimon/table/instant.py @@ -0,0 +1,136 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod + + +class Instant(ABC): + """Table rollback instant, corresponding to Java's Instant interface. + + Supports polymorphic JSON serialization via to_dict/from_dict, + matching Java's Jackson @JsonTypeInfo/@JsonSubTypes annotations. + + Serialization format: + SnapshotInstant: {"type": "snapshot", "snapshotId": 123} + TagInstant: {"type": "tag", "tagName": "test_tag"} + """ + + FIELD_TYPE = "type" + TYPE_SNAPSHOT = "snapshot" + TYPE_TAG = "tag" + + @staticmethod + def snapshot(snapshot_id): + """Create a SnapshotInstant. + + Args: + snapshot_id: The snapshot ID to rollback to. + + Returns: + A SnapshotInstant instance. + """ + return SnapshotInstant(snapshot_id) + + @staticmethod + def tag(tag_name): + """Create a TagInstant. + + Args: + tag_name: The tag name to rollback to. + + Returns: + A TagInstant instance. + """ + return TagInstant(tag_name) + + @abstractmethod + def to_dict(self): + """Serialize this Instant to a dictionary for JSON output.""" + + @staticmethod + def from_dict(data): + """Deserialize an Instant from a dictionary. + + Args: + data: A dictionary with a 'type' field indicating the instant type. + + Returns: + A SnapshotInstant or TagInstant instance. + + Raises: + ValueError: If the type field is missing or unknown. + """ + instant_type = data.get(Instant.FIELD_TYPE) + if instant_type == Instant.TYPE_SNAPSHOT: + return SnapshotInstant(data[SnapshotInstant.FIELD_SNAPSHOT_ID]) + elif instant_type == Instant.TYPE_TAG: + return TagInstant(data[TagInstant.FIELD_TAG_NAME]) + else: + raise ValueError("Unknown instant type: {}".format(instant_type)) + + +class SnapshotInstant(Instant): + """Snapshot instant for table rollback.""" + + FIELD_SNAPSHOT_ID = "snapshotId" + + def __init__(self, snapshot_id): + self.snapshot_id = snapshot_id + + def to_dict(self): + return { + Instant.FIELD_TYPE: Instant.TYPE_SNAPSHOT, + self.FIELD_SNAPSHOT_ID: self.snapshot_id, + } + + def __eq__(self, other): + if not isinstance(other, SnapshotInstant): + return False + return self.snapshot_id == other.snapshot_id + + def __hash__(self): + return hash(self.snapshot_id) + + def __repr__(self): + return "SnapshotInstant(snapshot_id={})".format(self.snapshot_id) + + +class TagInstant(Instant): + """Tag instant for table rollback.""" + + FIELD_TAG_NAME = "tagName" + + def __init__(self, tag_name): + self.tag_name = tag_name + + def to_dict(self): + return { + Instant.FIELD_TYPE: Instant.TYPE_TAG, + self.FIELD_TAG_NAME: self.tag_name, + } + + def __eq__(self, other): + if not isinstance(other, TagInstant): + return False + return self.tag_name == other.tag_name + + def __hash__(self): + return hash(self.tag_name) + + def __repr__(self): + return "TagInstant(tag_name={})".format(self.tag_name) diff --git a/paimon-python/pypaimon/write/commit_rollback.py b/paimon-python/pypaimon/write/commit_rollback.py index d612c90850ca..80a440dbf40e 100644 --- a/paimon-python/pypaimon/write/commit_rollback.py +++ b/paimon-python/pypaimon/write/commit_rollback.py @@ -18,43 +18,34 @@ """Commit rollback to rollback 'COMPACT' commits for resolving conflicts. -Follows the design of Java's org.apache.paimon.operation.commit.CommitRollback -and org.apache.paimon.table.RollbackHelper. +Follows the design of Java's org.apache.paimon.operation.commit.CommitRollback. """ -import logging - -logger = logging.getLogger(__name__) +from pypaimon.table.instant import Instant class CommitRollback: """Rollback COMPACT commits to resolve conflicts. When a conflict is detected during commit, if the latest snapshot is a - COMPACT commit, it can be rolled back by cleaning all snapshots and tags - with IDs larger than the retained snapshot, following the logic of Java's - RollbackHelper.cleanLargerThan. + COMPACT commit, it can be rolled back via TableRollback, following the + logic of Java's CommitRollback. """ - def __init__(self, snapshot_manager, tag_manager, file_io): + def __init__(self, table_rollback): """Initialize CommitRollback. Args: - snapshot_manager: Manager for reading snapshot metadata. - tag_manager: Manager for tag operations. - file_io: FileIO instance for file operations. + table_rollback: A TableRollback instance used to perform the rollback. """ - self.snapshot_manager = snapshot_manager - self.tag_manager = tag_manager - self.file_io = file_io + self._table_rollback = table_rollback def try_to_rollback(self, latest_snapshot): """Try to rollback a COMPACT commit to resolve conflicts. - Only rolls back COMPACT type commits. Follows Java - CommitRollback.tryToRollback and RollbackHelper.cleanLargerThan: - cleans all snapshots and tags with IDs larger than the retained - snapshot, then updates the LATEST hint. + Only rolls back COMPACT type commits. Delegates to TableRollback + to rollback to the previous snapshot (latest - 1), passing the + latest snapshot ID as from_snapshot. Args: latest_snapshot: The latest snapshot that may need to be rolled back. @@ -64,94 +55,10 @@ def try_to_rollback(self, latest_snapshot): """ if latest_snapshot.commit_kind == "COMPACT": latest_id = latest_snapshot.id - previous_id = latest_id - 1 try: - previous_snapshot = self.snapshot_manager.get_snapshot_by_id( - previous_id) - if previous_snapshot is None: - logger.warning( - "Cannot rollback: previous snapshot %d does not exist.", - previous_id, - ) - return False - - self._clean_larger_than(previous_snapshot) - logger.info( - "Rolled back COMPACT snapshot %d to snapshot %d " - "to resolve conflict.", - latest_id, previous_id, - ) + self._table_rollback.rollback_to( + Instant.snapshot(latest_id - 1), latest_id) return True except Exception: - logger.warning( - "Failed to rollback COMPACT snapshot %d.", - latest_id, - exc_info=True, - ) + pass return False - - def _clean_larger_than(self, retained_snapshot): - """Clean snapshots and tags whose ID is larger than the retained snapshot. - - Follows Java RollbackHelper.cleanLargerThan logic: - 1. Clean snapshots with ID > retained and update LATEST hint - 2. Clean tags with ID > retained - - Args: - retained_snapshot: The snapshot to retain; all later ones are removed. - """ - self._clean_snapshots(retained_snapshot) - self._clean_tags(retained_snapshot) - - def _clean_snapshots(self, retained_snapshot): - """Clean snapshots with ID larger than the retained snapshot. - - Follows Java RollbackHelper.cleanSnapshots logic: - updates the LATEST hint first, then deletes snapshot files - from latest down to retained + 1. - - Args: - retained_snapshot: The snapshot to retain. - """ - earliest_snapshot = self.snapshot_manager.try_get_earliest_snapshot() - earliest_id = earliest_snapshot.id if earliest_snapshot is not None else 1 - - latest_content = self.snapshot_manager.read_latest_file() - latest_id = int(latest_content) - - # Update the LATEST hint to point to the retained snapshot - self.file_io.overwrite_file_utf8( - self.snapshot_manager.latest_file, str(retained_snapshot.id)) - - # Delete snapshot files from latest down to retained + 1 - lower_bound = max(earliest_id, retained_snapshot.id + 1) - for snapshot_id in range(latest_id, lower_bound - 1, -1): - snapshot_path = self.snapshot_manager.get_snapshot_path(snapshot_id) - if self.file_io.exists(snapshot_path): - self.file_io.delete_quietly(snapshot_path) - - def _clean_tags(self, retained_snapshot): - """Clean tags whose snapshot ID is larger than the retained snapshot. - - Follows Java RollbackHelper.cleanTags logic: - iterates all tags and deletes those pointing to snapshots - with IDs larger than the retained snapshot. - - Args: - retained_snapshot: The snapshot to retain. - """ - try: - tag_names = self.tag_manager.list_tags() - except Exception: - return - - if not tag_names: - return - - for tag_name in tag_names: - tag = self.tag_manager.get(tag_name) - if tag is None: - continue - if tag.id > retained_snapshot.id: - tag_path = self.tag_manager.tag_path(tag_name) - self.file_io.delete_quietly(tag_path) diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index db7a9c19338b..a41124d2dc5e 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -27,7 +27,6 @@ from pypaimon.manifest.manifest_list_manager import ManifestListManager from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry -from pypaimon.tag.tag_manager import TagManager from pypaimon.write.commit_rollback import CommitRollback from pypaimon.write.conflict_detection import ConflictDetection from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta @@ -102,15 +101,11 @@ def __init__(self, snapshot_commit: SnapshotCommit, table, commit_user: str): manifest_list_manager=self.manifest_list_manager, table=table, ) - self.tag_manager = TagManager( - file_io=table.file_io, - table_path=table.table_path, - ) - self.rollback = CommitRollback( - snapshot_manager=self.snapshot_manager, - tag_manager=self.tag_manager, - file_io=table.file_io, - ) + + self.rollback = None + table_rollback = table.catalog_environment.catalog_table_rollback() + if table_rollback is not None: + self.rollback = CommitRollback(table_rollback) def row_id_check_conflict(self, row_id_check_from_snapshot): """Set the snapshot ID from which to check row ID conflicts.""" @@ -349,7 +344,7 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str latest_snapshot, base_entries, commit_entries, commit_kind) if conflict_exception is not None: - if allow_rollback: + if allow_rollback and self.rollback is not None: if self.rollback.try_to_rollback(latest_snapshot): return RetryResult(latest_snapshot, conflict_exception) raise conflict_exception From 63b38b1d223d12a41bae8ff2c8d70b0516a3e310 Mon Sep 17 00:00:00 2001 From: umi Date: Sat, 28 Feb 2026 17:28:08 +0800 Subject: [PATCH 3/4] range --- .../manifest/manifest_file_manager.py | 15 +- .../pypaimon/manifest/schema/file_entry.py | 129 +++++++++++ .../manifest/schema/manifest_entry.py | 3 +- paimon-python/pypaimon/schema/data_types.py | 69 +++++- .../pypaimon/table/row/generic_row.py | 10 + paimon-python/pypaimon/utils/range_helper.py | 134 +++++++++++ .../pypaimon/write/conflict_detection.py | 218 ++---------------- .../pypaimon/write/file_store_commit.py | 52 +++-- 8 files changed, 404 insertions(+), 226 deletions(-) create mode 100644 paimon-python/pypaimon/manifest/schema/file_entry.py create mode 100644 paimon-python/pypaimon/utils/range_helper.py diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py b/paimon-python/pypaimon/manifest/manifest_file_manager.py index 0ed50918253c..5975fcbc9fa9 100644 --- a/paimon-python/pypaimon/manifest/manifest_file_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py @@ -53,17 +53,6 @@ def read_entries_parallel(self, manifest_files: List[ManifestFileMeta], manifest def _process_single_manifest(manifest_file: ManifestFileMeta) -> List[ManifestEntry]: return self.read(manifest_file.file_name, manifest_entry_filter, drop_stats) - def _entry_identifier(e: ManifestEntry) -> tuple: - return ( - tuple(e.partition.values), - e.bucket, - e.file.level, - e.file.file_name, - tuple(e.file.extra_files) if e.file.extra_files else (), - e.file.embedded_index, - e.file.external_path, - ) - deleted_entry_keys = set() added_entries = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -73,11 +62,11 @@ def _entry_identifier(e: ManifestEntry) -> tuple: if entry.kind == 0: # ADD added_entries.append(entry) else: # DELETE - deleted_entry_keys.add(_entry_identifier(entry)) + deleted_entry_keys.add(entry.identifier()) final_entries = [ entry for entry in added_entries - if _entry_identifier(entry) not in deleted_entry_keys + if entry.identifier() not in deleted_entry_keys ] return final_entries diff --git a/paimon-python/pypaimon/manifest/schema/file_entry.py b/paimon-python/pypaimon/manifest/schema/file_entry.py new file mode 100644 index 000000000000..3abab4495c28 --- /dev/null +++ b/paimon-python/pypaimon/manifest/schema/file_entry.py @@ -0,0 +1,129 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Entry representing a file. + +Follows the design of Java's org.apache.paimon.manifest.FileEntry. +""" + + +class FileEntry: + """Entry representing a file. + + The same Identifier indicates that the FileEntry refers to the same data file. + """ + + class Identifier: + """Unique identifier for a file entry. + + Uses partition, bucket, level, fileName, extraFiles, + embeddedIndex and externalPath to identify a file. + """ + + def __init__(self, partition, bucket, level, file_name, + extra_files, embedded_index, external_path): + self.partition = partition + self.bucket = bucket + self.level = level + self.file_name = file_name + self.extra_files = extra_files + self.embedded_index = embedded_index + self.external_path = external_path + self._hash = None + + def __eq__(self, other): + if self is other: + return True + if other is None or not isinstance(other, FileEntry.Identifier): + return False + return (self.bucket == other.bucket + and self.level == other.level + and self.partition == other.partition + and self.file_name == other.file_name + and self.extra_files == other.extra_files + and self.embedded_index == other.embedded_index + and self.external_path == other.external_path) + + def __hash__(self): + if self._hash is None: + self._hash = hash(( + self.partition, + self.bucket, + self.level, + self.file_name, + self.extra_files, + self.embedded_index, + self.external_path, + )) + return self._hash + + def identifier(self): + """Build a unique Identifier for this file entry. + + Returns: + An Identifier instance. + """ + extra_files = (tuple(self.file.extra_files) + if self.file.extra_files else ()) + return FileEntry.Identifier( + partition=self.partition, + bucket=self.bucket, + level=self.file.level, + file_name=self.file.file_name, + extra_files=extra_files, + embedded_index=self.file.embedded_index, + external_path=self.file.external_path, + ) + + @staticmethod + def merge_entries(entries): + """Merge file entries: ADD and DELETE of the same file cancel each other. + + - ADD: if identifier already in map, raise error; otherwise add to map. + - DELETE: if identifier already in map, remove both (cancel); + otherwise add to map. + + Args: + entries: Iterable of FileEntry. + + Returns: + List of merged FileEntry values, preserving insertion order. + + Raises: + RuntimeError: If trying to add a file that is already in the map. + """ + entry_map = {} + + for entry in entries: + entry_identifier = entry.identifier() + if entry.kind == 0: # ADD + if entry_identifier in entry_map: + raise RuntimeError( + "Trying to add file {} which is already added.".format( + entry.file.file_name)) + entry_map[entry_identifier] = entry + elif entry.kind == 1: # DELETE + if entry_identifier in entry_map: + del entry_map[entry_identifier] + else: + entry_map[entry_identifier] = entry + else: + raise RuntimeError( + "Unknown entry kind: {}".format(entry.kind)) + + return list(entry_map.values()) diff --git a/paimon-python/pypaimon/manifest/schema/manifest_entry.py b/paimon-python/pypaimon/manifest/schema/manifest_entry.py index b1fd244dafc0..eba241786387 100644 --- a/paimon-python/pypaimon/manifest/schema/manifest_entry.py +++ b/paimon-python/pypaimon/manifest/schema/manifest_entry.py @@ -20,11 +20,12 @@ from pypaimon.manifest.schema.data_file_meta import (DATA_FILE_META_SCHEMA, DataFileMeta) +from pypaimon.manifest.schema.file_entry import FileEntry from pypaimon.table.row.generic_row import GenericRow @dataclass -class ManifestEntry: +class ManifestEntry(FileEntry): kind: int partition: GenericRow bucket: int diff --git a/paimon-python/pypaimon/schema/data_types.py b/paimon-python/pypaimon/schema/data_types.py index 318ddfe02fcf..6befaa4d4068 100755 --- a/paimon-python/pypaimon/schema/data_types.py +++ b/paimon-python/pypaimon/schema/data_types.py @@ -73,6 +73,16 @@ def __init__(self, type: str, nullable: bool = True): super().__init__(nullable) self.type = type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, AtomicType): + return False + return self.type == other.type and self.nullable == other.nullable + + def __hash__(self): + return hash((self.type, self.nullable)) + def to_dict(self) -> str: if not self.nullable: return self.type + " NOT NULL" @@ -95,6 +105,16 @@ def __init__(self, nullable: bool, element_type: DataType): super().__init__(nullable) self.element = element_type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, ArrayType): + return False + return self.element == other.element and self.nullable == other.nullable + + def __hash__(self): + return hash((self.element, self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "ARRAY" + (" NOT NULL" if not self.nullable else ""), @@ -119,6 +139,16 @@ def __init__(self, nullable: bool, element_type: DataType): super().__init__(nullable) self.element = element_type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, MultisetType): + return False + return self.element == other.element and self.nullable == other.nullable + + def __hash__(self): + return hash((self.element, self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "MULTISET{}{}".format('<' + str(self.element) + '>' if self.element else '', @@ -150,6 +180,18 @@ def __init__( self.key = key_type self.value = value_type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, MapType): + return False + return (self.key == other.key + and self.value == other.value + and self.nullable == other.nullable) + + def __hash__(self): + return hash((self.key, self.value, self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "MAP<{}, {}>".format(self.key, self.value), @@ -199,6 +241,21 @@ def __init__( def from_dict(cls, data: Dict[str, Any]) -> "DataField": return DataTypeParser.parse_data_field(data) + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, DataField): + return False + return (self.id == other.id + and self.name == other.name + and self.type == other.type + and self.description == other.description + and self.default_value == other.default_value) + + def __hash__(self): + return hash((self.id, self.name, self.type, + self.description, self.default_value)) + def to_dict(self) -> Dict[str, Any]: result = { self.FIELD_ID: self.id, @@ -223,6 +280,16 @@ def __init__(self, nullable: bool, fields: List[DataField]): super().__init__(nullable) self.fields = fields or [] + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, RowType): + return False + return self.fields == other.fields and self.nullable == other.nullable + + def __hash__(self): + return hash((tuple(self.fields), self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "ROW" + ("" if self.nullable else " NOT NULL"), @@ -587,7 +654,7 @@ def to_avro_type(field_type: pyarrow.DataType, field_name: str, parent_name: str = "record") -> Union[str, Dict[str, Any]]: if pyarrow.types.is_integer(field_type): if (pyarrow.types.is_signed_integer(field_type) and field_type.bit_width <= 32) or \ - (pyarrow.types.is_unsigned_integer(field_type) and field_type.bit_width < 32): + (pyarrow.types.is_unsigned_integer(field_type) and field_type.bit_width < 32): return "int" else: return "long" diff --git a/paimon-python/pypaimon/table/row/generic_row.py b/paimon-python/pypaimon/table/row/generic_row.py index be5c1ec80f39..4aa740de7219 100644 --- a/paimon-python/pypaimon/table/row/generic_row.py +++ b/paimon-python/pypaimon/table/row/generic_row.py @@ -51,6 +51,16 @@ def get_row_kind(self) -> RowKind: def __len__(self) -> int: return len(self.values) + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, GenericRow): + return False + return self.values == other.values and self.row_kind == other.row_kind + + def __hash__(self): + return hash((tuple(self.values), tuple(self.fields), self.row_kind)) + def __str__(self): field_strs = [f"{field.name}={repr(value)}" for field, value in zip(self.fields, self.values)] return f"GenericRow(row_kind={self.row_kind.name}, {', '.join(field_strs)})" diff --git a/paimon-python/pypaimon/utils/range_helper.py b/paimon-python/pypaimon/utils/range_helper.py new file mode 100644 index 000000000000..5bbc02a55c2d --- /dev/null +++ b/paimon-python/pypaimon/utils/range_helper.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""A helper class to handle ranges. + +Follows the design of Java's org.apache.paimon.utils.RangeHelper. +""" + + +class RangeHelper: + """A helper class to handle ranges. + + Provides methods to check if all ranges are the same and to merge + overlapping ranges into groups, preserving original order within groups. + + Args: + range_function: A callable that extracts a Range from an element T. + """ + + def __init__(self, range_function): + self._range_function = range_function + + def are_all_ranges_same(self, items): + """Check if all items have the same range. + + Args: + items: List of items to check. + + Returns: + True if all items have the same range, False otherwise. + """ + if not items: + return True + + first = items[0] + first_range = self._range_function(first) + if first_range is None: + return False + + for item in items[1:]: + if item is None: + return False + current_range = self._range_function(item) + if current_range is None: + return False + if current_range.from_ != first_range.from_ or current_range.to != first_range.to: + return False + + return True + + def merge_overlapping_ranges(self, items): + """Merge items with overlapping ranges into groups. + + Sorts items by range start, then merges overlapping groups. + Within each group, items are sorted by their original index. + + Args: + items: List of items with non-null ranges. + + Returns: + List of groups, where each group is a list of items + with overlapping ranges. + """ + if not items: + return [] + + # Create indexed values to track original indices + indexed = [] + for original_index, item in enumerate(items): + item_range = self._range_function(item) + if item_range is not None: + indexed.append(_IndexedValue(item, item_range, original_index)) + + if not indexed: + return [] + + # Sort by range start, then by range end + indexed.sort(key=lambda iv: (iv.start(), iv.end())) + + groups = [] + current_group = [indexed[0]] + current_end = indexed[0].end() + + # Iterate through sorted ranges and merge overlapping ones + for i in range(1, len(indexed)): + current = indexed[i] + if current.start() <= current_end: + current_group.append(current) + if current.end() > current_end: + current_end = current.end() + else: + groups.append(current_group) + current_group = [current] + current_end = current.end() + + # Add the last group + groups.append(current_group) + + # Convert groups to result, sorting each group by original index + result = [] + for group in groups: + group.sort(key=lambda iv: iv.original_index) + result.append([iv.value for iv in group]) + + return result + + +class _IndexedValue: + """A helper class to track original indices during range merging.""" + + def __init__(self, value, item_range, original_index): + self.value = value + self.range = item_range + self.original_index = original_index + + def start(self): + return self.range.from_ + + def end(self): + return self.range.to diff --git a/paimon-python/pypaimon/write/conflict_detection.py b/paimon-python/pypaimon/write/conflict_detection.py index 1f1ab0d7c9ac..47c3b5065a78 100644 --- a/paimon-python/pypaimon/write/conflict_detection.py +++ b/paimon-python/pypaimon/write/conflict_detection.py @@ -1,4 +1,3 @@ -################################################################################ # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -7,14 +6,14 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. """Conflict detection for commit operations. @@ -22,13 +21,12 @@ """ import logging -from typing import List, Optional, Set from pypaimon.manifest.schema.data_file_meta import DataFileMeta -from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.manifest.schema.file_entry import FileEntry from pypaimon.read.scanner.file_scanner import FileScanner -from pypaimon.snapshot.snapshot import Snapshot from pypaimon.utils.range import Range +from pypaimon.utils.range_helper import RangeHelper logger = logging.getLogger(__name__) @@ -61,29 +59,19 @@ def set_row_id_check_from_snapshot(self, row_id_check_from_snapshot): """Set the snapshot ID from which to check row ID conflicts.""" self._row_id_check_from_snapshot = row_id_check_from_snapshot - def should_be_overwrite_commit(self, commit_entries): + def should_be_overwrite_commit(self): """Check if the commit should be treated as an overwrite commit. - Follows Java ConflictDetection.shouldBeOverwriteCommit logic: - returns True if any entry is a DELETE (kind=1), or if - rowIdCheckFromSnapshot is set. - - Args: - commit_entries: The entries being committed. + returns True if rowIdCheckFromSnapshot is set. Returns: True if the commit should be treated as OVERWRITE. """ - for entry in commit_entries: - if entry.kind == 1: - return True - print(self._row_id_check_from_snapshot) return self._row_id_check_from_snapshot is not None def check_conflicts(self, latest_snapshot, base_entries, delta_entries, commit_kind): """Run all conflict checks and return the first detected conflict. - Follows Java ConflictDetection.checkConflicts logic: merges base_entries and delta_entries, then runs conflict checks on the merged result. @@ -99,11 +87,11 @@ def check_conflicts(self, latest_snapshot, base_entries, delta_entries, commit_k all_entries = list(base_entries) + list(delta_entries) try: - merged_entries = merge_entries(all_entries) + merged_entries = FileEntry.merge_entries(all_entries) except Exception as e: return RuntimeError( "File deletion conflicts detected! Give up committing. " + str(e)) - print("hello1") + conflict = self.check_row_id_range_conflicts(commit_kind, merged_entries) if conflict is not None: return conflict @@ -113,7 +101,6 @@ def check_conflicts(self, latest_snapshot, base_entries, delta_entries, commit_k def check_row_id_range_conflicts(self, commit_kind, commit_entries): """Check for row ID range conflicts among merged entries. - Follows Java ConflictDetection.checkRowIdRangeConflicts logic: only enabled when data evolution is active, and checks that overlapping row ID ranges in non-blob data files are identical. @@ -137,14 +124,15 @@ def check_row_id_range_conflicts(self, commit_kind, commit_entries): if not entries_with_row_id: return None - merged_groups = _merge_overlapping_row_id_ranges(entries_with_row_id) + range_helper = RangeHelper(lambda entry: entry.file.row_id_range()) + merged_groups = range_helper.merge_overlapping_ranges(entries_with_row_id) for group in merged_groups: data_files = [ entry for entry in group if not DataFileMeta.is_blob_file(entry.file.file_name) ] - if not _are_all_row_id_ranges_same(data_files): + if not range_helper.are_all_ranges_same(data_files): file_descriptions = [ "{name}(rowId={row_id}, count={count})".format( name=entry.file.file_name, @@ -163,7 +151,6 @@ def check_row_id_range_conflicts(self, commit_kind, commit_entries): def check_for_row_id_from_snapshot(self, latest_snapshot, commit_entries): """Check for row ID conflicts from a specific snapshot onwards. - Follows Java ConflictDetection.checkForRowIdFromSnapshot logic: collects row ID ranges from delta entries, then checks if any incremental changes between the check snapshot and latest snapshot have overlapping row ID ranges. @@ -180,7 +167,10 @@ def check_for_row_id_from_snapshot(self, latest_snapshot, commit_entries): if self._row_id_check_from_snapshot is None: return None - changed_parts = changed_partitions(commit_entries) + changed_partitions = set() + for entry in commit_entries: + partition_key = tuple(entry.partition.values) + changed_partitions.add(partition_key) history_id_ranges = [] for entry in commit_entries: @@ -208,7 +198,7 @@ def check_for_row_id_from_snapshot(self, latest_snapshot, commit_entries): continue incremental_entries = self._read_incremental_entries( - snapshot, changed_parts) + snapshot, changed_partitions) for entry in incremental_entries: file_range = entry.file.row_id_range() if file_range is None: @@ -228,7 +218,6 @@ def check_for_row_id_from_snapshot(self, latest_snapshot, commit_entries): def _read_incremental_entries(self, snapshot, partition_filter): """Read incremental manifest entries from a snapshot's delta manifest list. - Follows Java CommitScanner.readIncrementalEntries logic: reads the delta manifest list and filters entries by partition. Args: @@ -253,168 +242,3 @@ def _read_incremental_entries(self, snapshot, partition_filter): entry for entry in all_entries if tuple(entry.partition.values) in partition_filter ] - - -def changed_partitions(commit_entries): - """Extract unique changed partitions from commit entries. - - Follows Java ManifestEntryChanges.changedPartitions logic. - - Args: - commit_entries: List of ManifestEntry to extract partitions from. - - Returns: - Set of partition tuples. - """ - partitions = set() - for entry in commit_entries: - partition_key = tuple(entry.partition.values) - partitions.add(partition_key) - return partitions - - -def _merge_overlapping_row_id_ranges(entries): - """Merge entries with overlapping row ID ranges into groups. - - Follows Java RangeHelper.mergeOverlappingRanges logic: - sorts entries by row ID range start, then merges overlapping groups. - - Args: - entries: List of ManifestEntry with non-null first_row_id. - - Returns: - List of groups, where each group is a list of entries - with overlapping row ID ranges. - """ - if not entries: - return [] - - indexed = [] - for i, entry in enumerate(entries): - row_range = entry.file.row_id_range() - if row_range is not None: - indexed.append((entry, row_range, i)) - - if not indexed: - return [] - - indexed.sort(key=lambda item: (item[1].from_, item[1].to)) - - groups = [] - current_group = [indexed[0]] - current_end = indexed[0][1].to - - for i in range(1, len(indexed)): - entry, row_range, original_index = indexed[i] - if row_range.from_ <= current_end: - current_group.append(indexed[i]) - if row_range.to > current_end: - current_end = row_range.to - else: - groups.append(current_group) - current_group = [indexed[i]] - current_end = row_range.to - - groups.append(current_group) - - result = [] - for group in groups: - group.sort(key=lambda item: item[2]) - result.append([item[0] for item in group]) - - return result - - -def _are_all_row_id_ranges_same(entries): - """Check if all entries have the same row ID range. - - Follows Java RangeHelper.areAllRangesSame logic. - - Args: - entries: List of ManifestEntry to check. - - Returns: - True if all entries have the same row ID range, False otherwise. - """ - if not entries: - return True - - first_range = entries[0].file.row_id_range() - if first_range is None: - return False - - for entry in entries[1:]: - entry_range = entry.file.row_id_range() - if entry_range is None: - return False - if entry_range.from_ != first_range.from_ or entry_range.to != first_range.to: - return False - - return True - - -def _entry_identifier(entry): - """Build a unique identifier tuple for a ManifestEntry. - - Follows Java FileEntry.Identifier logic: uses partition, bucket, level, - fileName, extraFiles, embeddedIndex and externalPath to identify a file. - - Args: - entry: A ManifestEntry instance. - - Returns: - A hashable identifier tuple. - """ - partition_key = tuple(entry.partition.values) - extra_files = tuple(entry.file.extra_files) if entry.file.extra_files else () - embedded_index = (bytes(entry.file.embedded_index) - if entry.file.embedded_index is not None else None) - return ( - partition_key, - entry.bucket, - entry.file.level, - entry.file.file_name, - extra_files, - embedded_index, - entry.file.external_path, - ) - - -def merge_entries(entries): - """Merge manifest entries: ADD and DELETE of the same file cancel each other. - - Follows Java FileEntry.mergeEntries logic: - - ADD: if identifier already in map, raise error; otherwise add to map - - DELETE: if identifier already in map, remove both (cancel); otherwise add to map - - Args: - entries: Iterable of ManifestEntry. - - Returns: - List of merged ManifestEntry values. - - Raises: - RuntimeError: If trying to add a file that is already in the map. - """ - entry_map = {} - insertion_order = [] - - for entry in entries: - identifier = _entry_identifier(entry) - if entry.kind == 0: # ADD - if identifier in entry_map: - raise RuntimeError( - "Trying to add file {} which is already added.".format( - entry.file.file_name)) - entry_map[identifier] = entry - insertion_order.append(identifier) - elif entry.kind == 1: # DELETE - if identifier in entry_map: - del entry_map[identifier] - else: - entry_map[identifier] = entry - insertion_order.append(identifier) - else: - raise RuntimeError("Unknown entry kind: {}".format(entry.kind)) - - return [entry_map[key] for key in insertion_order if key in entry_map] diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index a41124d2dc5e..423abe64dfc1 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -140,8 +140,8 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): commit_kind = "APPEND" detect_conflicts = False allow_rollback = False - if self.conflict_detection.should_be_overwrite_commit(commit_entries): - # commit_kind = "OVERWRITE" + if self.conflict_detection.should_be_overwrite_commit(): + commit_kind = "OVERWRITE" detect_conflicts = True allow_rollback = True @@ -507,9 +507,9 @@ def _is_duplicate_commit(self, retry_result, latest_snapshot, commit_identifier, def _read_all_entries_from_changed_partitions(self, latest_snapshot, delta_entries): """Read all entries from the latest snapshot for partitions that are changed. - Follows Java CommitScanner.readAllEntriesFromChangedPartitions logic: - extracts changed partitions from delta entries, then reads all entries - from the latest snapshot filtered by those partitions. + Builds a partition predicate from delta entries and passes it to FileScanner, + so that manifest files and entries are filtered during reading rather than + after a full scan. Args: latest_snapshot: The latest snapshot to read entries from. @@ -522,19 +522,43 @@ def _read_all_entries_from_changed_partitions(self, latest_snapshot, delta_entri if latest_snapshot is None: return [] - changed_partition_set = set() - for entry in delta_entries: - changed_partition_set.add(tuple(entry.partition.values)) + partition_filter = self._build_partition_filter_from_entries(delta_entries) all_manifests = self.manifest_list_manager.read_all(latest_snapshot) - all_entries = FileScanner( - self.table, lambda: [], None + return FileScanner( + self.table, lambda: [], partition_filter ).read_manifest_entries(all_manifests) - return [ - entry for entry in all_entries - if tuple(entry.partition.values) in changed_partition_set - ] + def _build_partition_filter_from_entries(self, entries): + """Build a partition predicate that matches all partitions present in the given entries. + + Args: + entries: List of ManifestEntry whose partitions should be matched. + + Returns: + A Predicate matching any of the changed partitions, or None if + partition keys are empty. + """ + partition_keys = self.table.partition_keys + if not partition_keys: + return None + + changed_partitions = set() + for entry in entries: + changed_partitions.add(tuple(entry.partition.values)) + + if not changed_partitions: + return None + + predicate_builder = PredicateBuilder(self.table.fields) + partition_predicates = [] + for partition_values in changed_partitions: + sub_predicates = [] + for i, key in enumerate(partition_keys): + sub_predicates.append(predicate_builder.equal(key, partition_values[i])) + partition_predicates.append(predicate_builder.and_predicates(sub_predicates)) + + return predicate_builder.or_predicates(partition_predicates) def _generate_overwrite_entries(self, latestSnapshot, partition_filter, commit_messages): """Generate commit entries for OVERWRITE mode based on latest snapshot.""" From 2a1a0bddb62e8b287f85ffe2f4a084f25b1ee308 Mon Sep 17 00:00:00 2001 From: umi Date: Sat, 28 Feb 2026 17:55:18 +0800 Subject: [PATCH 4/4] commitMessage --- .../pypaimon/catalog/catalog_environment.py | 2 +- .../pypaimon/tests/table_update_test.py | 6 ++---- .../pypaimon/write/commit_message.py | 1 + .../pypaimon/write/conflict_detection.py | 4 ---- .../pypaimon/write/file_store_commit.py | 20 +++++++++++++------ paimon-python/pypaimon/write/table_commit.py | 16 --------------- paimon-python/pypaimon/write/table_update.py | 4 ++-- .../pypaimon/write/table_update_by_row_id.py | 3 ++- 8 files changed, 22 insertions(+), 34 deletions(-) diff --git a/paimon-python/pypaimon/catalog/catalog_environment.py b/paimon-python/pypaimon/catalog/catalog_environment.py index b820c7945b3c..61e480bd608f 100644 --- a/paimon-python/pypaimon/catalog/catalog_environment.py +++ b/paimon-python/pypaimon/catalog/catalog_environment.py @@ -19,7 +19,7 @@ from typing import Optional from pypaimon.catalog.catalog_loader import CatalogLoader -from pypaimon.catalog.table_rollback import TableRollback, CatalogTableRollback +from pypaimon.catalog.table_rollback import CatalogTableRollback from pypaimon.common.identifier import Identifier from pypaimon.snapshot.catalog_snapshot_commit import CatalogSnapshotCommit from pypaimon.snapshot.renaming_snapshot_commit import RenamingSnapshotCommit diff --git a/paimon-python/pypaimon/tests/table_update_test.py b/paimon-python/pypaimon/tests/table_update_test.py index 5dc161e06813..8006cd8ebb8b 100644 --- a/paimon-python/pypaimon/tests/table_update_test.py +++ b/paimon-python/pypaimon/tests/table_update_test.py @@ -15,7 +15,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import logging import os import shutil import tempfile @@ -899,9 +898,9 @@ def update_worker(thread_index, update_spec): 'age': update_spec['ages'], }) - commit_messages, snapshot_id = table_update.update_by_arrow_with_row_id(update_data) + commit_messages = table_update.update_by_arrow_with_row_id(update_data) - table_commit = write_builder.new_commit().row_id_check_conflict(snapshot_id) + table_commit = write_builder.new_commit() table_commit.commit(commit_messages) table_commit.close() @@ -959,7 +958,6 @@ def test_concurrent_updates_same_rows_with_retry(self): successful updates for each row (last-writer-wins). """ import threading - import traceback table = self._create_table() diff --git a/paimon-python/pypaimon/write/commit_message.py b/paimon-python/pypaimon/write/commit_message.py index b36a1b1bbf4f..32c4556a90c5 100644 --- a/paimon-python/pypaimon/write/commit_message.py +++ b/paimon-python/pypaimon/write/commit_message.py @@ -27,6 +27,7 @@ class CommitMessage: partition: Tuple bucket: int new_files: List[DataFileMeta] + snapshot_for_update: int = -1 def is_empty(self): return not self.new_files diff --git a/paimon-python/pypaimon/write/conflict_detection.py b/paimon-python/pypaimon/write/conflict_detection.py index 47c3b5065a78..0d8efc47ac86 100644 --- a/paimon-python/pypaimon/write/conflict_detection.py +++ b/paimon-python/pypaimon/write/conflict_detection.py @@ -55,10 +55,6 @@ def __init__(self, data_evolution_enabled, snapshot_manager, self.table = table self._row_id_check_from_snapshot = None - def set_row_id_check_from_snapshot(self, row_id_check_from_snapshot): - """Set the snapshot ID from which to check row ID conflicts.""" - self._row_id_check_from_snapshot = row_id_check_from_snapshot - def should_be_overwrite_commit(self): """Check if the commit should be treated as an overwrite commit. diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index 423abe64dfc1..abad687b1f1d 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -107,17 +107,14 @@ def __init__(self, snapshot_commit: SnapshotCommit, table, commit_user: str): if table_rollback is not None: self.rollback = CommitRollback(table_rollback) - def row_id_check_conflict(self, row_id_check_from_snapshot): - """Set the snapshot ID from which to check row ID conflicts.""" - self.conflict_detection.set_row_id_check_from_snapshot( - row_id_check_from_snapshot) - return self - def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in normal append mode.""" if not commit_messages: return + # Extract snapshot_for_update from commit messages + self._apply_row_id_check(commit_messages) + logger.info( "Ready to commit to table %s, number of commit messages: %d", self.table.identifier, @@ -151,6 +148,17 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): detect_conflicts=detect_conflicts, allow_rollback=allow_rollback) + def _apply_row_id_check(self, commit_messages: List[CommitMessage]): + """Extract snapshot_for_update from commit messages and apply to conflict detection. + + If any commit message has a snapshot_for_update != -1, set it on the + conflict detection instance for row ID conflict checking. + """ + for msg in commit_messages: + if msg.snapshot_for_update != -1: + self.conflict_detection._row_id_check_from_snapshot = msg.snapshot_for_update + return + def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in overwrite mode.""" if not commit_messages: diff --git a/paimon-python/pypaimon/write/table_commit.py b/paimon-python/pypaimon/write/table_commit.py index 9c37a27eb011..1eafafefc09f 100644 --- a/paimon-python/pypaimon/write/table_commit.py +++ b/paimon-python/pypaimon/write/table_commit.py @@ -73,22 +73,6 @@ def abort(self, commit_messages: List[CommitMessage]): def close(self): self.file_store_commit.close() - def row_id_check_conflict(self, row_id_check_from_snapshot): - """Set the snapshot ID from which to check row ID conflicts. - - Follows Java TableCommitImpl.rowIdCheckConflict logic: - forwards the call to FileStoreCommit.row_id_check_conflict(). - - Args: - row_id_check_from_snapshot: The snapshot ID from which to start - checking row ID conflicts, or None to disable. - - Returns: - self for method chaining. - """ - self.file_store_commit.row_id_check_conflict(row_id_check_from_snapshot) - return self - def _check_committed(self): if self.batch_committed: raise RuntimeError("BatchTableCommit only supports one-time committing.") diff --git a/paimon-python/pypaimon/write/table_update.py b/paimon-python/pypaimon/write/table_update.py index 596f5c447aea..9ad86aa95dcd 100644 --- a/paimon-python/pypaimon/write/table_update.py +++ b/paimon-python/pypaimon/write/table_update.py @@ -122,10 +122,10 @@ def new_shard_updator(self, shard_num: int, total_shard_count: int): total_shard_count, ) - def update_by_arrow_with_row_id(self, table: pa.Table) -> (List[CommitMessage], int): + def update_by_arrow_with_row_id(self, table: pa.Table) -> List[CommitMessage]: update_by_row_id = TableUpdateByRowId(self.table, self.commit_user) update_by_row_id.update_columns(table, self.update_cols) - return update_by_row_id.commit_messages, update_by_row_id.snapshot_id + return update_by_row_id.commit_messages class ShardTableUpdator: diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py b/paimon-python/pypaimon/write/table_update_by_row_id.py index ffee5ddceb2a..a0e737aec12c 100644 --- a/paimon-python/pypaimon/write/table_update_by_row_id.py +++ b/paimon-python/pypaimon/write/table_update_by_row_id.py @@ -309,8 +309,9 @@ def _write_group(self, partition: GenericRow, first_row_id: int, # Prepare commit and assign first_row_id commit_messages = file_store_write.prepare_commit(BATCH_COMMIT_IDENTIFIER) - # Assign first_row_id to the new files + # Assign first_row_id to the new files and snapshot_for_update for msg in commit_messages: + msg.snapshot_for_update = self.snapshot_id for file in msg.new_files: # Assign the same first_row_id as the original file file.first_row_id = first_row_id