diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 597f62632f..0fe4b1a6a7 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -50,7 +50,7 @@ from pyiceberg.table.maintenance import MaintenanceTable from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata from pyiceberg.table.name_mapping import NameMapping -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.table.update import ( @@ -1219,6 +1219,9 @@ def scan( snapshot_id: int | None = None, options: Properties = EMPTY_DICT, limit: int | None = None, + *, + branch: str | None = None, + tag: str | None = None, ) -> DataScan: """Fetch a DataScan based on the table's current metadata. @@ -1245,10 +1248,24 @@ def scan( An integer representing the number of rows to return in the scan result. If None, fetches all matching rows. + branch: + Optional branch name to scan. If provided, the branch + is resolved to its referenced snapshot ID. + tag: + Optional tag name to scan. If provided, the tag is + resolved to its referenced snapshot ID. Returns: A DataScan based on the table's current metadata. """ + if sum(ref is not None for ref in (snapshot_id, branch, tag)) > 1: + raise ValueError("Cannot specify more than one of snapshot_id, branch, and tag") + + if branch is not None: + snapshot_id = self._scan_ref_snapshot_id(branch, SnapshotRefType.BRANCH) + elif tag is not None: + snapshot_id = self._scan_ref_snapshot_id(tag, SnapshotRefType.TAG) + return DataScan( table_metadata=self.metadata, io=self.io, @@ -1262,6 +1279,18 @@ def scan( table_identifier=self._identifier, ) + def _scan_ref_snapshot_id(self, name: str, snapshot_ref_type: SnapshotRefType) -> int: + ref_type_name = snapshot_ref_type.value + ref = self.metadata.refs.get(name) + if ref is None: + raise ValueError(f"Cannot scan unknown {ref_type_name}={name}") + if ref.snapshot_ref_type != snapshot_ref_type: + raise ValueError(f"Ref {name} is not a {ref_type_name}") + if snapshot := self.metadata.snapshot_by_name(name): + return snapshot.snapshot_id + + raise ValueError(f"Cannot scan unknown {ref_type_name}={name}") + @property def format_version(self) -> TableVersion: return self.metadata.format_version @@ -1775,6 +1804,9 @@ def scan( snapshot_id: int | None = None, options: Properties = EMPTY_DICT, limit: int | None = None, + *, + branch: str | None = None, + tag: str | None = None, ) -> DataScan: raise ValueError("Cannot scan a staged table") diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 7931edacdd..2a93367964 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -1040,18 +1040,22 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): _updates: tuple[TableUpdate, ...] _requirements: tuple[TableRequirement, ...] _snapshot_ids_to_expire: set[int] + _retain_last: int | None + _has_other_selectors: bool def __init__(self, transaction: Transaction) -> None: super().__init__(transaction) self._updates = () self._requirements = () self._snapshot_ids_to_expire = set() + self._retain_last = None + self._has_other_selectors = False def _commit(self) -> UpdatesAndRequirements: """ Commit the staged updates and requirements. - This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads). + This will remove the snapshots with the given IDs, but will always skip protected snapshots. Returns: Tuple of updates and requirements to be committed, @@ -1060,6 +1064,28 @@ def _commit(self) -> UpdatesAndRequirements: # Remove any protected snapshot IDs from the set to expire, just in case protected_ids = self._get_protected_snapshot_ids() self._snapshot_ids_to_expire -= protected_ids + + if self._retain_last is not None: + standalone = not self._has_other_selectors + unprotected_snapshots = sorted( + [ + snapshot + for snapshot in self._transaction.table_metadata.snapshots + if snapshot.snapshot_id not in protected_ids + ], + key=lambda snapshot: ( + snapshot.timestamp_ms, + snapshot.sequence_number if snapshot.sequence_number is not None else -1, + ), + reverse=True, + ) + keep_ids = {snapshot.snapshot_id for snapshot in unprotected_snapshots[: self._retain_last]} + surplus_ids = {snapshot.snapshot_id for snapshot in unprotected_snapshots[self._retain_last :]} + + self._snapshot_ids_to_expire -= keep_ids + if standalone: + self._snapshot_ids_to_expire |= surplus_ids + update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire) self._updates += (update,) return self._updates, self._requirements @@ -1068,16 +1094,22 @@ def _get_protected_snapshot_ids(self) -> set[int]: """ Get the IDs of protected snapshots. - These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration. + These are the HEAD snapshots of all branches, all tagged snapshots, and the current snapshot. These ids are to be + excluded from expiration. Returns: Set of protected snapshot IDs to exclude from expiration. """ - return { + table_metadata = self._transaction.table_metadata + protected_ids = { ref.snapshot_id - for ref in self._transaction.table_metadata.refs.values() + for ref in table_metadata.refs.values() if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH] } + if table_metadata.current_snapshot_id is not None: + protected_ids.add(table_metadata.current_snapshot_id) + + return protected_ids def by_id(self, snapshot_id: int) -> ExpireSnapshots: """ @@ -1096,6 +1128,7 @@ def by_id(self, snapshot_id: int) -> ExpireSnapshots: if snapshot_id in self._get_protected_snapshot_ids(): raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.") + self._has_other_selectors = True self._snapshot_ids_to_expire.add(snapshot_id) return self @@ -1111,6 +1144,7 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots: Returns: This for method chaining. """ + self._has_other_selectors = True for snapshot_id in snapshot_ids: self.by_id(snapshot_id) return self @@ -1125,9 +1159,31 @@ def older_than(self, dt: datetime) -> ExpireSnapshots: Returns: This for method chaining. """ + self._has_other_selectors = True protected_ids = self._get_protected_snapshot_ids() expire_from = datetime_to_millis(dt) for snapshot in self._transaction.table_metadata.snapshots: if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids: self._snapshot_ids_to_expire.add(snapshot.snapshot_id) return self + + def retain_last(self, num_snapshots: int) -> ExpireSnapshots: + """ + Retain the N most-recent unprotected snapshots. + + Used alone, this expires all unprotected snapshots except the newest N. When combined + with older_than/by_id/by_ids, the newest N unprotected snapshots are kept as a floor; + explicitly selected IDs in that newest N are silently kept. Protected snapshots are + always kept and are not counted toward N. + + Args: + num_snapshots (int): Number of newest unprotected snapshots to retain. + + Returns: + This for method chaining. + """ + if num_snapshots < 1: + raise ValueError("Number of snapshots to retain must be at least 1") + + self._retain_last = num_snapshots + return self diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index 106e5b786c..b687af60cb 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -22,7 +22,9 @@ import pytest from pyiceberg.table import CommitTableResponse, Table -from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata +from pyiceberg.table.refs import SnapshotRef, SnapshotRefType +from pyiceberg.table.snapshots import SnapshotLogEntry +from pyiceberg.table.update import RemoveSnapshotsUpdate, TableRequirement, TableUpdate, update_table_metadata from pyiceberg.table.update.snapshot import ExpireSnapshots @@ -75,6 +77,29 @@ def test_cannot_expire_tagged_snapshot(table_v2: Table) -> None: table_v2.catalog.commit_table.assert_not_called() +def test_cannot_expire_current_snapshot_without_ref(table_v2: Table) -> None: + current_snapshot_id = table_v2.metadata.current_snapshot_id + assert current_snapshot_id is not None + non_current_snapshot_id = next( + snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots if snapshot.snapshot_id != current_snapshot_id + ) + + table_v2.catalog = MagicMock() + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + "main": SnapshotRef(snapshot_id=non_current_snapshot_id, snapshot_ref_type=SnapshotRefType.BRANCH), + } + } + ) + assert all(ref.snapshot_id != current_snapshot_id for ref in table_v2.metadata.refs.values()) + + with pytest.raises(ValueError, match=f"Snapshot with ID {current_snapshot_id} is protected and cannot be expired."): + table_v2.maintenance.expire_snapshots().by_id(current_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_not_called() + + def test_expire_unprotected_snapshot(table_v2: Table) -> None: """Test that an unprotected snapshot can be expired.""" EXPIRE_SNAPSHOT = 3051729675574597004 @@ -316,3 +341,241 @@ def test_update_remove_snapshots_with_statistics(table_v2_with_statistics: Table assert not any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in new_metadata.statistics), ( "Statistics for removed snapshot should be gone" ) + + +def _prepare_table_with_snapshots( + table: Table, + snapshot_ids_and_timestamps: list[tuple[int, int]], + refs: dict[str, SnapshotRef] | None = None, + current_snapshot_id: int | None = None, +) -> None: + base_snapshot = table.metadata.snapshots[0] + snapshots = [] + snapshot_log = [] + parent_snapshot_id = None + + for sequence_number, (snapshot_id, timestamp_ms) in enumerate(snapshot_ids_and_timestamps): + snapshots.append( + base_snapshot.model_copy( + update={ + "snapshot_id": snapshot_id, + "parent_snapshot_id": parent_snapshot_id, + "sequence_number": sequence_number, + "timestamp_ms": timestamp_ms, + "manifest_list": f"s3://bucket/test/{snapshot_id}.avro", + } + ) + ) + snapshot_log.append(SnapshotLogEntry(snapshot_id=snapshot_id, timestamp_ms=timestamp_ms)) + parent_snapshot_id = snapshot_id + + table.metadata = table.metadata.model_copy( + update={ + "current_snapshot_id": current_snapshot_id, + "refs": refs or {}, + "snapshots": snapshots, + "snapshot_log": snapshot_log, + } + ) + + +def _configure_commit_to_apply_updates(table: Table) -> None: + def commit_table( + _table: Table, _requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...] + ) -> CommitTableResponse: + return CommitTableResponse( + metadata=update_table_metadata(table.metadata, updates), + metadata_location="mock://metadata/location", + uuid=uuid4(), + ) + + table.catalog = MagicMock() + table.catalog.commit_table.side_effect = commit_table + + +def test_retain_last_two_expires_surplus_unprotected_snapshots(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + (104, 4000), + (105, 5000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + table_v2.maintenance.expire_snapshots().retain_last(2).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {104, 105} + + +def test_retain_last_one_keeps_only_newest_unprotected_snapshot(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + table_v2.maintenance.expire_snapshots().retain_last(1).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {103} + + +def test_retain_last_keeps_current_snapshot_without_counting_it(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + (104, 4000), + ], + refs={ + "main": SnapshotRef(snapshot_id=104, snapshot_ref_type=SnapshotRefType.BRANCH), + }, + current_snapshot_id=101, + ) + _configure_commit_to_apply_updates(table_v2) + assert all(ref.snapshot_id != 101 for ref in table_v2.metadata.refs.values()) + + table_v2.maintenance.expire_snapshots().retain_last(1).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {101, 103, 104} + assert table_v2.metadata.current_snapshot_id == 101 + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + assert current_snapshot.snapshot_id == 101 + + +def test_older_than_with_retain_last_keeps_newest_unprotected_floor(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + (104, 4000), + (105, 5000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + table_v2.maintenance.expire_snapshots().older_than(datetime(1970, 1, 1, 0, 0, 10)).retain_last(2).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {104, 105} + + +def test_older_than_with_retain_last_intersection(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + (104, 4000), + (105, 5000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + table_v2.maintenance.expire_snapshots().older_than(datetime(1970, 1, 1, 0, 0, 2, 500000)).retain_last(2).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {103, 104, 105} + + +def test_older_than_matching_nothing_with_retain_last_expires_nothing(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + (104, 4000), + (105, 5000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + # Cutoff before the oldest snapshot: older_than selects nothing, so retain_last must act + # only as a floor and expire nothing (it must NOT fall back to standalone behavior). + table_v2.maintenance.expire_snapshots().older_than(datetime(1970, 1, 1, 0, 0, 0, 500000)).retain_last(2).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {101, 102, 103, 104, 105} + + +def test_empty_by_ids_with_retain_last_expires_nothing(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + # An explicit (if empty) by_ids selector means retain_last acts only as a floor, not standalone. + table_v2.maintenance.expire_snapshots().by_ids([]).retain_last(1).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {101, 102, 103} + + +def test_retain_last_tiebreak_uses_sequence_number(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (102, 1000), + (101, 1000), + ], + ) + _configure_commit_to_apply_updates(table_v2) + + table_v2.maintenance.expire_snapshots().retain_last(1).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {101} + + +def test_retain_last_keeps_protected_snapshots_without_counting_them(table_v2: Table) -> None: + _prepare_table_with_snapshots( + table_v2, + [ + (101, 1000), + (102, 2000), + (103, 3000), + (104, 4000), + ], + refs={ + "old-tag": SnapshotRef(snapshot_id=101, snapshot_ref_type=SnapshotRefType.TAG), + "main": SnapshotRef(snapshot_id=104, snapshot_ref_type=SnapshotRefType.BRANCH), + }, + current_snapshot_id=104, + ) + _configure_commit_to_apply_updates(table_v2) + + table_v2.maintenance.expire_snapshots().retain_last(1).commit() + + remaining_ids = {snapshot.snapshot_id for snapshot in table_v2.metadata.snapshots} + assert remaining_ids == {101, 103, 104} + + +def test_retain_last_requires_at_least_one_snapshot(table_v2: Table) -> None: + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Number of snapshots to retain must be at least 1"): + table_v2.maintenance.expire_snapshots().retain_last(0) + + table_v2.catalog.commit_table.assert_not_called() diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 7e64e6e7c0..c3dc47d9d5 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -288,6 +288,60 @@ def test_table_scan_ref(table_v2: Table) -> None: assert scan.use_ref("test").snapshot_id == 3051729675574597004 +def test_table_scan_branch_keyword(table_v2: Table) -> None: + table_v2.metadata = table_v2.metadata.model_copy( + update={ + "refs": { + **table_v2.metadata.refs, + "test-branch": SnapshotRef(snapshot_id=3055729675574597004, snapshot_ref_type=SnapshotRefType.BRANCH), + } + } + ) + + scan = table_v2.scan(branch="test-branch") + + assert scan.snapshot_id == 3055729675574597004 + + +def test_table_scan_tag_keyword(table_v2: Table) -> None: + scan = table_v2.scan(tag="test") + + assert scan.snapshot_id == 3051729675574597004 + + +def test_table_scan_branch_and_tag_raises(table_v2: Table) -> None: + with pytest.raises(ValueError, match="Cannot specify more than one of snapshot_id, branch, and tag"): + table_v2.scan(branch="test-branch", tag="test") + + +def test_table_scan_snapshot_id_and_branch_raises(table_v2: Table) -> None: + with pytest.raises(ValueError, match="Cannot specify more than one of snapshot_id, branch, and tag"): + table_v2.scan(snapshot_id=3051729675574597004, branch="test-branch") + + +def test_table_scan_unknown_branch_raises(table_v2: Table) -> None: + with pytest.raises(ValueError, match="Cannot scan unknown branch=does-not-exist"): + table_v2.scan(branch="does-not-exist") + + +def test_table_scan_branch_with_tag_ref_raises(table_v2: Table) -> None: + with pytest.raises(ValueError, match="Ref test is not a branch"): + table_v2.scan(branch="test") + + +def test_table_scan_tag_with_branch_ref_raises(table_v2: Table) -> None: + # `main` is synthesized as a BRANCH ref from current-snapshot-id; scanning it as a tag must fail. + with pytest.raises(ValueError, match="Ref main is not a tag"): + table_v2.scan(tag="main") + + +def test_table_scan_branch_keyword_main(table_v2: Table) -> None: + # The `main` branch ref is synthesized from current-snapshot-id and must be scannable by name. + scan = table_v2.scan(branch="main") + + assert scan.snapshot_id == table_v2.metadata.current_snapshot_id + + def test_table_scan_ref_does_not_exists(table_v2: Table) -> None: scan = table_v2.scan()