Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 149 additions & 13 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@
)
from pyiceberg.partitioning import PartitionFieldValue, PartitionKey, PartitionSpec
from pyiceberg.schema import (
LAST_UPDATED_SEQUENCE_NUMBER_FIELD_ID,
RESERVED_METADATA_FIELD_IDS,
ROW_ID_FIELD_ID,
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
Expand Down Expand Up @@ -1612,6 +1615,128 @@ def _get_column_projection_values(
return projected_missing_fields


def _projected_data_schema(projected_schema: Schema) -> Schema:
return Schema(
*(field for field in projected_schema.fields if field.field_id not in RESERVED_METADATA_FIELD_IDS),
schema_id=projected_schema.schema_id,
identifier_field_ids=projected_schema.identifier_field_ids,
)


def _projected_metadata_fields(projected_schema: Schema) -> tuple[NestedField, ...]:
return tuple(field for field in projected_schema.fields if field.field_id in RESERVED_METADATA_FIELD_IDS)


def _column_by_field_id(file_schema: Schema, batch: pa.RecordBatch, field_id: int) -> pa.Array | None:
for idx, field in enumerate(file_schema.fields):
if field.field_id == field_id:
return batch.column(idx)
return None


def _as_int64_array(array: pa.Array | pa.ChunkedArray) -> pa.Array:
if isinstance(array, pa.ChunkedArray):
array = array.combine_chunks()
if array.type != pa.int64():
return array.cast(pa.int64())
return array


def _int64_range(start: int, stop: int) -> pa.Array:
return pa.array(range(start, stop), type=pa.int64())


def _filter_batch_and_positions(
batch: pa.RecordBatch, pyarrow_filter: ds.Expression, positions: pa.Array | None
) -> tuple[pa.RecordBatch, pa.Array | None]:
position_column_name = "__iceberg_row_position"
while position_column_name in batch.schema.names:
position_column_name = f"{position_column_name}_"

table = pa.Table.from_batches([batch])
if positions is not None:
table = table.append_column(position_column_name, positions)

table = table.filter(pyarrow_filter)
if table.num_rows == 0:
return batch.slice(0, 0), pa.array([], type=pa.int64()) if positions is not None else None

if positions is not None:
positions = _as_int64_array(table.column(position_column_name))
table = table.drop([position_column_name])

return table.combine_chunks().to_batches()[0], positions


def _row_id_array(task: FileScanTask, file_schema: Schema, batch: pa.RecordBatch, positions: pa.Array | None) -> pa.Array:
if task.file.first_row_id is None:
# Snapshots written before row lineage was enabled (e.g. pre-upgrade snapshots of an
# upgraded table) have a null first_row_id, so _row_id reads as null for all rows.
return pa.nulls(batch.num_rows, type=pa.int64())
if positions is None:
raise ValueError("Cannot read _row_id: row positions were not materialized")

computed_row_ids = pc.add(positions, pa.scalar(task.file.first_row_id, type=pa.int64()))
physical_row_ids = _column_by_field_id(file_schema, batch, ROW_ID_FIELD_ID)
if physical_row_ids is None:
return computed_row_ids

return pc.coalesce(_as_int64_array(physical_row_ids), computed_row_ids)


def _last_updated_sequence_number_array(task: FileScanTask, file_schema: Schema, batch: pa.RecordBatch) -> pa.Array:
physical_sequence_numbers = _column_by_field_id(file_schema, batch, LAST_UPDATED_SEQUENCE_NUMBER_FIELD_ID)

if task.data_sequence_number is None:
if physical_sequence_numbers is None:
raise ValueError(
"Cannot read _last_updated_sequence_number: the file scan task has no data sequence number. "
"Server-side/REST scan planning does not yet supply the data sequence number required to "
"materialize this column."
)
physical_sequence_numbers = _as_int64_array(physical_sequence_numbers)
if physical_sequence_numbers.null_count > 0:
raise ValueError(
"Cannot read _last_updated_sequence_number: data sequence number is required for null physical values"
)
return physical_sequence_numbers

fallback_sequence_numbers = pa.repeat(pa.scalar(task.data_sequence_number, type=pa.int64()), batch.num_rows)
if physical_sequence_numbers is None:
return fallback_sequence_numbers

return pc.coalesce(_as_int64_array(physical_sequence_numbers), fallback_sequence_numbers)


def _append_row_lineage_metadata_columns(
projected_schema: Schema,
file_schema: Schema,
file_batch: pa.RecordBatch,
projected_batch: pa.RecordBatch,
task: FileScanTask,
positions: pa.Array | None,
) -> pa.RecordBatch:
metadata_fields = _projected_metadata_fields(projected_schema)
if not metadata_fields:
return projected_batch

arrays = [projected_batch.column(idx) for idx in range(projected_batch.num_columns)]
fields = list(projected_batch.schema)

for field in metadata_fields:
if field.field_id == ROW_ID_FIELD_ID:
metadata_array = _row_id_array(task, file_schema, file_batch, positions)
elif field.field_id == LAST_UPDATED_SEQUENCE_NUMBER_FIELD_ID:
metadata_array = _last_updated_sequence_number_array(task, file_schema, file_batch)
else:
continue

arrays.append(metadata_array)
fields.append(pa.field(field.name, pa.int64(), nullable=True))

return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))


def _task_to_record_batches(
io: FileIO,
task: FileScanTask,
Expand Down Expand Up @@ -1644,6 +1769,9 @@ def _task_to_record_batches(
file_schema = pyarrow_to_schema(
physical_schema, name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, format_version=format_version
)
data_projected_schema = _projected_data_schema(projected_schema)
metadata_fields = _projected_metadata_fields(projected_schema)
row_id_requested = any(field.field_id == ROW_ID_FIELD_ID for field in metadata_fields)

# Apply column projection rules: https://iceberg.apache.org/spec/#column-projection
projected_missing_fields = _get_column_projection_values(
Expand All @@ -1659,13 +1787,14 @@ def _task_to_record_batches(
pyarrow_filter = expression_to_pyarrow(bound_file_filter, file_schema)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
apply_filter_after_scan = bool(positional_deletes) or row_id_requested

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
filter=pyarrow_filter if not apply_filter_after_scan else None,
columns=[col.name for col in file_project_schema.columns],
)

Expand All @@ -1675,34 +1804,41 @@ def _task_to_record_batches(
next_index = next_index + len(batch)
current_index = next_index - len(batch)
current_batch = batch
positions = _int64_range(current_index, current_index + len(batch)) if row_id_requested else None

if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
current_batch = current_batch.take(indices)
if pyarrow_filter is not None:
# Temporary fix until PyArrow 21 is the minimum supported version
# (https://github.com/apache/arrow/pull/46057): RecordBatch.filter raises
# IndexError on PyArrow <21 when the result is empty; Table.filter does not.
table = pa.Table.from_batches([current_batch])
table = table.filter(pyarrow_filter)
if table.num_rows == 0:
current_batch = current_batch.slice(0, 0)
else:
current_batch = table.combine_chunks().to_batches()[0]
if positions is not None:
positions = positions.take(indices)

if apply_filter_after_scan and pyarrow_filter is not None:
# Temporary fix until PyArrow 21 is the minimum supported version
# (https://github.com/apache/arrow/pull/46057): RecordBatch.filter raises
# IndexError on PyArrow <21 when the result is empty; Table.filter does not.
current_batch, positions = _filter_batch_and_positions(current_batch, pyarrow_filter, positions)

# skip empty batches
if current_batch.num_rows == 0:
continue

yield _to_requested_schema(
projected_schema,
projected_batch = _to_requested_schema(
data_projected_schema,
file_project_schema,
current_batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
projected_missing_fields=projected_missing_fields,
allow_timestamp_tz_mismatch=True,
)
yield _append_row_lineage_metadata_columns(
projected_schema,
file_project_schema,
current_batch,
projected_batch,
task,
positions,
)


def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]:
Expand Down
Loading