Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,6 @@ def to_arrow_batch_reader(self, dictionary_columns: tuple[str, ...] = ()) -> pa.

from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow

target_schema = schema_to_pyarrow(self.projection())
batches = ArrowScan(
self.table_metadata,
self.io,
Expand All @@ -2230,6 +2229,23 @@ def to_arrow_batch_reader(self, dictionary_columns: tuple[str, ...] = ()) -> pa.
dictionary_columns=dictionary_columns,
).to_record_batches(self.plan_files())

target_schema = schema_to_pyarrow(self.projection())
if dictionary_columns:
batches = iter(batches)
try:
first_batch = next(batches)
except StopIteration:
pass
else:
target_schema = pa.schema(
[
field.with_type(batch_field.type)
for field, batch_field in zip(target_schema, first_batch.schema, strict=True)
],
metadata=target_schema.metadata,
)
batches = chain([first_batch], batches)

return pa.RecordBatchReader.from_batches(
target_schema,
batches,
Expand Down
131 changes: 130 additions & 1 deletion pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from __future__ import annotations

import itertools
import logging
import uuid
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Callable, Iterable
from datetime import datetime
from functools import cached_property
from typing import TYPE_CHECKING, Generic
Expand Down Expand Up @@ -79,6 +80,10 @@

if TYPE_CHECKING:
from pyiceberg.table import Transaction
from pyiceberg.table.metadata import TableMetadata


logger = logging.getLogger(__name__)


def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str:
Expand Down Expand Up @@ -1039,13 +1044,17 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):

_updates: tuple[TableUpdate, ...]
_requirements: tuple[TableRequirement, ...]
_io: FileIO
_snapshot_ids_to_expire: set[int]
_clean_expired_files: bool

def __init__(self, transaction: Transaction) -> None:
super().__init__(transaction)
self._updates = ()
self._requirements = ()
self._io = transaction._table.io
self._snapshot_ids_to_expire = set()
self._clean_expired_files = False

def _commit(self) -> UpdatesAndRequirements:
"""
Expand All @@ -1064,6 +1073,126 @@ def _commit(self) -> UpdatesAndRequirements:
self._updates += (update,)
return self._updates, self._requirements

def commit(self) -> None:
if not self._clean_expired_files or not self._snapshot_ids_to_expire:
super().commit()
return

if not self._transaction._autocommit:
super().commit()
logger.debug("Skipping expired-file cleanup for non-autocommit transaction; cleanup only runs on autocommit")
return

self._snapshot_ids_to_expire -= self._get_protected_snapshot_ids()
if not self._snapshot_ids_to_expire:
super().commit()
return

pre_commit_metadata = self._transaction.table_metadata
expired_snapshots = [
snapshot for snapshot in pre_commit_metadata.snapshots if snapshot.snapshot_id in self._snapshot_ids_to_expire
]
expired_files = self._reachable_files(expired_snapshots, strict=False)
expired_files.update(self._statistics_paths(pre_commit_metadata, self._snapshot_ids_to_expire, belonging=True))

super().commit()

try:
surviving_files = self._reachable_files(self._transaction.table_metadata.snapshots, strict=True)
except Exception:
logger.warning(
"skipping expired-file cleanup: could not fully resolve surviving files",
exc_info=logger.isEnabledFor(logging.DEBUG),
)
return
surviving_files.update(
self._statistics_paths(self._transaction.table_metadata, self._snapshot_ids_to_expire, belonging=False)
)

for path in expired_files - surviving_files:
try:
self._io.delete(path)
except Exception:
logger.warning(
"Failed to delete expired snapshot file %s",
path,
exc_info=logger.isEnabledFor(logging.DEBUG),
)

def clean_expired_files(self, clean: bool = True) -> ExpireSnapshots:
"""Clean up files that are no longer reachable after expiring snapshots."""
self._clean_expired_files = clean
return self

def _reachable_files(self, snapshots: Iterable[Snapshot], *, strict: bool) -> set[str]:
reachable_files: set[str] = set()

for snapshot in snapshots:
reachable_files.add(snapshot.manifest_list)
try:
manifests = snapshot.manifests(self._io)
except Exception:
if strict:
raise
logger.debug(
"Skipping manifest list while collecting reachable files for snapshot %s: %s",
snapshot.snapshot_id,
snapshot.manifest_list,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
continue

for manifest in manifests:
reachable_files.add(manifest.manifest_path)
try:
entries = manifest.fetch_manifest_entry(io=self._io, discard_deleted=False)
except Exception:
if strict:
raise
logger.debug(
"Skipping manifest while collecting reachable files for snapshot %s: %s",
snapshot.snapshot_id,
manifest.manifest_path,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
continue

reachable_files.update(entry.data_file.file_path for entry in entries)

return reachable_files

@staticmethod
def _statistics_paths(metadata: TableMetadata, snapshot_ids: set[int], *, belonging: bool) -> set[str]:
statistics_paths: set[str] = set()

try:
statistics = getattr(metadata, "statistics", None) or ()
partition_statistics = getattr(metadata, "partition_statistics", None) or ()
except Exception:
logger.debug(
"Skipping statistics while collecting reachable files",
exc_info=logger.isEnabledFor(logging.DEBUG),
)
return statistics_paths

try:
for stat in itertools.chain(statistics, partition_statistics):
try:
if (stat.snapshot_id in snapshot_ids) == belonging and stat.statistics_path:
statistics_paths.add(stat.statistics_path)
except Exception:
logger.debug(
"Skipping statistics file while collecting reachable files",
exc_info=logger.isEnabledFor(logging.DEBUG),
)
except Exception:
logger.debug(
"Skipping statistics while collecting reachable files",
exc_info=logger.isEnabledFor(logging.DEBUG),
)

return statistics_paths

def _get_protected_snapshot_ids(self) -> set[int]:
"""
Get the IDs of protected snapshots.
Expand Down
73 changes: 73 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5297,3 +5297,76 @@ def test_dictionary_columns_produces_dict_encoded_output(tmpdir: str) -> None:

# Values must be identical
assert result_plain.column("label").to_pylist() == result_dict.column("label").to_pylist()


@pytest.fixture
def sql_catalog(tmp_path: Path) -> Iterator[Any]:
from pyiceberg.catalog.sql import SqlCatalog

catalog = SqlCatalog(
"test_sql_catalog",
uri="sqlite:///:memory:",
warehouse=f"file://{tmp_path}",
)
catalog.create_tables()
try:
yield catalog
finally:
catalog.destroy_tables()
catalog.close()


def _create_dictionary_batch_reader_table(catalog: Any, identifier: str) -> Any:
arrow_table = pa.table(
{
"id": pa.array([1, 2, 3, 4], type=pa.int64()),
"label": pa.array(["a", "b", "a", "b"], type=pa.string()),
}
)
catalog.create_namespace_if_not_exists("default")
table = catalog.create_table(identifier, schema=arrow_table.schema)
table.append(arrow_table)
return table


def test_to_arrow_batch_reader_preserves_dictionary_columns(sql_catalog: Any) -> None:
"""Regression test for issue #3540.

``to_arrow(dictionary_columns=...)`` preserves dictionary encoding, but
``to_arrow_batch_reader(dictionary_columns=...)`` previously decoded it back
to plain strings via a final ``.cast(target_schema)`` where ``target_schema``
had no dictionary types. Both public paths must now preserve the encoding.
"""
table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_batch_reader_test")
expected = table.scan().to_arrow(dictionary_columns=("label",))
result = table.scan().to_arrow_batch_reader(dictionary_columns=("label",)).read_all()

assert result.schema.field("label").type == expected.schema.field("label").type
assert pa.types.is_dictionary(result.schema.field("label").type)
assert result.column("label").to_pylist() == ["a", "b", "a", "b"]
assert result.to_pydict() == expected.to_pydict()
# A column not in dictionary_columns stays a plain (non-dict) array.
assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64()
assert not pa.types.is_dictionary(result.schema.field("id").type)


def test_to_arrow_batch_reader_dictionary_columns_keep_non_string_columns_plain(sql_catalog: Any) -> None:
table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_int_batch_reader_test")
expected = table.scan().to_arrow(dictionary_columns=("id",))
result = table.scan().to_arrow_batch_reader(dictionary_columns=("id",)).read_all()

assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64()
assert not pa.types.is_dictionary(result.schema.field("id").type)
assert result.to_pydict() == expected.to_pydict()


def test_to_arrow_batch_reader_dictionary_columns_allow_mixed_types(sql_catalog: Any) -> None:
table = _create_dictionary_batch_reader_table(sql_catalog, "default.dict_mixed_batch_reader_test")
expected = table.scan().to_arrow(dictionary_columns=("label", "id"))
result = table.scan().to_arrow_batch_reader(dictionary_columns=("label", "id")).read_all()

assert result.schema.field("label").type == expected.schema.field("label").type
assert pa.types.is_dictionary(result.schema.field("label").type)
assert result.schema.field("id").type == expected.schema.field("id").type == pa.int64()
assert not pa.types.is_dictionary(result.schema.field("id").type)
assert result.to_pydict() == expected.to_pydict()
Loading