diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 597f62632f..4046da1924 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2219,7 +2219,6 @@ def to_arrow_batch_reader(self, dictionary_columns: tuple[str, ...] = ()) -> pa. from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow - target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( self.table_metadata, self.io, @@ -2230,6 +2229,23 @@ def to_arrow_batch_reader(self, dictionary_columns: tuple[str, ...] = ()) -> pa. dictionary_columns=dictionary_columns, ).to_record_batches(self.plan_files()) + target_schema = schema_to_pyarrow(self.projection()) + if dictionary_columns: + batches = iter(batches) + try: + first_batch = next(batches) + except StopIteration: + pass + else: + target_schema = pa.schema( + [ + field.with_type(batch_field.type) + for field, batch_field in zip(target_schema, first_batch.schema, strict=True) + ], + metadata=target_schema.metadata, + ) + batches = chain([first_batch], batches) + return pa.RecordBatchReader.from_batches( target_schema, batches, diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 7931edacdd..3419aa23c6 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -17,10 +17,11 @@ from __future__ import annotations import itertools +import logging import uuid from abc import abstractmethod from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Iterable from datetime import datetime from functools import cached_property from typing import TYPE_CHECKING, Generic @@ -79,6 +80,10 @@ if TYPE_CHECKING: from pyiceberg.table import Transaction + from pyiceberg.table.metadata import TableMetadata + + +logger = logging.getLogger(__name__) def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str: @@ -1039,13 +1044,17 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): _updates: tuple[TableUpdate, ...] _requirements: tuple[TableRequirement, ...] + _io: FileIO _snapshot_ids_to_expire: set[int] + _clean_expired_files: bool def __init__(self, transaction: Transaction) -> None: super().__init__(transaction) self._updates = () self._requirements = () + self._io = transaction._table.io self._snapshot_ids_to_expire = set() + self._clean_expired_files = False def _commit(self) -> UpdatesAndRequirements: """ @@ -1064,6 +1073,126 @@ def _commit(self) -> UpdatesAndRequirements: self._updates += (update,) return self._updates, self._requirements + def commit(self) -> None: + if not self._clean_expired_files or not self._snapshot_ids_to_expire: + super().commit() + return + + if not self._transaction._autocommit: + super().commit() + logger.debug("Skipping expired-file cleanup for non-autocommit transaction; cleanup only runs on autocommit") + return + + self._snapshot_ids_to_expire -= self._get_protected_snapshot_ids() + if not self._snapshot_ids_to_expire: + super().commit() + return + + pre_commit_metadata = self._transaction.table_metadata + expired_snapshots = [ + snapshot for snapshot in pre_commit_metadata.snapshots if snapshot.snapshot_id in self._snapshot_ids_to_expire + ] + expired_files = self._reachable_files(expired_snapshots, strict=False) + expired_files.update(self._statistics_paths(pre_commit_metadata, self._snapshot_ids_to_expire, belonging=True)) + + super().commit() + + try: + surviving_files = self._reachable_files(self._transaction.table_metadata.snapshots, strict=True) + except Exception: + logger.warning( + "skipping expired-file cleanup: could not fully resolve surviving files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return + surviving_files.update( + self._statistics_paths(self._transaction.table_metadata, self._snapshot_ids_to_expire, belonging=False) + ) + + for path in expired_files - surviving_files: + try: + self._io.delete(path) + except Exception: + logger.warning( + "Failed to delete expired snapshot file %s", + path, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + + def clean_expired_files(self, clean: bool = True) -> ExpireSnapshots: + """Clean up files that are no longer reachable after expiring snapshots.""" + self._clean_expired_files = clean + return self + + def _reachable_files(self, snapshots: Iterable[Snapshot], *, strict: bool) -> set[str]: + reachable_files: set[str] = set() + + for snapshot in snapshots: + reachable_files.add(snapshot.manifest_list) + try: + manifests = snapshot.manifests(self._io) + except Exception: + if strict: + raise + logger.debug( + "Skipping manifest list while collecting reachable files for snapshot %s: %s", + snapshot.snapshot_id, + snapshot.manifest_list, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + continue + + for manifest in manifests: + reachable_files.add(manifest.manifest_path) + try: + entries = manifest.fetch_manifest_entry(io=self._io, discard_deleted=False) + except Exception: + if strict: + raise + logger.debug( + "Skipping manifest while collecting reachable files for snapshot %s: %s", + snapshot.snapshot_id, + manifest.manifest_path, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + continue + + reachable_files.update(entry.data_file.file_path for entry in entries) + + return reachable_files + + @staticmethod + def _statistics_paths(metadata: TableMetadata, snapshot_ids: set[int], *, belonging: bool) -> set[str]: + statistics_paths: set[str] = set() + + try: + statistics = getattr(metadata, "statistics", None) or () + partition_statistics = getattr(metadata, "partition_statistics", None) or () + except Exception: + logger.debug( + "Skipping statistics while collecting reachable files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return statistics_paths + + try: + for stat in itertools.chain(statistics, partition_statistics): + try: + if (stat.snapshot_id in snapshot_ids) == belonging and stat.statistics_path: + statistics_paths.add(stat.statistics_path) + except Exception: + logger.debug( + "Skipping statistics file while collecting reachable files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + except Exception: + logger.debug( + "Skipping statistics while collecting reachable files", + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + + return statistics_paths + def _get_protected_snapshot_ids(self) -> set[int]: """ Get the IDs of protected snapshots. diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 407ec611fd..faaecd3618 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -5297,3 +5297,76 @@ def test_dictionary_columns_produces_dict_encoded_output(tmpdir: str) -> None: # Values must be identical assert result_plain.column("label").to_pylist() == result_dict.column("label").to_pylist() + + +@pytest.fixture +def sql_catalog(tmp_path: Path) -> Iterator[Any]: + from pyiceberg.catalog.sql import SqlCatalog + + catalog = SqlCatalog( + "test_sql_catalog", + uri="sqlite:///:memory:", + warehouse=f"file://{tmp_path}", + ) + catalog.create_tables() + try: + yield catalog + finally: + catalog.destroy_tables() + catalog.close() + + +def _create_dictionary_batch_reader_table(catalog: Any, identifier: str) -> Any: + arrow_table = pa.table( + { + "id": pa.array([1, 2, 3, 4], type=pa.int64()), + "label": pa.array(["a", "b", "a", "b"], type=pa.string()), + } + ) + catalog.create_namespace_if_not_exists("default") + table = catalog.create_table(identifier, schema=arrow_table.schema) + table.append(arrow_table) + return table + + +def test_to_arrow_batch_reader_preserves_dictionary_columns(sql_catalog: Any) -> None: + """Regression test for issue #3540. + + ``to_arrow(dictionary_columns=...)`` preserves dictionary encoding, but + ``to_arrow_batch_reader(dictionary_columns=...)`` previously decoded it back + to plain strings via a final ``.cast(target_schema)`` where ``target_schema`` + had no dictionary types. Both public paths must now preserve the encoding. + """ + table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_batch_reader_test") + expected = table.scan().to_arrow(dictionary_columns=("label",)) + result = table.scan().to_arrow_batch_reader(dictionary_columns=("label",)).read_all() + + assert result.schema.field("label").type == expected.schema.field("label").type + assert pa.types.is_dictionary(result.schema.field("label").type) + assert result.column("label").to_pylist() == ["a", "b", "a", "b"] + assert result.to_pydict() == expected.to_pydict() + # A column not in dictionary_columns stays a plain (non-dict) array. + assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64() + assert not pa.types.is_dictionary(result.schema.field("id").type) + + +def test_to_arrow_batch_reader_dictionary_columns_keep_non_string_columns_plain(sql_catalog: Any) -> None: + table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_int_batch_reader_test") + expected = table.scan().to_arrow(dictionary_columns=("id",)) + result = table.scan().to_arrow_batch_reader(dictionary_columns=("id",)).read_all() + + assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64() + assert not pa.types.is_dictionary(result.schema.field("id").type) + assert result.to_pydict() == expected.to_pydict() + + +def test_to_arrow_batch_reader_dictionary_columns_allow_mixed_types(sql_catalog: Any) -> None: + table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_mixed_batch_reader_test") + expected = table.scan().to_arrow(dictionary_columns=("label", "id")) + result = table.scan().to_arrow_batch_reader(dictionary_columns=("label", "id")).read_all() + + assert result.schema.field("label").type == expected.schema.field("label").type + assert pa.types.is_dictionary(result.schema.field("label").type) + assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64() + assert not pa.types.is_dictionary(result.schema.field("id").type) + assert result.to_pydict() == expected.to_pydict() diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index 106e5b786c..8bc685c7f4 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -16,16 +16,271 @@ # under the License. import threading from datetime import datetime, timedelta -from unittest.mock import MagicMock, Mock +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch +from urllib.parse import unquote, urlparse from uuid import uuid4 +import pyarrow as pa import pytest -from pyiceberg.table import CommitTableResponse, Table -from pyiceberg.table.update import RemoveSnapshotsUpdate, update_table_metadata +from pyiceberg.catalog.sql import SqlCatalog +from pyiceberg.io import FileIO +from pyiceberg.manifest import ManifestFile +from pyiceberg.table import CommitTableResponse, Table, Transaction +from pyiceberg.table.snapshots import Snapshot +from pyiceberg.table.statistics import BlobMetadata, PartitionStatisticsFile, StatisticsFile +from pyiceberg.table.update import RemoveSnapshotsUpdate, SetPartitionStatisticsUpdate, update_table_metadata from pyiceberg.table.update.snapshot import ExpireSnapshots +def _sql_catalog(tmp_path: Path) -> SqlCatalog: + catalog = SqlCatalog( + "test", + uri=f"sqlite:///{tmp_path}/pyiceberg_catalog.db", + warehouse=f"file://{tmp_path}", + ) + catalog.create_namespace_if_not_exists("default") + return catalog + + +def _arrow_table(rows: list[dict[str, object]]) -> pa.Table: + return pa.Table.from_pylist( + rows, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.string())]), + ) + + +def _local_path(location: str) -> Path: + parsed = urlparse(location) + if parsed.scheme == "file": + return Path(unquote(parsed.path)) + return Path(location) + + +def _touch_file(location: str) -> None: + path = _local_path(location) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(b"") + + +def _statistics_file(snapshot_id: int, statistics_path: str) -> StatisticsFile: + return StatisticsFile( + snapshot_id=snapshot_id, + statistics_path=statistics_path, + file_size_in_bytes=0, + file_footer_size_in_bytes=0, + blob_metadata=[ + BlobMetadata( + type="apache-datasketches-theta-v1", + snapshot_id=snapshot_id, + sequence_number=0, + fields=[1], + ) + ], + ) + + +def _partition_statistics_file(snapshot_id: int, statistics_path: str) -> PartitionStatisticsFile: + return PartitionStatisticsFile( + snapshot_id=snapshot_id, + statistics_path=statistics_path, + file_size_in_bytes=0, + ) + + +def _data_file_paths(table: Table, snapshot_id: int) -> set[str]: + snapshot = table.metadata.snapshot_by_id(snapshot_id) + assert snapshot is not None + return { + entry.data_file.file_path + for manifest in snapshot.manifests(table.io) + for entry in manifest.fetch_manifest_entry(io=table.io, discard_deleted=True) + } + + +def _non_current_snapshot_ids(table: Table) -> list[int]: + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + return [snapshot.snapshot_id for snapshot in table.metadata.snapshots if snapshot.snapshot_id != current_snapshot.snapshot_id] + + +def _create_overwritten_table(catalog: SqlCatalog, identifier: str) -> tuple[Table, int, str]: + table = catalog.create_table(identifier, schema=_arrow_table([{"id": 1, "data": "before"}]).schema) + table.append(_arrow_table([{"id": 1, "data": "before"}])) + first_snapshot = table.metadata.current_snapshot() + assert first_snapshot is not None + first_snapshot_data_files = _data_file_paths(table, first_snapshot.snapshot_id) + assert len(first_snapshot_data_files) == 1 + + table.overwrite(_arrow_table([{"id": 2, "data": "after"}])) + + return table, first_snapshot.snapshot_id, next(iter(first_snapshot_data_files)) + + +def test_expire_snapshots_clean_expired_files_deletes_unreferenced(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + + control_table, _, control_data_file = _create_overwritten_table(catalog, "default.control_table") + control_table.maintenance.expire_snapshots().by_ids(_non_current_snapshot_ids(control_table)).commit() + assert _local_path(control_data_file).exists() + + table, first_snapshot_id, first_data_file = _create_overwritten_table(catalog, "default.clean_table") + snapshot_ids_to_expire = _non_current_snapshot_ids(table) + assert first_snapshot_id in snapshot_ids_to_expire + + table.maintenance.expire_snapshots().by_ids(snapshot_ids_to_expire).clean_expired_files().commit() + + assert not _local_path(first_data_file).exists() + assert table.scan().to_arrow().to_pylist() == [{"id": 2, "data": "after"}] + + +def test_expire_snapshots_clean_expired_files_keeps_shared_files(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table = catalog.create_table("default.shared_table", schema=_arrow_table([{"id": 1, "data": "one"}]).schema) + + table.append(_arrow_table([{"id": 1, "data": "one"}])) + first_snapshot = table.metadata.current_snapshot() + assert first_snapshot is not None + first_snapshot_data_files = _data_file_paths(table, first_snapshot.snapshot_id) + assert len(first_snapshot_data_files) == 1 + shared_data_file = next(iter(first_snapshot_data_files)) + + table.append(_arrow_table([{"id": 2, "data": "two"}])) + + table.maintenance.expire_snapshots().by_id(first_snapshot.snapshot_id).clean_expired_files().commit() + + assert _local_path(shared_data_file).exists() + assert sorted(table.scan().to_arrow().to_pylist(), key=lambda row: row["id"]) == [ + {"id": 1, "data": "one"}, + {"id": 2, "data": "two"}, + ] + + +def test_expire_snapshots_clean_expired_files_keeps_files_referenced_only_by_tag(tmp_path: Path) -> None: + """A file reachable only from a tagged snapshot (not the current lineage) must survive expiration.""" + catalog = _sql_catalog(tmp_path) + table = catalog.create_table("default.tag_shared_table", schema=_arrow_table([{"id": 1, "data": "one"}]).schema) + + # First snapshot writes the shared file; it is unprotected and will be expired. + table.append(_arrow_table([{"id": 1, "data": "one"}])) + expiring_snapshot = table.metadata.current_snapshot() + assert expiring_snapshot is not None + shared_data_files = _data_file_paths(table, expiring_snapshot.snapshot_id) + assert len(shared_data_files) == 1 + shared_data_file = next(iter(shared_data_files)) + + # Second snapshot still references the shared file; we tag it so it stays alive. + table.append(_arrow_table([{"id": 2, "data": "two"}])) + tagged_snapshot = table.metadata.current_snapshot() + assert tagged_snapshot is not None + assert shared_data_file in _data_file_paths(table, tagged_snapshot.snapshot_id) + table.manage_snapshots().create_tag(tagged_snapshot.snapshot_id, "keep_tag").commit() + + # Overwrite drops the shared file from the current lineage; only the tag still references it. + table.overwrite(_arrow_table([{"id": 99, "data": "overwritten"}])) + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + assert shared_data_file not in _data_file_paths(table, current_snapshot.snapshot_id) + + # Expiring the first snapshot must NOT delete the file the tag still references. + table.maintenance.expire_snapshots().by_id(expiring_snapshot.snapshot_id).clean_expired_files().commit() + + assert _local_path(shared_data_file).exists() + assert table.metadata.snapshot_by_id(expiring_snapshot.snapshot_id) is None + assert table.metadata.snapshot_by_id(tagged_snapshot.snapshot_id) is not None + # The tagged snapshot remains fully readable from the file that survived. + assert sorted(table.scan(snapshot_id=tagged_snapshot.snapshot_id).to_arrow().to_pylist(), key=lambda row: row["id"]) == [ + {"id": 1, "data": "one"}, + {"id": 2, "data": "two"}, + ] + + +def test_expire_snapshots_clean_expired_files_deletes_statistics_files(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table, first_snapshot_id, _ = _create_overwritten_table(catalog, "default.clean_stats_table") + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + + table_location = table.location().rstrip("/") + expired_table_statistics = f"{table_location}/metadata/expired-table-stats.puffin" + surviving_table_statistics = f"{table_location}/metadata/surviving-table-stats.puffin" + expired_partition_statistics = f"{table_location}/metadata/expired-partition-stats.puffin" + surviving_partition_statistics = f"{table_location}/metadata/surviving-partition-stats.puffin" + + for location in ( + expired_table_statistics, + surviving_table_statistics, + expired_partition_statistics, + surviving_partition_statistics, + ): + _touch_file(location) + + with table.update_statistics() as update: + update.set_statistics(_statistics_file(first_snapshot_id, expired_table_statistics)) + update.set_statistics(_statistics_file(current_snapshot.snapshot_id, surviving_table_statistics)) + + table.transaction()._apply( + ( + SetPartitionStatisticsUpdate( + partition_statistics=_partition_statistics_file(first_snapshot_id, expired_partition_statistics) + ), + SetPartitionStatisticsUpdate( + partition_statistics=_partition_statistics_file(current_snapshot.snapshot_id, surviving_partition_statistics) + ), + ) + ).commit_transaction() + + table.maintenance.expire_snapshots().by_id(first_snapshot_id).clean_expired_files().commit() + + assert not _local_path(expired_table_statistics).exists() + assert not _local_path(expired_partition_statistics).exists() + assert _local_path(surviving_table_statistics).exists() + assert _local_path(surviving_partition_statistics).exists() + + +def test_expire_snapshots_clean_skips_deletion_when_surviving_unresolvable(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table, first_snapshot_id, first_data_file = _create_overwritten_table(catalog, "default.unresolvable_survivor_table") + current_snapshot = table.metadata.current_snapshot() + assert current_snapshot is not None + original_manifests = Snapshot.manifests + + def manifests(snapshot: Snapshot, io: FileIO) -> list[ManifestFile]: + if snapshot.snapshot_id == current_snapshot.snapshot_id: + raise OSError("surviving manifest list is unavailable") + return original_manifests(snapshot, io) + + with ( + patch.object(table.io, "delete", wraps=table.io.delete) as delete, + patch.object(Snapshot, "manifests", autospec=True, side_effect=manifests), + ): + table.maintenance.expire_snapshots().by_id(first_snapshot_id).clean_expired_files().commit() + + delete.assert_not_called() + assert _local_path(first_data_file).exists() + assert table.scan().to_arrow().to_pylist() == [{"id": 2, "data": "after"}] + + +def test_expire_snapshots_clean_noop_on_non_autocommit_transaction(tmp_path: Path) -> None: + catalog = _sql_catalog(tmp_path) + table, first_snapshot_id, first_data_file = _create_overwritten_table(catalog, "default.non_autocommit_table") + transaction = Transaction(table, autocommit=False) + expire_snapshots = ExpireSnapshots(transaction).by_id(first_snapshot_id).clean_expired_files() + + with patch.object(expire_snapshots._io, "delete") as delete: + expire_snapshots.commit() + + delete.assert_not_called() + assert _local_path(first_data_file).exists() + assert table.metadata.snapshot_by_id(first_snapshot_id) is not None + assert transaction.table_metadata.snapshot_by_id(first_snapshot_id) is None + assert any( + isinstance(update, RemoveSnapshotsUpdate) and update.snapshot_ids == [first_snapshot_id] + for update in transaction._updates + ) + + def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None: """Test that a HEAD (branch) snapshot cannot be expired.""" HEAD_SNAPSHOT = 3051729675574597004