diff --git a/paimon-python/pypaimon/api/api_request.py b/paimon-python/pypaimon/api/api_request.py index f2d062f8d005..f191fb03d3da 100644 --- a/paimon-python/pypaimon/api/api_request.py +++ b/paimon-python/pypaimon/api/api_request.py @@ -26,6 +26,7 @@ from pypaimon.schema.schema_change import SchemaChange from pypaimon.snapshot.snapshot import Snapshot from pypaimon.snapshot.snapshot_commit import PartitionStatistics +from pypaimon.table.instant import Instant class RESTRequest(ABC): @@ -84,3 +85,12 @@ class AlterTableRequest(RESTRequest): FIELD_CHANGES = "changes" changes: List[SchemaChange] = json_field(FIELD_CHANGES) + + +@dataclass +class RollbackTableRequest(RESTRequest): + FIELD_INSTANT = "instant" + FIELD_FROM_SNAPSHOT = "fromSnapshot" + + instant: Instant = json_field(FIELD_INSTANT) + from_snapshot: Optional[int] = json_field(FIELD_FROM_SNAPSHOT) diff --git a/paimon-python/pypaimon/api/resource_paths.py b/paimon-python/pypaimon/api/resource_paths.py index 79adc0ade3a1..16967d16e23e 100644 --- a/paimon-python/pypaimon/api/resource_paths.py +++ b/paimon-python/pypaimon/api/resource_paths.py @@ -70,3 +70,7 @@ def rename_table(self) -> str: def commit_table(self, database_name: str, table_name: str) -> str: return ("{}/{}/{}/{}/{}/commit".format(self.base_path, self.DATABASES, RESTUtil.encode_string(database_name), self.TABLES, RESTUtil.encode_string(table_name))) + + def rollback_table(self, database_name: str, table_name: str) -> str: + return ("{}/{}/{}/{}/{}/rollback".format(self.base_path, self.DATABASES, RESTUtil.encode_string(database_name), + self.TABLES, RESTUtil.encode_string(table_name))) diff --git a/paimon-python/pypaimon/api/rest_api.py b/paimon-python/pypaimon/api/rest_api.py index dc2c02104b86..069220d68775 100755 --- a/paimon-python/pypaimon/api/rest_api.py +++ b/paimon-python/pypaimon/api/rest_api.py @@ -20,7 +20,8 @@ from pypaimon.api.api_request import (AlterDatabaseRequest, AlterTableRequest, CommitTableRequest, CreateDatabaseRequest, - CreateTableRequest, RenameTableRequest) + CreateTableRequest, RenameTableRequest, + RollbackTableRequest) from pypaimon.api.api_response import (CommitTableResponse, ConfigResponse, GetDatabaseResponse, GetTableResponse, GetTableTokenResponse, @@ -358,6 +359,27 @@ def commit_snapshot( ) return response.is_success() + def rollback_to(self, identifier, instant, from_snapshot=None): + """Rollback table to the given instant. + + Args: + identifier: The table identifier. + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + NoSuchResourceException: If the table, snapshot or tag does not exist. + ForbiddenException: If no permission to access this table. + """ + database_name, table_name = self.__validate_identifier(identifier) + request = RollbackTableRequest(instant=instant, from_snapshot=from_snapshot) + self.client.post( + self.resource_paths.rollback_table(database_name, table_name), + request, + self.rest_auth_function + ) + @staticmethod def __validate_identifier(identifier: Identifier): if not identifier: diff --git a/paimon-python/pypaimon/catalog/catalog.py b/paimon-python/pypaimon/catalog/catalog.py index 729522450ba7..aa914d297825 100644 --- a/paimon-python/pypaimon/catalog/catalog.py +++ b/paimon-python/pypaimon/catalog/catalog.py @@ -95,6 +95,23 @@ def commit_snapshot( """ + def rollback_to(self, identifier, instant, from_snapshot=None): + """Rollback table by the given identifier and instant. + + Args: + identifier: Path of the table (Identifier instance). + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + TableNotExistException: If the table does not exist. + UnsupportedOperationError: If the catalog does not support version management. + """ + raise NotImplementedError( + "rollback_to is not supported by this catalog." + ) + def drop_partitions( self, identifier: Union[str, Identifier], diff --git a/paimon-python/pypaimon/catalog/catalog_environment.py b/paimon-python/pypaimon/catalog/catalog_environment.py index 762a42dd6a7c..61e480bd608f 100644 --- a/paimon-python/pypaimon/catalog/catalog_environment.py +++ b/paimon-python/pypaimon/catalog/catalog_environment.py @@ -19,6 +19,7 @@ from typing import Optional from pypaimon.catalog.catalog_loader import CatalogLoader +from pypaimon.catalog.table_rollback import CatalogTableRollback from pypaimon.common.identifier import Identifier from pypaimon.snapshot.catalog_snapshot_commit import CatalogSnapshotCommit from pypaimon.snapshot.renaming_snapshot_commit import RenamingSnapshotCommit @@ -60,6 +61,22 @@ def snapshot_commit(self, snapshot_manager) -> Optional[SnapshotCommit]: # to create locks based on the catalog lock context return RenamingSnapshotCommit(snapshot_manager) + def catalog_table_rollback(self): + """Create a TableRollback instance based on the catalog environment configuration. + + If catalog_loader is available and version management is supported, + returns a CatalogTableRollback that delegates to catalog.rollback_to. + Otherwise, returns None. + + Returns: + A TableRollback instance or None. + """ + if self.catalog_loader is not None and self.supports_version_management: + catalog = self.catalog_loader.load() + identifier = self.identifier + return CatalogTableRollback(catalog, identifier) + return None + def copy(self, identifier: Identifier) -> 'CatalogEnvironment': """ Create a copy of this CatalogEnvironment with a different identifier. diff --git a/paimon-python/pypaimon/catalog/rest/rest_catalog.py b/paimon-python/pypaimon/catalog/rest/rest_catalog.py index 9cf138c2543a..dab7633bc91c 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_catalog.py +++ b/paimon-python/pypaimon/catalog/rest/rest_catalog.py @@ -256,6 +256,28 @@ def alter_table( except ForbiddenException as e: raise TableNoPermissionException(identifier) from e + def rollback_to(self, identifier, instant, from_snapshot=None): + """Rollback table by the given identifier and instant. + + Args: + identifier: Path of the table (Identifier or string). + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + TableNotExistException: If the table does not exist. + TableNoPermissionException: If no permission to access this table. + """ + if not isinstance(identifier, Identifier): + identifier = Identifier.from_string(identifier) + try: + self.rest_api.rollback_to(identifier, instant, from_snapshot) + except NoSuchResourceException as e: + raise TableNotExistException(identifier) from e + except ForbiddenException as e: + raise TableNoPermissionException(identifier) from e + def load_table_metadata(self, identifier: Identifier) -> TableMetadata: try: response = self.rest_api.get_table(identifier) diff --git a/paimon-python/pypaimon/catalog/table_rollback.py b/paimon-python/pypaimon/catalog/table_rollback.py new file mode 100644 index 000000000000..70f155ce40e6 --- /dev/null +++ b/paimon-python/pypaimon/catalog/table_rollback.py @@ -0,0 +1,62 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod + + +class TableRollback(ABC): + """Rollback table to instant from snapshot. + """ + + @abstractmethod + def rollback_to(self, instant, from_snapshot=None): + """Rollback table to the given instant. + + Args: + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + """ + + +class CatalogTableRollback(TableRollback): + """ + Internal TableRollback implementation that delegates to catalog.rollback_to. + """ + + def __init__(self, catalog, identifier): + self._catalog = catalog + self._identifier = identifier + + def rollback_to(self, instant, from_snapshot=None): + """Rollback table to the given instant via catalog. + + Args: + instant: The Instant (SnapshotInstant or TagInstant) to rollback to. + from_snapshot: Optional snapshot ID. Success only occurs when the + latest snapshot is this snapshot. + + Raises: + RuntimeError: If the table does not exist in the catalog. + """ + try: + self._catalog.rollback_to(self._identifier, instant, from_snapshot) + except Exception as e: + raise RuntimeError( + "Failed to rollback table {}: {}".format( + self._identifier, e)) from e diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py b/paimon-python/pypaimon/manifest/manifest_file_manager.py index 0ed50918253c..5975fcbc9fa9 100644 --- a/paimon-python/pypaimon/manifest/manifest_file_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py @@ -53,17 +53,6 @@ def read_entries_parallel(self, manifest_files: List[ManifestFileMeta], manifest def _process_single_manifest(manifest_file: ManifestFileMeta) -> List[ManifestEntry]: return self.read(manifest_file.file_name, manifest_entry_filter, drop_stats) - def _entry_identifier(e: ManifestEntry) -> tuple: - return ( - tuple(e.partition.values), - e.bucket, - e.file.level, - e.file.file_name, - tuple(e.file.extra_files) if e.file.extra_files else (), - e.file.embedded_index, - e.file.external_path, - ) - deleted_entry_keys = set() added_entries = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -73,11 +62,11 @@ def _entry_identifier(e: ManifestEntry) -> tuple: if entry.kind == 0: # ADD added_entries.append(entry) else: # DELETE - deleted_entry_keys.add(_entry_identifier(entry)) + deleted_entry_keys.add(entry.identifier()) final_entries = [ entry for entry in added_entries - if _entry_identifier(entry) not in deleted_entry_keys + if entry.identifier() not in deleted_entry_keys ] return final_entries diff --git a/paimon-python/pypaimon/manifest/schema/file_entry.py b/paimon-python/pypaimon/manifest/schema/file_entry.py new file mode 100644 index 000000000000..3abab4495c28 --- /dev/null +++ b/paimon-python/pypaimon/manifest/schema/file_entry.py @@ -0,0 +1,129 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Entry representing a file. + +Follows the design of Java's org.apache.paimon.manifest.FileEntry. +""" + + +class FileEntry: + """Entry representing a file. + + The same Identifier indicates that the FileEntry refers to the same data file. + """ + + class Identifier: + """Unique identifier for a file entry. + + Uses partition, bucket, level, fileName, extraFiles, + embeddedIndex and externalPath to identify a file. + """ + + def __init__(self, partition, bucket, level, file_name, + extra_files, embedded_index, external_path): + self.partition = partition + self.bucket = bucket + self.level = level + self.file_name = file_name + self.extra_files = extra_files + self.embedded_index = embedded_index + self.external_path = external_path + self._hash = None + + def __eq__(self, other): + if self is other: + return True + if other is None or not isinstance(other, FileEntry.Identifier): + return False + return (self.bucket == other.bucket + and self.level == other.level + and self.partition == other.partition + and self.file_name == other.file_name + and self.extra_files == other.extra_files + and self.embedded_index == other.embedded_index + and self.external_path == other.external_path) + + def __hash__(self): + if self._hash is None: + self._hash = hash(( + self.partition, + self.bucket, + self.level, + self.file_name, + self.extra_files, + self.embedded_index, + self.external_path, + )) + return self._hash + + def identifier(self): + """Build a unique Identifier for this file entry. + + Returns: + An Identifier instance. + """ + extra_files = (tuple(self.file.extra_files) + if self.file.extra_files else ()) + return FileEntry.Identifier( + partition=self.partition, + bucket=self.bucket, + level=self.file.level, + file_name=self.file.file_name, + extra_files=extra_files, + embedded_index=self.file.embedded_index, + external_path=self.file.external_path, + ) + + @staticmethod + def merge_entries(entries): + """Merge file entries: ADD and DELETE of the same file cancel each other. + + - ADD: if identifier already in map, raise error; otherwise add to map. + - DELETE: if identifier already in map, remove both (cancel); + otherwise add to map. + + Args: + entries: Iterable of FileEntry. + + Returns: + List of merged FileEntry values, preserving insertion order. + + Raises: + RuntimeError: If trying to add a file that is already in the map. + """ + entry_map = {} + + for entry in entries: + entry_identifier = entry.identifier() + if entry.kind == 0: # ADD + if entry_identifier in entry_map: + raise RuntimeError( + "Trying to add file {} which is already added.".format( + entry.file.file_name)) + entry_map[entry_identifier] = entry + elif entry.kind == 1: # DELETE + if entry_identifier in entry_map: + del entry_map[entry_identifier] + else: + entry_map[entry_identifier] = entry + else: + raise RuntimeError( + "Unknown entry kind: {}".format(entry.kind)) + + return list(entry_map.values()) diff --git a/paimon-python/pypaimon/manifest/schema/manifest_entry.py b/paimon-python/pypaimon/manifest/schema/manifest_entry.py index b1fd244dafc0..eba241786387 100644 --- a/paimon-python/pypaimon/manifest/schema/manifest_entry.py +++ b/paimon-python/pypaimon/manifest/schema/manifest_entry.py @@ -20,11 +20,12 @@ from pypaimon.manifest.schema.data_file_meta import (DATA_FILE_META_SCHEMA, DataFileMeta) +from pypaimon.manifest.schema.file_entry import FileEntry from pypaimon.table.row.generic_row import GenericRow @dataclass -class ManifestEntry: +class ManifestEntry(FileEntry): kind: int partition: GenericRow bucket: int diff --git a/paimon-python/pypaimon/schema/data_types.py b/paimon-python/pypaimon/schema/data_types.py index 318ddfe02fcf..6befaa4d4068 100755 --- a/paimon-python/pypaimon/schema/data_types.py +++ b/paimon-python/pypaimon/schema/data_types.py @@ -73,6 +73,16 @@ def __init__(self, type: str, nullable: bool = True): super().__init__(nullable) self.type = type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, AtomicType): + return False + return self.type == other.type and self.nullable == other.nullable + + def __hash__(self): + return hash((self.type, self.nullable)) + def to_dict(self) -> str: if not self.nullable: return self.type + " NOT NULL" @@ -95,6 +105,16 @@ def __init__(self, nullable: bool, element_type: DataType): super().__init__(nullable) self.element = element_type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, ArrayType): + return False + return self.element == other.element and self.nullable == other.nullable + + def __hash__(self): + return hash((self.element, self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "ARRAY" + (" NOT NULL" if not self.nullable else ""), @@ -119,6 +139,16 @@ def __init__(self, nullable: bool, element_type: DataType): super().__init__(nullable) self.element = element_type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, MultisetType): + return False + return self.element == other.element and self.nullable == other.nullable + + def __hash__(self): + return hash((self.element, self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "MULTISET{}{}".format('<' + str(self.element) + '>' if self.element else '', @@ -150,6 +180,18 @@ def __init__( self.key = key_type self.value = value_type + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, MapType): + return False + return (self.key == other.key + and self.value == other.value + and self.nullable == other.nullable) + + def __hash__(self): + return hash((self.key, self.value, self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "MAP<{}, {}>".format(self.key, self.value), @@ -199,6 +241,21 @@ def __init__( def from_dict(cls, data: Dict[str, Any]) -> "DataField": return DataTypeParser.parse_data_field(data) + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, DataField): + return False + return (self.id == other.id + and self.name == other.name + and self.type == other.type + and self.description == other.description + and self.default_value == other.default_value) + + def __hash__(self): + return hash((self.id, self.name, self.type, + self.description, self.default_value)) + def to_dict(self) -> Dict[str, Any]: result = { self.FIELD_ID: self.id, @@ -223,6 +280,16 @@ def __init__(self, nullable: bool, fields: List[DataField]): super().__init__(nullable) self.fields = fields or [] + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, RowType): + return False + return self.fields == other.fields and self.nullable == other.nullable + + def __hash__(self): + return hash((tuple(self.fields), self.nullable)) + def to_dict(self) -> Dict[str, Any]: return { "type": "ROW" + ("" if self.nullable else " NOT NULL"), @@ -587,7 +654,7 @@ def to_avro_type(field_type: pyarrow.DataType, field_name: str, parent_name: str = "record") -> Union[str, Dict[str, Any]]: if pyarrow.types.is_integer(field_type): if (pyarrow.types.is_signed_integer(field_type) and field_type.bit_width <= 32) or \ - (pyarrow.types.is_unsigned_integer(field_type) and field_type.bit_width < 32): + (pyarrow.types.is_unsigned_integer(field_type) and field_type.bit_width < 32): return "int" else: return "long" diff --git a/paimon-python/pypaimon/table/instant.py b/paimon-python/pypaimon/table/instant.py new file mode 100644 index 000000000000..0cf4383213c4 --- /dev/null +++ b/paimon-python/pypaimon/table/instant.py @@ -0,0 +1,136 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod + + +class Instant(ABC): + """Table rollback instant, corresponding to Java's Instant interface. + + Supports polymorphic JSON serialization via to_dict/from_dict, + matching Java's Jackson @JsonTypeInfo/@JsonSubTypes annotations. + + Serialization format: + SnapshotInstant: {"type": "snapshot", "snapshotId": 123} + TagInstant: {"type": "tag", "tagName": "test_tag"} + """ + + FIELD_TYPE = "type" + TYPE_SNAPSHOT = "snapshot" + TYPE_TAG = "tag" + + @staticmethod + def snapshot(snapshot_id): + """Create a SnapshotInstant. + + Args: + snapshot_id: The snapshot ID to rollback to. + + Returns: + A SnapshotInstant instance. + """ + return SnapshotInstant(snapshot_id) + + @staticmethod + def tag(tag_name): + """Create a TagInstant. + + Args: + tag_name: The tag name to rollback to. + + Returns: + A TagInstant instance. + """ + return TagInstant(tag_name) + + @abstractmethod + def to_dict(self): + """Serialize this Instant to a dictionary for JSON output.""" + + @staticmethod + def from_dict(data): + """Deserialize an Instant from a dictionary. + + Args: + data: A dictionary with a 'type' field indicating the instant type. + + Returns: + A SnapshotInstant or TagInstant instance. + + Raises: + ValueError: If the type field is missing or unknown. + """ + instant_type = data.get(Instant.FIELD_TYPE) + if instant_type == Instant.TYPE_SNAPSHOT: + return SnapshotInstant(data[SnapshotInstant.FIELD_SNAPSHOT_ID]) + elif instant_type == Instant.TYPE_TAG: + return TagInstant(data[TagInstant.FIELD_TAG_NAME]) + else: + raise ValueError("Unknown instant type: {}".format(instant_type)) + + +class SnapshotInstant(Instant): + """Snapshot instant for table rollback.""" + + FIELD_SNAPSHOT_ID = "snapshotId" + + def __init__(self, snapshot_id): + self.snapshot_id = snapshot_id + + def to_dict(self): + return { + Instant.FIELD_TYPE: Instant.TYPE_SNAPSHOT, + self.FIELD_SNAPSHOT_ID: self.snapshot_id, + } + + def __eq__(self, other): + if not isinstance(other, SnapshotInstant): + return False + return self.snapshot_id == other.snapshot_id + + def __hash__(self): + return hash(self.snapshot_id) + + def __repr__(self): + return "SnapshotInstant(snapshot_id={})".format(self.snapshot_id) + + +class TagInstant(Instant): + """Tag instant for table rollback.""" + + FIELD_TAG_NAME = "tagName" + + def __init__(self, tag_name): + self.tag_name = tag_name + + def to_dict(self): + return { + Instant.FIELD_TYPE: Instant.TYPE_TAG, + self.FIELD_TAG_NAME: self.tag_name, + } + + def __eq__(self, other): + if not isinstance(other, TagInstant): + return False + return self.tag_name == other.tag_name + + def __hash__(self): + return hash(self.tag_name) + + def __repr__(self): + return "TagInstant(tag_name={})".format(self.tag_name) diff --git a/paimon-python/pypaimon/table/row/generic_row.py b/paimon-python/pypaimon/table/row/generic_row.py index be5c1ec80f39..4aa740de7219 100644 --- a/paimon-python/pypaimon/table/row/generic_row.py +++ b/paimon-python/pypaimon/table/row/generic_row.py @@ -51,6 +51,16 @@ def get_row_kind(self) -> RowKind: def __len__(self) -> int: return len(self.values) + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, GenericRow): + return False + return self.values == other.values and self.row_kind == other.row_kind + + def __hash__(self): + return hash((tuple(self.values), tuple(self.fields), self.row_kind)) + def __str__(self): field_strs = [f"{field.name}={repr(value)}" for field, value in zip(self.fields, self.values)] return f"GenericRow(row_kind={self.row_kind.name}, {', '.join(field_strs)})" diff --git a/paimon-python/pypaimon/tests/table_update_test.py b/paimon-python/pypaimon/tests/table_update_test.py index ad9158e9febf..8006cd8ebb8b 100644 --- a/paimon-python/pypaimon/tests/table_update_test.py +++ b/paimon-python/pypaimon/tests/table_update_test.py @@ -859,6 +859,196 @@ def test_update_partial_rows_across_two_files(self): 'Seattle', 'Boston', 'Denver', 'Miami', 'Atlanta'] self.assertEqual(expected_cities, cities, "Cities should remain unchanged") + def test_concurrent_updates_with_retry(self): + """Test data evolution with multiple threads performing concurrent updates. + + Each thread updates different rows of the same column. If a conflict occurs, + the thread retries until the update succeeds. After all threads complete, + the final result is verified to ensure all updates were applied correctly. + """ + import threading + import traceback + table = self._create_table() + + # Table has 5 rows (row_id 0-4) after _create_table: + # row 0: age=25, row 1: age=30, row 2: age=35, row 3: age=40, row 4: age=45 + + # Thread 1 updates rows 0 and 1 + # Thread 2 updates rows 2 and 3 + # Thread 3 updates row 4 + thread_updates = [ + {'row_ids': [0, 1], 'ages': [100, 200]}, + {'row_ids': [2, 3], 'ages': [300, 400]}, + {'row_ids': [4], 'ages': [500]}, + ] + + errors = [] + success_counts = [0] * len(thread_updates) + + def update_worker(thread_index, update_spec): + max_retries = 20 + for attempt in range(max_retries): + try: + print("hello0") + write_builder = table.new_batch_write_builder() + table_update = write_builder.new_update().with_update_type(['age']) + + update_data = pa.Table.from_pydict({ + '_ROW_ID': update_spec['row_ids'], + 'age': update_spec['ages'], + }) + + commit_messages = table_update.update_by_arrow_with_row_id(update_data) + + table_commit = write_builder.new_commit() + table_commit.commit(commit_messages) + table_commit.close() + + success_counts[thread_index] = attempt + 1 + return + except Exception as e: + print( + "Thread-{} attempt {} failed: {}\n{}".format( + thread_index, attempt + 1, e, traceback.format_exc() + ) + ) + if attempt == max_retries - 1: + errors.append( + "Thread-{} failed after {} retries: {}".format( + thread_index, max_retries, e + ) + ) + + threads = [] + for idx, spec in enumerate(thread_updates): + thread = threading.Thread(target=update_worker, args=(idx, spec)) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join(timeout=120) + + if errors: + self.fail("Some threads failed:\n" + "\n".join(errors)) + + for idx, count in enumerate(success_counts): + self.assertGreater( + count, 0, + "Thread-{} did not succeed".format(idx) + ) + + # Verify the final data + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + result = table_read.to_arrow(splits) + + ages = result['age'].to_pylist() + expected_ages = [100, 200, 300, 400, 500] + self.assertEqual(expected_ages, ages, + "Concurrent updates did not produce correct final result") + + def test_concurrent_updates_same_rows_with_retry(self): + """Test data evolution with multiple threads updating overlapping rows. + + Multiple threads compete to update the same rows. Each thread retries + on conflict until success. The final result should reflect one of the + successful updates for each row (last-writer-wins). + """ + import threading + + table = self._create_table() + + # All threads update the same rows but with different values + thread_updates = [ + {'row_ids': [0, 1, 2], 'ages': [101, 201, 301], 'thread_name': 'A'}, + {'row_ids': [0, 1, 2], 'ages': [102, 202, 302], 'thread_name': 'B'}, + {'row_ids': [0, 1, 2], 'ages': [103, 203, 303], 'thread_name': 'C'}, + ] + + errors = [] + completion_order = [] + order_lock = threading.Lock() + + def update_worker(thread_index, update_spec): + max_retries = 30 + for attempt in range(max_retries): + try: + write_builder = table.new_batch_write_builder() + table_update = write_builder.new_update().with_update_type(['age']) + + update_data = pa.Table.from_pydict({ + '_ROW_ID': update_spec['row_ids'], + 'age': update_spec['ages'], + }) + + commit_messages = table_update.update_by_arrow_with_row_id(update_data) + + table_commit = write_builder.new_commit() + table_commit.commit(commit_messages) + table_commit.close() + + with order_lock: + completion_order.append(thread_index) + return + except Exception as e: + print( + "Thread-{} ({}) attempt {} failed: {}".format( + thread_index, update_spec['thread_name'], + attempt + 1, e + ) + ) + if attempt == max_retries - 1: + errors.append( + "Thread-{} ({}) failed after {} retries: {}".format( + thread_index, update_spec['thread_name'], + max_retries, e + ) + ) + + threads = [] + for idx, spec in enumerate(thread_updates): + thread = threading.Thread(target=update_worker, args=(idx, spec)) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join(timeout=120) + + if errors: + self.fail("Some threads failed:\n" + "\n".join(errors)) + + self.assertEqual( + len(completion_order), len(thread_updates), + "Not all threads completed successfully" + ) + + # Verify the final data: the last thread to commit wins + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + result = table_read.to_arrow(splits) + + ages = result['age'].to_pylist() + + # The last thread to successfully commit determines rows 0-2 + last_winner = completion_order[-1] + winner_ages = thread_updates[last_winner]['ages'] + self.assertEqual(winner_ages[0], ages[0], + "Row 0 should reflect last writer's value") + self.assertEqual(winner_ages[1], ages[1], + "Row 1 should reflect last writer's value") + self.assertEqual(winner_ages[2], ages[2], + "Row 2 should reflect last writer's value") + + # Rows 3 and 4 should remain unchanged + self.assertEqual(40, ages[3], "Row 3 should remain unchanged") + self.assertEqual(45, ages[4], "Row 4 should remain unchanged") + if __name__ == '__main__': unittest.main() diff --git a/paimon-python/pypaimon/utils/range_helper.py b/paimon-python/pypaimon/utils/range_helper.py new file mode 100644 index 000000000000..5bbc02a55c2d --- /dev/null +++ b/paimon-python/pypaimon/utils/range_helper.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""A helper class to handle ranges. + +Follows the design of Java's org.apache.paimon.utils.RangeHelper. +""" + + +class RangeHelper: + """A helper class to handle ranges. + + Provides methods to check if all ranges are the same and to merge + overlapping ranges into groups, preserving original order within groups. + + Args: + range_function: A callable that extracts a Range from an element T. + """ + + def __init__(self, range_function): + self._range_function = range_function + + def are_all_ranges_same(self, items): + """Check if all items have the same range. + + Args: + items: List of items to check. + + Returns: + True if all items have the same range, False otherwise. + """ + if not items: + return True + + first = items[0] + first_range = self._range_function(first) + if first_range is None: + return False + + for item in items[1:]: + if item is None: + return False + current_range = self._range_function(item) + if current_range is None: + return False + if current_range.from_ != first_range.from_ or current_range.to != first_range.to: + return False + + return True + + def merge_overlapping_ranges(self, items): + """Merge items with overlapping ranges into groups. + + Sorts items by range start, then merges overlapping groups. + Within each group, items are sorted by their original index. + + Args: + items: List of items with non-null ranges. + + Returns: + List of groups, where each group is a list of items + with overlapping ranges. + """ + if not items: + return [] + + # Create indexed values to track original indices + indexed = [] + for original_index, item in enumerate(items): + item_range = self._range_function(item) + if item_range is not None: + indexed.append(_IndexedValue(item, item_range, original_index)) + + if not indexed: + return [] + + # Sort by range start, then by range end + indexed.sort(key=lambda iv: (iv.start(), iv.end())) + + groups = [] + current_group = [indexed[0]] + current_end = indexed[0].end() + + # Iterate through sorted ranges and merge overlapping ones + for i in range(1, len(indexed)): + current = indexed[i] + if current.start() <= current_end: + current_group.append(current) + if current.end() > current_end: + current_end = current.end() + else: + groups.append(current_group) + current_group = [current] + current_end = current.end() + + # Add the last group + groups.append(current_group) + + # Convert groups to result, sorting each group by original index + result = [] + for group in groups: + group.sort(key=lambda iv: iv.original_index) + result.append([iv.value for iv in group]) + + return result + + +class _IndexedValue: + """A helper class to track original indices during range merging.""" + + def __init__(self, value, item_range, original_index): + self.value = value + self.range = item_range + self.original_index = original_index + + def start(self): + return self.range.from_ + + def end(self): + return self.range.to diff --git a/paimon-python/pypaimon/write/commit_message.py b/paimon-python/pypaimon/write/commit_message.py index b36a1b1bbf4f..32c4556a90c5 100644 --- a/paimon-python/pypaimon/write/commit_message.py +++ b/paimon-python/pypaimon/write/commit_message.py @@ -27,6 +27,7 @@ class CommitMessage: partition: Tuple bucket: int new_files: List[DataFileMeta] + snapshot_for_update: int = -1 def is_empty(self): return not self.new_files diff --git a/paimon-python/pypaimon/write/commit_rollback.py b/paimon-python/pypaimon/write/commit_rollback.py new file mode 100644 index 000000000000..80a440dbf40e --- /dev/null +++ b/paimon-python/pypaimon/write/commit_rollback.py @@ -0,0 +1,64 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Commit rollback to rollback 'COMPACT' commits for resolving conflicts. + +Follows the design of Java's org.apache.paimon.operation.commit.CommitRollback. +""" + +from pypaimon.table.instant import Instant + + +class CommitRollback: + """Rollback COMPACT commits to resolve conflicts. + + When a conflict is detected during commit, if the latest snapshot is a + COMPACT commit, it can be rolled back via TableRollback, following the + logic of Java's CommitRollback. + """ + + def __init__(self, table_rollback): + """Initialize CommitRollback. + + Args: + table_rollback: A TableRollback instance used to perform the rollback. + """ + self._table_rollback = table_rollback + + def try_to_rollback(self, latest_snapshot): + """Try to rollback a COMPACT commit to resolve conflicts. + + Only rolls back COMPACT type commits. Delegates to TableRollback + to rollback to the previous snapshot (latest - 1), passing the + latest snapshot ID as from_snapshot. + + Args: + latest_snapshot: The latest snapshot that may need to be rolled back. + + Returns: + True if rollback succeeded, False otherwise. + """ + if latest_snapshot.commit_kind == "COMPACT": + latest_id = latest_snapshot.id + try: + self._table_rollback.rollback_to( + Instant.snapshot(latest_id - 1), latest_id) + return True + except Exception: + pass + return False diff --git a/paimon-python/pypaimon/write/conflict_detection.py b/paimon-python/pypaimon/write/conflict_detection.py new file mode 100644 index 000000000000..0d8efc47ac86 --- /dev/null +++ b/paimon-python/pypaimon/write/conflict_detection.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Conflict detection for commit operations. + +Follows the design of Java's org.apache.paimon.operation.commit.ConflictDetection. +""" + +import logging + +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.manifest.schema.file_entry import FileEntry +from pypaimon.read.scanner.file_scanner import FileScanner +from pypaimon.utils.range import Range +from pypaimon.utils.range_helper import RangeHelper + +logger = logging.getLogger(__name__) + + +class ConflictDetection: + """Detects conflicts between base and delta files during commit. + + Follows the design of Java's ConflictDetection class, providing + row ID range conflict checks and row ID from snapshot conflict checks + for Data Evolution tables. + """ + + def __init__(self, data_evolution_enabled, snapshot_manager, + manifest_list_manager, table): + """Initialize ConflictDetection. + + Args: + data_evolution_enabled: Whether data evolution feature is enabled. + snapshot_manager: Manager for reading snapshot metadata. + manifest_list_manager: Manager for reading manifest lists. + table: The FileStoreTable instance. + """ + self.data_evolution_enabled = data_evolution_enabled + self.snapshot_manager = snapshot_manager + self.manifest_list_manager = manifest_list_manager + self.table = table + self._row_id_check_from_snapshot = None + + def should_be_overwrite_commit(self): + """Check if the commit should be treated as an overwrite commit. + + returns True if rowIdCheckFromSnapshot is set. + + Returns: + True if the commit should be treated as OVERWRITE. + """ + return self._row_id_check_from_snapshot is not None + + def check_conflicts(self, latest_snapshot, base_entries, delta_entries, commit_kind): + """Run all conflict checks and return the first detected conflict. + + merges base_entries and delta_entries, then runs conflict checks + on the merged result. + + Args: + latest_snapshot: The latest snapshot at commit time. + base_entries: All entries read from the latest snapshot. + delta_entries: The delta entries being committed. + commit_kind: The kind of commit (e.g. "APPEND", "COMPACT", "OVERWRITE"). + + Returns: + A RuntimeError if a conflict is detected, otherwise None. + """ + all_entries = list(base_entries) + list(delta_entries) + + try: + merged_entries = FileEntry.merge_entries(all_entries) + except Exception as e: + return RuntimeError( + "File deletion conflicts detected! Give up committing. " + str(e)) + + conflict = self.check_row_id_range_conflicts(commit_kind, merged_entries) + if conflict is not None: + return conflict + + return self.check_for_row_id_from_snapshot(latest_snapshot, delta_entries) + + def check_row_id_range_conflicts(self, commit_kind, commit_entries): + """Check for row ID range conflicts among merged entries. + + only enabled when data evolution is active, and checks that + overlapping row ID ranges in non-blob data files are identical. + + Args: + commit_kind: The kind of commit (e.g. "APPEND", "COMPACT"). + commit_entries: The entries being committed. + + Returns: + A RuntimeError if conflict is detected, otherwise None. + """ + if not self.data_evolution_enabled: + return None + if self._row_id_check_from_snapshot is None and commit_kind != "COMPACT": + return None + + entries_with_row_id = [ + entry for entry in commit_entries + if entry.file.first_row_id is not None + ] + + if not entries_with_row_id: + return None + + range_helper = RangeHelper(lambda entry: entry.file.row_id_range()) + merged_groups = range_helper.merge_overlapping_ranges(entries_with_row_id) + + for group in merged_groups: + data_files = [ + entry for entry in group + if not DataFileMeta.is_blob_file(entry.file.file_name) + ] + if not range_helper.are_all_ranges_same(data_files): + file_descriptions = [ + "{name}(rowId={row_id}, count={count})".format( + name=entry.file.file_name, + row_id=entry.file.first_row_id, + count=entry.file.row_count, + ) + for entry in data_files + ] + return RuntimeError( + "For Data Evolution table, multiple 'MERGE INTO' and 'COMPACT' " + "operations have encountered conflicts, data files: " + + str(file_descriptions)) + + return None + + def check_for_row_id_from_snapshot(self, latest_snapshot, commit_entries): + """Check for row ID conflicts from a specific snapshot onwards. + + collects row ID ranges from delta entries, then checks if any + incremental changes between the check snapshot and latest snapshot + have overlapping row ID ranges. + + Args: + latest_snapshot: The latest snapshot at commit time. + commit_entries: The delta entries being committed. + + Returns: + A RuntimeError if conflict is detected, otherwise None. + """ + if not self.data_evolution_enabled: + return None + if self._row_id_check_from_snapshot is None: + return None + + changed_partitions = set() + for entry in commit_entries: + partition_key = tuple(entry.partition.values) + changed_partitions.add(partition_key) + + history_id_ranges = [] + for entry in commit_entries: + first_row_id = entry.file.first_row_id + row_count = entry.file.row_count + if first_row_id is not None: + history_id_ranges.append( + Range(first_row_id, first_row_id + row_count - 1)) + + check_snapshot = self.snapshot_manager.get_snapshot_by_id( + self._row_id_check_from_snapshot) + if check_snapshot is None or check_snapshot.next_row_id is None: + raise RuntimeError( + "Next row id cannot be null for snapshot " + "{snapshot}.".format(snapshot=self._row_id_check_from_snapshot)) + check_next_row_id = check_snapshot.next_row_id + + for snapshot_id in range( + self._row_id_check_from_snapshot + 1, + latest_snapshot.id + 1): + snapshot = self.snapshot_manager.get_snapshot_by_id(snapshot_id) + if snapshot is None: + continue + if snapshot.commit_kind == "COMPACT": + continue + + incremental_entries = self._read_incremental_entries( + snapshot, changed_partitions) + for entry in incremental_entries: + file_range = entry.file.row_id_range() + if file_range is None: + continue + if file_range.from_ < check_next_row_id: + for history_range in history_id_ranges: + if history_range.overlaps(file_range): + print("conflict2") + return RuntimeError( + "For Data Evolution table, multiple 'MERGE INTO' " + "operations have encountered conflicts, updating " + "the same file, which can render some updates " + "ineffective.") + + return None + + def _read_incremental_entries(self, snapshot, partition_filter): + """Read incremental manifest entries from a snapshot's delta manifest list. + + reads the delta manifest list and filters entries by partition. + + Args: + snapshot: The snapshot to read incremental entries from. + partition_filter: Set of partition tuples to filter by. + + Returns: + List of ManifestEntry matching the partition filter. + """ + delta_manifests = self.manifest_list_manager.read_delta(snapshot) + if not delta_manifests: + return [] + + all_entries = FileScanner( + self.table, lambda: [], None + ).read_manifest_entries(delta_manifests) + + if not partition_filter: + return all_entries + + return [ + entry for entry in all_entries + if tuple(entry.partition.values) in partition_filter + ] diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index 80eb858087ef..abad687b1f1d 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -27,6 +27,8 @@ from pypaimon.manifest.manifest_list_manager import ManifestListManager from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.write.commit_rollback import CommitRollback +from pypaimon.write.conflict_detection import ConflictDetection from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta from pypaimon.manifest.schema.simple_stats import SimpleStats from pypaimon.read.scanner.file_scanner import FileScanner @@ -93,11 +95,26 @@ def __init__(self, snapshot_commit: SnapshotCommit, table, commit_user: str): self.commit_min_retry_wait = table.options.commit_min_retry_wait() self.commit_max_retry_wait = table.options.commit_max_retry_wait() + self.conflict_detection = ConflictDetection( + data_evolution_enabled=table.options.data_evolution_enabled(), + snapshot_manager=self.snapshot_manager, + manifest_list_manager=self.manifest_list_manager, + table=table, + ) + + self.rollback = None + table_rollback = table.catalog_environment.catalog_table_rollback() + if table_rollback is not None: + self.rollback = CommitRollback(table_rollback) + def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in normal append mode.""" if not commit_messages: return + # Extract snapshot_for_update from commit messages + self._apply_row_id_check(commit_messages) + logger.info( "Ready to commit to table %s, number of commit messages: %d", self.table.identifier, @@ -116,9 +133,31 @@ def commit(self, commit_messages: List[CommitMessage], commit_identifier: int): )) logger.info("Finished collecting changes, including: %d entries", len(commit_entries)) - self._try_commit(commit_kind="APPEND", + + commit_kind = "APPEND" + detect_conflicts = False + allow_rollback = False + if self.conflict_detection.should_be_overwrite_commit(): + commit_kind = "OVERWRITE" + detect_conflicts = True + allow_rollback = True + + self._try_commit(commit_kind=commit_kind, commit_identifier=commit_identifier, - commit_entries_plan=lambda snapshot: commit_entries) + commit_entries_plan=lambda snapshot: commit_entries, + detect_conflicts=detect_conflicts, + allow_rollback=allow_rollback) + + def _apply_row_id_check(self, commit_messages: List[CommitMessage]): + """Extract snapshot_for_update from commit messages and apply to conflict detection. + + If any commit message has a snapshot_for_update != -1, set it on the + conflict detection instance for row ID conflict checking. + """ + for msg in commit_messages: + if msg.snapshot_for_update != -1: + self.conflict_detection._row_id_check_from_snapshot = msg.snapshot_for_update + return def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], commit_identifier: int): """Commit the given commit messages in overwrite mode.""" @@ -149,7 +188,9 @@ def overwrite(self, overwrite_partition, commit_messages: List[CommitMessage], c commit_kind="OVERWRITE", commit_identifier=commit_identifier, commit_entries_plan=lambda snapshot: self._generate_overwrite_entries( - snapshot, partition_filter, commit_messages) + snapshot, partition_filter, commit_messages), + detect_conflicts=True, + allow_rollback=True, ) def drop_partitions(self, partitions: List[Dict[str, str]], commit_identifier: int) -> None: @@ -187,10 +228,13 @@ def drop_partitions(self, partitions: List[Dict[str, str]], commit_identifier: i commit_kind="OVERWRITE", commit_identifier=commit_identifier, commit_entries_plan=lambda snapshot: self._generate_overwrite_entries( - snapshot, partition_filter, []) + snapshot, partition_filter, []), + detect_conflicts=True, + allow_rollback=True, ) - def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan): + def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan, + detect_conflicts=False, allow_rollback=False): import threading retry_count = 0 @@ -211,7 +255,9 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan): commit_kind=commit_kind, commit_entries=commit_entries, commit_identifier=commit_identifier, - latest_snapshot=latest_snapshot + latest_snapshot=latest_snapshot, + detect_conflicts=detect_conflicts, + allow_rollback=allow_rollback, ) if result.is_success(): @@ -267,11 +313,13 @@ def _try_commit(self, commit_kind, commit_identifier, commit_entries_plan): def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str, commit_entries: List[ManifestEntry], commit_identifier: int, - latest_snapshot: Optional[Snapshot]) -> CommitResult: + latest_snapshot: Optional[Snapshot], + detect_conflicts: bool = False, + allow_rollback: bool = False) -> CommitResult: start_millis = int(time.time() * 1000) if self._is_duplicate_commit(retry_result, latest_snapshot, commit_identifier, commit_kind): return SuccessResult() - + unique_id = uuid.uuid4() base_manifest_list = f"manifest-list-{unique_id}-0" delta_manifest_list = f"manifest-list-{unique_id}-1" @@ -296,8 +344,20 @@ def _try_commit_once(self, retry_result: Optional[RetryResult], commit_kind: str # Assign row IDs to new files and get the next row ID for the snapshot commit_entries, next_row_id = self._assign_row_tracking_meta(first_row_id_start, commit_entries) + # Conflict detection: read base entries from latest snapshot, then check conflicts + if detect_conflicts and latest_snapshot is not None: + base_entries = self._read_all_entries_from_changed_partitions( + latest_snapshot, commit_entries) + conflict_exception = self.conflict_detection.check_conflicts( + latest_snapshot, base_entries, commit_entries, commit_kind) + + if conflict_exception is not None: + if allow_rollback and self.rollback is not None: + if self.rollback.try_to_rollback(latest_snapshot): + return RetryResult(latest_snapshot, conflict_exception) + raise conflict_exception + try: - # TODO: implement noConflictsOrFail logic new_manifest_file_meta = self._write_manifest_file(commit_entries, new_manifest_file) self.manifest_list_manager.write(delta_manifest_list, [new_manifest_file_meta]) @@ -452,6 +512,62 @@ def _is_duplicate_commit(self, retry_result, latest_snapshot, commit_identifier, return True return False + def _read_all_entries_from_changed_partitions(self, latest_snapshot, delta_entries): + """Read all entries from the latest snapshot for partitions that are changed. + + Builds a partition predicate from delta entries and passes it to FileScanner, + so that manifest files and entries are filtered during reading rather than + after a full scan. + + Args: + latest_snapshot: The latest snapshot to read entries from. + delta_entries: The delta entries being committed, used to determine + which partitions have changed. + + Returns: + List of ManifestEntry from the latest snapshot for changed partitions. + """ + if latest_snapshot is None: + return [] + + partition_filter = self._build_partition_filter_from_entries(delta_entries) + + all_manifests = self.manifest_list_manager.read_all(latest_snapshot) + return FileScanner( + self.table, lambda: [], partition_filter + ).read_manifest_entries(all_manifests) + + def _build_partition_filter_from_entries(self, entries): + """Build a partition predicate that matches all partitions present in the given entries. + + Args: + entries: List of ManifestEntry whose partitions should be matched. + + Returns: + A Predicate matching any of the changed partitions, or None if + partition keys are empty. + """ + partition_keys = self.table.partition_keys + if not partition_keys: + return None + + changed_partitions = set() + for entry in entries: + changed_partitions.add(tuple(entry.partition.values)) + + if not changed_partitions: + return None + + predicate_builder = PredicateBuilder(self.table.fields) + partition_predicates = [] + for partition_values in changed_partitions: + sub_predicates = [] + for i, key in enumerate(partition_keys): + sub_predicates.append(predicate_builder.equal(key, partition_values[i])) + partition_predicates.append(predicate_builder.and_predicates(sub_predicates)) + + return predicate_builder.or_predicates(partition_predicates) + def _generate_overwrite_entries(self, latestSnapshot, partition_filter, commit_messages): """Generate commit entries for OVERWRITE mode based on latest snapshot.""" entries = [] diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py b/paimon-python/pypaimon/write/table_update_by_row_id.py index b027564ff398..a0e737aec12c 100644 --- a/paimon-python/pypaimon/write/table_update_by_row_id.py +++ b/paimon-python/pypaimon/write/table_update_by_row_id.py @@ -47,7 +47,8 @@ def __init__(self, table, commit_user: str): self.commit_user = commit_user # Load existing first_row_ids and build partition map - (self.first_row_ids, + (self.snapshot_id, + self.first_row_ids, self.first_row_id_to_partition_map, self.first_row_id_to_row_count_map, self.total_row_count, @@ -76,7 +77,9 @@ def _load_existing_files_info(self): total_row_count = sum(first_row_id_to_row_count_map.values()) - return (sorted(list(set(first_row_ids))), + snapshot_id = self.table.snapshot_manager().get_latest_snapshot().id + return (snapshot_id, + sorted(list(set(first_row_ids))), first_row_id_to_partition_map, first_row_id_to_row_count_map, total_row_count, @@ -306,8 +309,9 @@ def _write_group(self, partition: GenericRow, first_row_id: int, # Prepare commit and assign first_row_id commit_messages = file_store_write.prepare_commit(BATCH_COMMIT_IDENTIFIER) - # Assign first_row_id to the new files + # Assign first_row_id to the new files and snapshot_for_update for msg in commit_messages: + msg.snapshot_for_update = self.snapshot_id for file in msg.new_files: # Assign the same first_row_id as the original file file.first_row_id = first_row_id