Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from pyiceberg.table.maintenance import MaintenanceTable
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata
from pyiceberg.table.name_mapping import NameMapping
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.table.update import (
Expand Down Expand Up @@ -1219,6 +1219,9 @@ def scan(
snapshot_id: int | None = None,
options: Properties = EMPTY_DICT,
limit: int | None = None,
*,
branch: str | None = None,
tag: str | None = None,
) -> DataScan:
"""Fetch a DataScan based on the table's current metadata.

Expand All @@ -1245,10 +1248,24 @@ def scan(
An integer representing the number of rows to
return in the scan result. If None, fetches all
matching rows.
branch:
Optional branch name to scan. If provided, the branch
is resolved to its referenced snapshot ID.
tag:
Optional tag name to scan. If provided, the tag is
resolved to its referenced snapshot ID.

Returns:
A DataScan based on the table's current metadata.
"""
if sum(ref is not None for ref in (snapshot_id, branch, tag)) > 1:
raise ValueError("Cannot specify more than one of snapshot_id, branch, and tag")

if branch is not None:
snapshot_id = self._scan_ref_snapshot_id(branch, SnapshotRefType.BRANCH)
elif tag is not None:
snapshot_id = self._scan_ref_snapshot_id(tag, SnapshotRefType.TAG)

return DataScan(
table_metadata=self.metadata,
io=self.io,
Expand All @@ -1262,6 +1279,18 @@ def scan(
table_identifier=self._identifier,
)

def _scan_ref_snapshot_id(self, name: str, snapshot_ref_type: SnapshotRefType) -> int:
ref_type_name = snapshot_ref_type.value
ref = self.metadata.refs.get(name)
if ref is None:
raise ValueError(f"Cannot scan unknown {ref_type_name}={name}")
if ref.snapshot_ref_type != snapshot_ref_type:
raise ValueError(f"Ref {name} is not a {ref_type_name}")
if snapshot := self.metadata.snapshot_by_name(name):
return snapshot.snapshot_id

raise ValueError(f"Cannot scan unknown {ref_type_name}={name}")

@property
def format_version(self) -> TableVersion:
return self.metadata.format_version
Expand Down Expand Up @@ -1775,6 +1804,9 @@ def scan(
snapshot_id: int | None = None,
options: Properties = EMPTY_DICT,
limit: int | None = None,
*,
branch: str | None = None,
tag: str | None = None,
) -> DataScan:
raise ValueError("Cannot scan a staged table")

Expand Down
64 changes: 60 additions & 4 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,18 +1040,22 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
_updates: tuple[TableUpdate, ...]
_requirements: tuple[TableRequirement, ...]
_snapshot_ids_to_expire: set[int]
_retain_last: int | None
_has_other_selectors: bool

def __init__(self, transaction: Transaction) -> None:
super().__init__(transaction)
self._updates = ()
self._requirements = ()
self._snapshot_ids_to_expire = set()
self._retain_last = None
self._has_other_selectors = False

def _commit(self) -> UpdatesAndRequirements:
"""
Commit the staged updates and requirements.

This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
This will remove the snapshots with the given IDs, but will always skip protected snapshots.

Returns:
Tuple of updates and requirements to be committed,
Expand All @@ -1060,6 +1064,28 @@ def _commit(self) -> UpdatesAndRequirements:
# Remove any protected snapshot IDs from the set to expire, just in case
protected_ids = self._get_protected_snapshot_ids()
self._snapshot_ids_to_expire -= protected_ids

if self._retain_last is not None:
standalone = not self._has_other_selectors
unprotected_snapshots = sorted(
[
snapshot
for snapshot in self._transaction.table_metadata.snapshots
if snapshot.snapshot_id not in protected_ids
],
key=lambda snapshot: (
snapshot.timestamp_ms,
snapshot.sequence_number if snapshot.sequence_number is not None else -1,
),
reverse=True,
)
keep_ids = {snapshot.snapshot_id for snapshot in unprotected_snapshots[: self._retain_last]}
surplus_ids = {snapshot.snapshot_id for snapshot in unprotected_snapshots[self._retain_last :]}

self._snapshot_ids_to_expire -= keep_ids
if standalone:
self._snapshot_ids_to_expire |= surplus_ids

update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
self._updates += (update,)
return self._updates, self._requirements
Expand All @@ -1068,16 +1094,22 @@ def _get_protected_snapshot_ids(self) -> set[int]:
"""
Get the IDs of protected snapshots.

These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
These are the HEAD snapshots of all branches, all tagged snapshots, and the current snapshot. These ids are to be
excluded from expiration.

Returns:
Set of protected snapshot IDs to exclude from expiration.
"""
return {
table_metadata = self._transaction.table_metadata
protected_ids = {
ref.snapshot_id
for ref in self._transaction.table_metadata.refs.values()
for ref in table_metadata.refs.values()
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]
}
if table_metadata.current_snapshot_id is not None:
protected_ids.add(table_metadata.current_snapshot_id)

return protected_ids

def by_id(self, snapshot_id: int) -> ExpireSnapshots:
"""
Expand All @@ -1096,6 +1128,7 @@ def by_id(self, snapshot_id: int) -> ExpireSnapshots:
if snapshot_id in self._get_protected_snapshot_ids():
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")

self._has_other_selectors = True
self._snapshot_ids_to_expire.add(snapshot_id)

return self
Expand All @@ -1111,6 +1144,7 @@ def by_ids(self, snapshot_ids: list[int]) -> ExpireSnapshots:
Returns:
This for method chaining.
"""
self._has_other_selectors = True
for snapshot_id in snapshot_ids:
self.by_id(snapshot_id)
return self
Expand All @@ -1125,9 +1159,31 @@ def older_than(self, dt: datetime) -> ExpireSnapshots:
Returns:
This for method chaining.
"""
self._has_other_selectors = True
protected_ids = self._get_protected_snapshot_ids()
expire_from = datetime_to_millis(dt)
for snapshot in self._transaction.table_metadata.snapshots:
if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids:
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
return self

def retain_last(self, num_snapshots: int) -> ExpireSnapshots:
"""
Retain the N most-recent unprotected snapshots.

Used alone, this expires all unprotected snapshots except the newest N. When combined
with older_than/by_id/by_ids, the newest N unprotected snapshots are kept as a floor;
explicitly selected IDs in that newest N are silently kept. Protected snapshots are
always kept and are not counted toward N.

Args:
num_snapshots (int): Number of newest unprotected snapshots to retain.

Returns:
This for method chaining.
"""
if num_snapshots < 1:
raise ValueError("Number of snapshots to retain must be at least 1")

self._retain_last = num_snapshots
return self
Loading