From 65fc183e3902321d6ef4e5ae0cc374f2659ed936 Mon Sep 17 00:00:00 2001 From: geruh Date: Mon, 12 Jan 2026 17:24:23 -0800 Subject: [PATCH] feat: Add rollback_to_snapshot to ManageSnapshots API --- pyiceberg/table/update/snapshot.py | 33 +++++ tests/integration/test_snapshot_operations.py | 26 ++++ tests/table/test_manage_snapshots.py | 126 ++++++++++++++++++ 3 files changed, 185 insertions(+) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index bc05aab966..d54f861892 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -64,6 +64,7 @@ Snapshot, SnapshotSummaryCollector, Summary, + ancestors_of, update_snapshot_summaries, ) from pyiceberg.table.update import ( @@ -985,6 +986,38 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N self._transaction._stage(update, requirement) return self + def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots: + """Rollback the table to the given snapshot id. + + The snapshot needs to be an ancestor of the current table state. + + Args: + snapshot_id (int): rollback to this snapshot_id that used to be current. + Returns: + This for method chaining + Raises: + ValueError: If the snapshot does not exist or is not an ancestor of the current table state. + """ + if not self._transaction.table_metadata.snapshot_by_id(snapshot_id): + raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}") + + if not self._is_current_ancestor(snapshot_id): + raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}") + + return self.set_current_snapshot(snapshot_id=snapshot_id) + + def _is_current_ancestor(self, snapshot_id: int) -> bool: + return snapshot_id in self._current_ancestors() + + def _current_ancestors(self) -> set[int]: + return { + a.snapshot_id + for a in ancestors_of( + self._transaction._table.current_snapshot(), + self._transaction.table_metadata, + ) + } + class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): """Expire snapshots by ID. diff --git a/tests/integration/test_snapshot_operations.py b/tests/integration/test_snapshot_operations.py index 2f0447ec52..68cec645ac 100644 --- a/tests/integration/test_snapshot_operations.py +++ b/tests/integration/test_snapshot_operations.py @@ -160,3 +160,29 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None: tbl = catalog.load_table(identifier) tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit() assert tbl.metadata.refs.get(tag_name, None) is None + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_rollback_to_snapshot(catalog: Catalog) -> None: + identifier = "default.test_table_snapshot_operations" + tbl = catalog.load_table(identifier) + assert len(tbl.history()) > 2 + + # get the current snapshot and an ancestor + current_snapshot_id = tbl.history()[-1].snapshot_id + ancestor_snapshot_id = tbl.history()[-2].snapshot_id + assert ancestor_snapshot_id != current_snapshot_id + + # rollback to the ancestor snapshot + tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() + + tbl = catalog.load_table(identifier) + updated_snapshot = tbl.current_snapshot() + assert updated_snapshot and updated_snapshot.snapshot_id == ancestor_snapshot_id + + # restore table + tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit() + tbl = catalog.load_table(identifier) + restored_snapshot = tbl.current_snapshot() + assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id diff --git a/tests/table/test_manage_snapshots.py b/tests/table/test_manage_snapshots.py index 93301a01c7..20c23fc91c 100644 --- a/tests/table/test_manage_snapshots.py +++ b/tests/table/test_manage_snapshots.py @@ -19,6 +19,7 @@ import pytest +from pyiceberg.io import load_file_io from pyiceberg.table import CommitTableResponse, Table from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate @@ -177,3 +178,128 @@ def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None: # The main branch should point to the same snapshot as the tag main_update = next(u for u in set_ref_updates if u.ref_name == "main") assert main_update.snapshot_id == snapshot_one + + +def test_rollback_to_snapshot(table_v2: Table) -> None: + ancestor_snapshot_id = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 1 + update = set_ref_updates[0] + assert update.snapshot_id == ancestor_snapshot_id + assert update.ref_name == "main" + assert update.type == "branch" + + +def test_rollback_to_snapshot_unknown_id(table_v2: Table) -> None: + invalid_snapshot_id = 1234567890000 + table_v2.catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"): + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit() + + table_v2.catalog.commit_table.assert_not_called() + + +def test_rollback_to_snapshot_not_ancestor(table_v2: Table) -> None: + from pyiceberg.table.metadata import TableMetadataV2 + + # create a table with a branching snapshot history: + snapshot_a = 1 + snapshot_b = 2 # current + snapshot_c = 3 # branch from a, not ancestor of b + + metadata_dict = { + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 3, + "last-updated-ms": 1602638573590, + "last-column-id": 1, + "current-schema-id": 0, + "schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "last-partition-id": 999, + "default-sort-order-id": 0, + "current-snapshot-id": snapshot_b, + "snapshots": [ + { + "snapshot-id": snapshot_a, + "timestamp-ms": 1000, + "sequence-number": 1, + "manifest-list": "s3://a/1.avro", + }, + { + "snapshot-id": snapshot_b, + "parent-snapshot-id": snapshot_a, + "timestamp-ms": 2000, + "sequence-number": 2, + "manifest-list": "s3://a/2.avro", + }, + { + "snapshot-id": snapshot_c, + "parent-snapshot-id": snapshot_a, + "timestamp-ms": 3000, + "sequence-number": 3, + "manifest-list": "s3://a/3.avro", + }, + ], + } + + from pyiceberg.table import Table + + branching_table = Table( + identifier=("db", "table"), + metadata=TableMetadataV2(**metadata_dict), + metadata_location="s3://bucket/test/metadata.json", + io=load_file_io(), + catalog=MagicMock(), + ) + + # snapshot_c exists but is not an ancestor of snapshot_b (current) + with pytest.raises(ValueError, match="Cannot roll back to snapshot, not an ancestor of the current state"): + branching_table.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c).commit() + + +def test_rollback_to_snapshot_chained_with_tag(table_v2: Table) -> None: + ancestor_snapshot_id = 3051729675574597004 + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + ( + table_v2.manage_snapshots() + .create_tag(snapshot_id=ancestor_snapshot_id, tag_name="before-rollback") + .rollback_to_snapshot(snapshot_id=ancestor_snapshot_id) + .commit() + ) + + table_v2.catalog.commit_table.assert_called_once() + + updates = _get_updates(table_v2.catalog) + set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)] + + assert len(set_ref_updates) == 2 + ref_names = {u.ref_name for u in set_ref_updates} + assert ref_names == {"before-rollback", "main"} + + +def test_rollback_to_current_snapshot(table_v2: Table) -> None: + current_snapshot = table_v2.current_snapshot() + assert current_snapshot is not None + + table_v2.catalog = MagicMock() + table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2) + + table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=current_snapshot.snapshot_id).commit() + table_v2.catalog.commit_table.assert_called_once()