Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Snapshot,
SnapshotSummaryCollector,
Summary,
ancestors_of,
update_snapshot_summaries,
)
from pyiceberg.table.update import (
Expand Down Expand Up @@ -985,6 +986,38 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N
self._transaction._stage(update, requirement)
return self

def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
"""Rollback the table to the given snapshot id.

The snapshot needs to be an ancestor of the current table state.

Args:
snapshot_id (int): rollback to this snapshot_id that used to be current.
Returns:
This for method chaining
Raises:
ValueError: If the snapshot does not exist or is not an ancestor of the current table state.
"""
if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")

if not self._is_current_ancestor(snapshot_id):
raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")

return self.set_current_snapshot(snapshot_id=snapshot_id)

def _is_current_ancestor(self, snapshot_id: int) -> bool:
return snapshot_id in self._current_ancestors()

def _current_ancestors(self) -> set[int]:
return {
a.snapshot_id
for a in ancestors_of(
self._transaction._table.current_snapshot(),
self._transaction.table_metadata,
)
}


class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,29 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_rollback_to_snapshot(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2

# get the current snapshot and an ancestor
current_snapshot_id = tbl.history()[-1].snapshot_id
ancestor_snapshot_id = tbl.history()[-2].snapshot_id
assert ancestor_snapshot_id != current_snapshot_id

# rollback to the ancestor snapshot
tbl.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()

tbl = catalog.load_table(identifier)
updated_snapshot = tbl.current_snapshot()
assert updated_snapshot and updated_snapshot.snapshot_id == ancestor_snapshot_id

# restore table
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
tbl = catalog.load_table(identifier)
restored_snapshot = tbl.current_snapshot()
assert restored_snapshot and restored_snapshot.snapshot_id == current_snapshot_id
126 changes: 126 additions & 0 deletions tests/table/test_manage_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest

from pyiceberg.io import load_file_io
from pyiceberg.table import CommitTableResponse, Table
from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate

Expand Down Expand Up @@ -177,3 +178,128 @@ def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None:
# The main branch should point to the same snapshot as the tag
main_update = next(u for u in set_ref_updates if u.ref_name == "main")
assert main_update.snapshot_id == snapshot_one


def test_rollback_to_snapshot(table_v2: Table) -> None:
ancestor_snapshot_id = 3051729675574597004

table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()

table_v2.catalog.commit_table.assert_called_once()

updates = _get_updates(table_v2.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

assert len(set_ref_updates) == 1
update = set_ref_updates[0]
assert update.snapshot_id == ancestor_snapshot_id
assert update.ref_name == "main"
assert update.type == "branch"


def test_rollback_to_snapshot_unknown_id(table_v2: Table) -> None:
invalid_snapshot_id = 1234567890000
table_v2.catalog = MagicMock()

with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"):
table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit()

table_v2.catalog.commit_table.assert_not_called()


def test_rollback_to_snapshot_not_ancestor(table_v2: Table) -> None:
from pyiceberg.table.metadata import TableMetadataV2

# create a table with a branching snapshot history:
snapshot_a = 1
snapshot_b = 2 # current
snapshot_c = 3 # branch from a, not ancestor of b

metadata_dict = {
"format-version": 2,
"table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1",
"location": "s3://bucket/test/location",
"last-sequence-number": 3,
"last-updated-ms": 1602638573590,
"last-column-id": 1,
"current-schema-id": 0,
"schemas": [{"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}],
"default-spec-id": 0,
"partition-specs": [{"spec-id": 0, "fields": []}],
"last-partition-id": 999,
"default-sort-order-id": 0,
"current-snapshot-id": snapshot_b,
"snapshots": [
{
"snapshot-id": snapshot_a,
"timestamp-ms": 1000,
"sequence-number": 1,
"manifest-list": "s3://a/1.avro",
},
{
"snapshot-id": snapshot_b,
"parent-snapshot-id": snapshot_a,
"timestamp-ms": 2000,
"sequence-number": 2,
"manifest-list": "s3://a/2.avro",
},
{
"snapshot-id": snapshot_c,
"parent-snapshot-id": snapshot_a,
"timestamp-ms": 3000,
"sequence-number": 3,
"manifest-list": "s3://a/3.avro",
},
],
}

from pyiceberg.table import Table

branching_table = Table(
identifier=("db", "table"),
metadata=TableMetadataV2(**metadata_dict),
metadata_location="s3://bucket/test/metadata.json",
io=load_file_io(),
catalog=MagicMock(),
)

# snapshot_c exists but is not an ancestor of snapshot_b (current)
with pytest.raises(ValueError, match="Cannot roll back to snapshot, not an ancestor of the current state"):
branching_table.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c).commit()


def test_rollback_to_snapshot_chained_with_tag(table_v2: Table) -> None:
ancestor_snapshot_id = 3051729675574597004

table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

(
table_v2.manage_snapshots()
.create_tag(snapshot_id=ancestor_snapshot_id, tag_name="before-rollback")
.rollback_to_snapshot(snapshot_id=ancestor_snapshot_id)
.commit()
)

table_v2.catalog.commit_table.assert_called_once()

updates = _get_updates(table_v2.catalog)
set_ref_updates = [u for u in updates if isinstance(u, SetSnapshotRefUpdate)]

assert len(set_ref_updates) == 2
ref_names = {u.ref_name for u in set_ref_updates}
assert ref_names == {"before-rollback", "main"}


def test_rollback_to_current_snapshot(table_v2: Table) -> None:
current_snapshot = table_v2.current_snapshot()
assert current_snapshot is not None

table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _mock_commit_response(table_v2)

table_v2.manage_snapshots().rollback_to_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()
table_v2.catalog.commit_table.assert_called_once()