From 83e4756c13faf9d877f347bbc7743eefc1ac5656 Mon Sep 17 00:00:00 2001 From: Abanoub Doss Date: Tue, 23 Jun 2026 21:27:34 -0500 Subject: [PATCH 1/3] Preserve dictionary_columns in DataScan.to_arrow_batch_reader Fixes #3540. to_arrow_batch_reader(dictionary_columns=...) cast each batch to a target schema built from schema_to_pyarrow(projection()), which has no dictionary types, silently decoding dictionary-encoded columns back to plain arrays (to_arrow does not, because it concatenates with permissive promotion). Derive the reader's target schema from the first scan batch so requested columns that ArrowScan actually dictionary-encodes (strings) stay dictionary typed, while columns it leaves plain (ints, ORC, etc.) stay plain - matching to_arrow. The trailing cast still conforms later batches. Adds regression tests covering string, non-string, and mixed dictionary_columns. --- pyiceberg/table/__init__.py | 18 ++++++++- tests/io/test_pyarrow.py | 73 +++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 597f62632f..4046da1924 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2219,7 +2219,6 @@ def to_arrow_batch_reader(self, dictionary_columns: tuple[str, ...] = ()) -> pa. from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow - target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, @@ -2230,6 +2229,23 @@ def to_arrow_batch_reader(self, dictionary_columns: tuple[str, ...] = ()) -> pa. dictionary_columns=dictionary_columns, ).to_record_batches(self.plan_files()) + target_schema = schema_to_pyarrow(self.projection()) + if dictionary_columns: + batches = iter(batches) + try: + first_batch = next(batches) + except StopIteration: + pass + else: + target_schema = pa.schema( + [ + field.with_type(batch_field.type) + for field, batch_field in zip(target_schema, first_batch.schema, strict=True) + ], + metadata=target_schema.metadata, + ) + batches = chain([first_batch], batches) + return pa.RecordBatchReader.from_batches( target_schema, batches, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 407ec611fd..faaecd3618 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -5297,3 +5297,76 @@ def test_dictionary_columns_produces_dict_encoded_output(tmpdir: str) -> None: # Values must be identical assert result_plain.column("label").to_pylist() == result_dict.column("label").to_pylist() + + +@pytest.fixture +def sql_catalog(tmp_path: Path) -> Iterator[Any]: + from pyiceberg.catalog.sql import SqlCatalog + + catalog = SqlCatalog( + "test_sql_catalog", + uri="sqlite:///:memory:", + warehouse=f"file://{tmp_path}", + ) + catalog.create_tables() + try: + yield catalog + finally: + catalog.destroy_tables() + catalog.close() + + +def _create_dictionary_batch_reader_table(catalog: Any, identifier: str) -> Any: + arrow_table = pa.table( + { + "id": pa.array([1, 2, 3, 4], type=pa.int64()), + "label": pa.array(["a", "b", "a", "b"], type=pa.string()), + } + ) + catalog.create_namespace_if_not_exists("default") + table = catalog.create_table(identifier, schema=arrow_table.schema) + table.append(arrow_table) + return table + + +def test_to_arrow_batch_reader_preserves_dictionary_columns(sql_catalog: Any) -> None: + """Regression test for issue #3540. + + ``to_arrow(dictionary_columns=...)`` preserves dictionary encoding, but + ``to_arrow_batch_reader(dictionary_columns=...)`` previously decoded it back + to plain strings via a final ``.cast(target_schema)`` where ``target_schema`` + had no dictionary types. Both public paths must now preserve the encoding. + """ + table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_batch_reader_test") + expected = table.scan().to_arrow(dictionary_columns=("label",)) + result = table.scan().to_arrow_batch_reader(dictionary_columns=("label",)).read_all() + + assert result.schema.field("label").type == expected.schema.field("label").type + assert pa.types.is_dictionary(result.schema.field("label").type) + assert result.column("label").to_pylist() == ["a", "b", "a", "b"] + assert result.to_pydict() == expected.to_pydict() + # A column not in dictionary_columns stays a plain (non-dict) array. + assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64() + assert not pa.types.is_dictionary(result.schema.field("id").type) + + +def test_to_arrow_batch_reader_dictionary_columns_keep_non_string_columns_plain(sql_catalog: Any) -> None: + table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_int_batch_reader_test") + expected = table.scan().to_arrow(dictionary_columns=("id",)) + result = table.scan().to_arrow_batch_reader(dictionary_columns=("id",)).read_all() + + assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64() + assert not pa.types.is_dictionary(result.schema.field("id").type) + assert result.to_pydict() == expected.to_pydict() + + +def test_to_arrow_batch_reader_dictionary_columns_allow_mixed_types(sql_catalog: Any) -> None: + table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_mixed_batch_reader_test") + expected = table.scan().to_arrow(dictionary_columns=("label", "id")) + result = table.scan().to_arrow_batch_reader(dictionary_columns=("label", "id")).read_all() + + assert result.schema.field("label").type == expected.schema.field("label").type + assert pa.types.is_dictionary(result.schema.field("label").type) + assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64() + assert not pa.types.is_dictionary(result.schema.field("id").type) + assert result.to_pydict() == expected.to_pydict() From ee27bdaca94b8af65f46d8efadbadd91e9e57fd4 Mon Sep 17 00:00:00 2001 From: Abanoub Doss Date: Tue, 23 Jun 2026 21:27:34 -0500 Subject: [PATCH 2/3] Add opt-in physical file cleanup to expire_snapshots Fixes #2604. ExpireSnapshots was metadata-only and leaked the data, delete, manifest, manifest-list, and statistics files of expired snapshots forever. Add opt-in ExpireSnapshots.clean_expired_files(). On the autocommit path, collect files reachable from the expiring snapshots before the metadata commit, then after the commit collect files still reachable from every surviving snapshot (which covers all branches and tags), and best-effort delete the difference. Surviving-file resolution runs strict: if it cannot be fully resolved the cleanup aborts rather than risk deleting a live file. Statistics and partition-statistics puffin files are cleaned the same way. Cleanup is off by default and skipped for non-autocommit transactions. --- pyiceberg/table/update/snapshot.py | 131 +++++++++++++- tests/table/test_expire_snapshots.py | 259 ++++++++++++++++++++++++++- 2 files changed, 386 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 7931edacdd..3419aa23c6 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -17,10 +17,11 @@ from __future__ import annotations import itertools +import logging import uuid from abc import abstractmethod from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Iterable from datetime import datetime from functools import cached_property from typing import TYPE_CHECKING, Generic @@ -79,6 +80,10 @@ if TYPE_CHECKING: from pyiceberg.table import Transaction + from pyiceberg.table.metadata import TableMetadata + + +logger = logging.getLogger(__name__) def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str: @@ -1039,13 +1044,17 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): _updates: tuple[TableUpdate, ...] _requirements: tuple[TableRequirement, ...] + _io: FileIO _snapshot_ids_to_expire: set[int] + _clean_expired_files: bool def __init__(self, transaction: Transaction) -> None: super().__init__(transaction) self._updates = () self._requirements = () + self._io = transaction._table.io self._snapshot_ids_to_expire = set() + self._clean_expired_files = False def _commit(self) -> UpdatesAndRequirements: """ @@ -1064,6 +1073,126 @@ def _commit(self) -> UpdatesAndRequirements: self._updates += (update,) return self._updates, self._requirements + def commit(self) -> None: + if not self._clean_expired_files or not self._snapshot_ids_to_expire: + super().commit() + return + + if not self._transaction._autocommit: + super().commit() + logger.debug("Skipping expired-file cleanup for non-autocommit transaction; cleanup only runs on autocommit") + return + + self._snapshot_ids_to_expire -= self._get_protected_snapshot_ids() + if not self._snapshot_ids_to_expire: + super().commit() + return + + pre_commit_metadata = self._transaction.table_metadata + expired_snapshots = [ + snapshot for snapshot in pre_commit_metadata.snapshots if snapshot.snapshot_id in self._snapshot_ids_to_expire + ] + expired_files = self._reachable_files(expired_snapshots, strict=False) + expired_files.update(self._statistics_paths(pre_commit_metadata, self._snapshot_ids_to_expire, belonging=True)) + + super().commit() + + try: + surviving_files = self._reachable_files(self._transaction.table_metadata.snapshots, strict=True) + except Exception: + logger.warning( + "skipping expired-file cleanup: could not fully resolve surviving files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return + surviving_files.update( + self._statistics_paths(self._transaction.table_metadata, self._snapshot_ids_to_expire, belonging=False) + ) + + for path in expired_files - surviving_files: + try: + self._io.delete(path) + except Exception: + logger.warning( + "Failed to delete expired snapshot file %s", + path, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + + def clean_expired_files(self, clean: bool = True) -> ExpireSnapshots: + """Clean up files that are no longer reachable after expiring snapshots.""" + self._clean_expired_files = clean + return self + + def _reachable_files(self, snapshots: Iterable[Snapshot], *, strict: bool) -> set[str]: + reachable_files: set[str] = set() + + for snapshot in snapshots: + reachable_files.add(snapshot.manifest_list) + try: + manifests = snapshot.manifests(self._io) + except Exception: + if strict: + raise + logger.debug( + "Skipping manifest list while collecting reachable files for snapshot %s: %s", + snapshot.snapshot_id, + snapshot.manifest_list, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + continue + + for manifest in manifests: + reachable_files.add(manifest.manifest_path) + try: + entries = manifest.fetch_manifest_entry(io=self._io, discard_deleted=False) + except Exception: + if strict: + raise + logger.debug( + "Skipping manifest while collecting reachable files for snapshot %s: %s", + snapshot.snapshot_id, + manifest.manifest_path, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + continue + + reachable_files.update(entry.data_file.file_path for entry in entries) + + return reachable_files + + @staticmethod + def _statistics_paths(metadata: TableMetadata, snapshot_ids: set[int], *, belonging: bool) -> set[str]: + statistics_paths: set[str] = set() + + try: + statistics = getattr(metadata, "statistics", None) or () + partition_statistics = getattr(metadata, "partition_statistics", None) or () + except Exception: + logger.debug( + "Skipping statistics while collecting reachable files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return statistics_paths + + try: + for stat in itertools.chain(statistics, partition_statistics): + try: + if (stat.snapshot_id in snapshot_ids) == belonging and stat.statistics_path: + statistics_paths.add(stat.statistics_path) + except Exception: + logger.debug( + "Skipping statistics file while collecting reachable files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + except Exception: + logger.debug( + "Skipping statistics while collecting reachable files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + + return statistics_paths + def _get_protected_snapshot_ids(self) -> set[int]: """ Get the IDs of protected snapshots. diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index 106e5b786c..178f44b486 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -16,16 +16,269 @@ # under the License. import threading from datetime import datetime, timedelta -from unittest.mock import MagicMock, Mock +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch +from urllib.parse import unquote, urlparse from uuid import uuid4 +import pyarrow as pa import pytest -from pyiceberg.table import CommitTableResponse, Table -from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata +from pyiceberg.catalog.sql import SqlCatalog +from pyiceberg.table import CommitTableResponse, Table, Transaction +from pyiceberg.table.snapshots import Snapshot +from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile +from pyiceberg.table.update import RemoveSnapshotsUpdate, SetPartitionStatisticsUpdate, update_table_metadata from pyiceberg.table.update.snapshot import ExpireSnapshots +def _sql_catalog(tmp_path: Path) -> SqlCatalog: + catalog = SqlCatalog( + "test", + uri=f"sqlite:///{tmp_path}/pyiceberg_catalog.db", + warehouse=f"file://{tmp_path}", + ) + catalog.create_namespace_if_not_exists("default") + return catalog + + +def _arrow_table(rows: list[dict[str, object]]) -> pa.Table: + return pa.Table.from_pylist( + rows, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.string())]), + ) + + +def _local_path(location: str) -> Path: + parsed = urlparse(location) + if parsed.scheme == "file": + return Path(unquote(parsed.path)) + return Path(location) + + +def _touch_file(location: str) -> None: + path = _local_path(location) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(b"") + + +def _statistics_file(snapshot_id: int, statistics_path: str) -> StatisticsFile: + return StatisticsFile( + snapshot_id=snapshot_id, + statistics_path=statistics_path, + file_size_in_bytes=0, + file_footer_size_in_bytes=0, + blob_metadata=[ + BlobMetadata( + type="apache-datasketches-theta-v1", + snapshot_id=snapshot_id, + sequence_number=0, + fields=[1], + ) + ], + ) + + +def _partition_statistics_file(snapshot_id: int, statistics_path: str) -> PartitionStatisticsFile: + return PartitionStatisticsFile( + snapshot_id=snapshot_id, + statistics_path=statistics_path, + file_size_in_bytes=0, + ) + + +def _data_file_paths(table: Table, snapshot_id: int) -> set[str]: + snapshot = table.metadata.snapshot_by_id(snapshot_id) + assert snapshot is not None + return { + entry.data_file.file_path + for manifest in snapshot.manifests(table.io) + for entry in manifest.fetch_manifest_entry(io=table.io, discard_deleted=True) + } + + +def _non_current_snapshot_ids(table: Table) -> list[int]: + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + return [snapshot.snapshot_id for snapshot in table.metadata.snapshots if snapshot.snapshot_id != current_snapshot.snapshot_id] + + +def _create_overwritten_table(catalog: SqlCatalog, identifier: str) -> tuple[Table, int, str]: + table = catalog.create_table(identifier, schema=_arrow_table([{"id": 1, "data": "before"}]).schema) + table.append(_arrow_table([{"id": 1, "data": "before"}])) + first_snapshot = table.metadata.current_snapshot() + assert first_snapshot is not None + first_snapshot_data_files = _data_file_paths(table, first_snapshot.snapshot_id) + assert len(first_snapshot_data_files) == 1 + + table.overwrite(_arrow_table([{"id": 2, "data": "after"}])) + + return table, first_snapshot.snapshot_id, next(iter(first_snapshot_data_files)) + + +def test_expire_snapshots_clean_expired_files_deletes_unreferenced(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + + control_table, _, control_data_file = _create_overwritten_table(catalog, "default.control_table") + control_table.maintenance.expire_snapshots().by_ids(_non_current_snapshot_ids(control_table)).commit() + assert _local_path(control_data_file).exists() + + table, first_snapshot_id, first_data_file = _create_overwritten_table(catalog, "default.clean_table") + snapshot_ids_to_expire = _non_current_snapshot_ids(table) + assert first_snapshot_id in snapshot_ids_to_expire + + table.maintenance.expire_snapshots().by_ids(snapshot_ids_to_expire).clean_expired_files().commit() + + assert not _local_path(first_data_file).exists() + assert table.scan().to_arrow().to_pylist() == [{"id": 2, "data": "after"}] + + +def test_expire_snapshots_clean_expired_files_keeps_shared_files(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table = catalog.create_table("default.shared_table", schema=_arrow_table([{"id": 1, "data": "one"}]).schema) + + table.append(_arrow_table([{"id": 1, "data": "one"}])) + first_snapshot = table.metadata.current_snapshot() + assert first_snapshot is not None + first_snapshot_data_files = _data_file_paths(table, first_snapshot.snapshot_id) + assert len(first_snapshot_data_files) == 1 + shared_data_file = next(iter(first_snapshot_data_files)) + + table.append(_arrow_table([{"id": 2, "data": "two"}])) + + table.maintenance.expire_snapshots().by_id(first_snapshot.snapshot_id).clean_expired_files().commit() + + assert _local_path(shared_data_file).exists() + assert sorted(table.scan().to_arrow().to_pylist(), key=lambda row: row["id"]) == [ + {"id": 1, "data": "one"}, + {"id": 2, "data": "two"}, + ] + + +def test_expire_snapshots_clean_expired_files_keeps_files_referenced_only_by_tag(tmp_path: Path) -> None: + """A file reachable only from a tagged snapshot (not the current lineage) must survive expiration.""" + catalog = _sql_catalog(tmp_path) + table = catalog.create_table("default.tag_shared_table", schema=_arrow_table([{"id": 1, "data": "one"}]).schema) + + # First snapshot writes the shared file; it is unprotected and will be expired. + table.append(_arrow_table([{"id": 1, "data": "one"}])) + expiring_snapshot = table.metadata.current_snapshot() + assert expiring_snapshot is not None + shared_data_files = _data_file_paths(table, expiring_snapshot.snapshot_id) + assert len(shared_data_files) == 1 + shared_data_file = next(iter(shared_data_files)) + + # Second snapshot still references the shared file; we tag it so it stays alive. + table.append(_arrow_table([{"id": 2, "data": "two"}])) + tagged_snapshot = table.metadata.current_snapshot() + assert tagged_snapshot is not None + assert shared_data_file in _data_file_paths(table, tagged_snapshot.snapshot_id) + table.manage_snapshots().create_tag(tagged_snapshot.snapshot_id, "keep_tag").commit() + + # Overwrite drops the shared file from the current lineage; only the tag still references it. + table.overwrite(_arrow_table([{"id": 99, "data": "overwritten"}])) + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + assert shared_data_file not in _data_file_paths(table, current_snapshot.snapshot_id) + + # Expiring the first snapshot must NOT delete the file the tag still references. + table.maintenance.expire_snapshots().by_id(expiring_snapshot.snapshot_id).clean_expired_files().commit() + + assert _local_path(shared_data_file).exists() + assert table.metadata.snapshot_by_id(expiring_snapshot.snapshot_id) is None + assert table.metadata.snapshot_by_id(tagged_snapshot.snapshot_id) is not None + # The tagged snapshot remains fully readable from the file that survived. + assert sorted(table.scan(snapshot_id=tagged_snapshot.snapshot_id).to_arrow().to_pylist(), key=lambda row: row["id"]) == [ + {"id": 1, "data": "one"}, + {"id": 2, "data": "two"}, + ] + + +def test_expire_snapshots_clean_expired_files_deletes_statistics_files(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table, first_snapshot_id, _ = _create_overwritten_table(catalog, "default.clean_stats_table") + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + + table_location = table.location().rstrip("/") + expired_table_statistics = f"{table_location}/metadata/expired-table-stats.puffin" + surviving_table_statistics = f"{table_location}/metadata/surviving-table-stats.puffin" + expired_partition_statistics = f"{table_location}/metadata/expired-partition-stats.puffin" + surviving_partition_statistics = f"{table_location}/metadata/surviving-partition-stats.puffin" + + for location in ( + expired_table_statistics, + surviving_table_statistics, + expired_partition_statistics, + surviving_partition_statistics, + ): + _touch_file(location) + + with table.update_statistics() as update: + update.set_statistics(_statistics_file(first_snapshot_id, expired_table_statistics)) + update.set_statistics(_statistics_file(current_snapshot.snapshot_id, surviving_table_statistics)) + + table.transaction()._apply( + ( + SetPartitionStatisticsUpdate( + partition_statistics=_partition_statistics_file(first_snapshot_id, expired_partition_statistics) + ), + SetPartitionStatisticsUpdate( + partition_statistics=_partition_statistics_file(current_snapshot.snapshot_id, surviving_partition_statistics) + ), + ) + ).commit_transaction() + + table.maintenance.expire_snapshots().by_id(first_snapshot_id).clean_expired_files().commit() + + assert not _local_path(expired_table_statistics).exists() + assert not _local_path(expired_partition_statistics).exists() + assert _local_path(surviving_table_statistics).exists() + assert _local_path(surviving_partition_statistics).exists() + + +def test_expire_snapshots_clean_skips_deletion_when_surviving_unresolvable(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table, first_snapshot_id, first_data_file = _create_overwritten_table(catalog, "default.unresolvable_survivor_table") + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + original_manifests = Snapshot.manifests + + def manifests(snapshot: Snapshot, *args: object, **kwargs: object) -> object: + if snapshot.snapshot_id == current_snapshot.snapshot_id: + raise OSError("surviving manifest list is unavailable") + return original_manifests(snapshot, *args, **kwargs) + + with ( + patch.object(table.io, "delete", wraps=table.io.delete) as delete, + patch.object(Snapshot, "manifests", autospec=True, side_effect=manifests), + ): + table.maintenance.expire_snapshots().by_id(first_snapshot_id).clean_expired_files().commit() + + delete.assert_not_called() + assert _local_path(first_data_file).exists() + assert table.scan().to_arrow().to_pylist() == [{"id": 2, "data": "after"}] + + +def test_expire_snapshots_clean_noop_on_non_autocommit_transaction(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table, first_snapshot_id, first_data_file = _create_overwritten_table(catalog, "default.non_autocommit_table") + transaction = Transaction(table, autocommit=False) + expire_snapshots = ExpireSnapshots(transaction).by_id(first_snapshot_id).clean_expired_files() + + with patch.object(expire_snapshots._io, "delete") as delete: + expire_snapshots.commit() + + delete.assert_not_called() + assert _local_path(first_data_file).exists() + assert table.metadata.snapshot_by_id(first_snapshot_id) is not None + assert transaction.table_metadata.snapshot_by_id(first_snapshot_id) is None + assert any( + isinstance(update, RemoveSnapshotsUpdate) and update.snapshot_ids == [first_snapshot_id] + for update in transaction._updates + ) + + def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None: """Test that a HEAD (branch) snapshot cannot be expired.""" HEAD_SNAPSHOT = 3051729675574597004 From a60fdfcf4b90d06110754b7f3c66d6e013f09e88 Mon Sep 17 00:00:00 2001 From: Abanoub Doss Date: Wed, 24 Jun 2026 08:27:26 -0500 Subject: [PATCH 3/3] chore: match the patched Snapshot.manifests signature in expire test --- tests/table/test_expire_snapshots.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index 178f44b486..8bc685c7f4 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -25,6 +25,8 @@ import pytest from pyiceberg.catalog.sql import SqlCatalog +from pyiceberg.io import FileIO +from pyiceberg.manifest import ManifestFile from pyiceberg.table import CommitTableResponse, Table, Transaction from pyiceberg.table.snapshots import Snapshot from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile @@ -244,10 +246,10 @@ def test_expire_snapshots_clean_skips_deletion_when_surviving_unresolvable(tmp_p assert current_snapshot is not None original_manifests = Snapshot.manifests - def manifests(snapshot: Snapshot, *args: object, **kwargs: object) -> object: + def manifests(snapshot: Snapshot, io: FileIO) -> list[ManifestFile]: if snapshot.snapshot_id == current_snapshot.snapshot_id: raise OSError("surviving manifest list is unavailable") - return original_manifests(snapshot, *args, **kwargs) + return original_manifests(snapshot, io) with ( patch.object(table.io, "delete", wraps=table.io.delete) as delete,