Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 237 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,15 @@ def _read_deletes(io: FileIO, data_file: DataFile) -> dict[str, pa.ChunkedArray]
raise ValueError(f"Delete file format not supported: {data_file.file_format}")


def _read_equality_delete(io: FileIO, data_file: DataFile) -> pa.Table:
if data_file.file_format in {FileFormat.PARQUET, FileFormat.ORC}:
with io.new_input(data_file.file_path).open() as fi:
delete_fragment = _get_file_format(data_file.file_format, pre_buffer=True, buffer_size=ONE_MEGABYTE).make_fragment(fi)
return ds.Scanner.from_fragment(fragment=delete_fragment).to_table().combine_chunks()
else:
raise ValueError(f"Equality delete file format not supported: {data_file.file_format}")


def _combine_positional_deletes(positional_deletes: list[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
Expand All @@ -1164,6 +1173,178 @@ def _combine_positional_deletes(positional_deletes: list[pa.ChunkedArray], start
return pc.subtract(result, pa.scalar(start_index))


def _equality_delete_key_names(
equality_ids: list[int],
data_schema: Schema,
delete_table: pa.Table,
table_schema: Schema,
downcast_ns_timestamp_to_us: bool,
format_version: TableVersion,
) -> tuple[list[str], list[str]]:
delete_schema = pyarrow_to_schema(
delete_table.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, format_version=format_version
)
data_names: list[str] = []
for field_id in equality_ids:
try:
data_name = data_schema.find_field(field_id).name
except ValueError:
data_name = table_schema.find_field(field_id).name
data_names.append(data_name)

delete_names: list[str] = []
for field_id, data_name in zip(equality_ids, data_names, strict=True):
try:
delete_name = delete_schema.find_field(field_id).name
except ValueError:
delete_name = table_schema.find_field(field_id).name
if delete_name not in delete_table.column_names and data_name in delete_table.column_names:
delete_name = data_name
delete_names.append(delete_name)

return data_names, delete_names


def _integer_type_bounds(arrow_type: pa.DataType) -> tuple[int, int]:
bit_width = arrow_type.bit_width
if pa.types.is_unsigned_integer(arrow_type):
return 0, (2**bit_width) - 1
return -(2 ** (bit_width - 1)), (2 ** (bit_width - 1)) - 1


def _cast_equality_delete_key_column(column: pa.ChunkedArray, data_type: pa.DataType) -> pa.ChunkedArray:
if column.type == data_type:
return column

try:
return column.cast(data_type)
except pa.lib.ArrowInvalid:
if not (pa.types.is_integer(column.type) and pa.types.is_integer(data_type)):
raise

source_min, source_max = _integer_type_bounds(column.type)
target_min, target_max = _integer_type_bounds(data_type)
if target_min > source_max or target_max < source_min:
return pa.chunked_array([pa.nulls(len(chunk), type=data_type) for chunk in column.chunks])

in_range = pc.is_valid(column)
if target_min > source_min:
in_range = pc.and_(in_range, pc.greater_equal(column, pa.scalar(target_min, type=column.type)))
if target_max < source_max:
in_range = pc.and_(in_range, pc.less_equal(column, pa.scalar(target_max, type=column.type)))

column = pc.if_else(in_range, column, pa.scalar(None, type=column.type))
return column.cast(data_type)


def _materialize_missing_equality_delete_columns(
data_table: pa.Table, equality_ids: list[int], data_names: list[str], data_schema: Schema, table_schema: Schema
) -> pa.Table:
for field_id, data_name in zip(equality_ids, data_names, strict=True):
try:
data_schema.find_field(field_id)
except ValueError:
if data_name not in data_table.column_names:
table_field = table_schema.find_field(field_id)
arrow_type = schema_to_pyarrow(table_field.field_type)
data_table = data_table.append_column(
pa.field(data_name, arrow_type), pa.nulls(data_table.num_rows, type=arrow_type)
)

return data_table


def _equality_delete_key_table(
delete_table: pa.Table, data_table: pa.Table, data_names: list[str], delete_names: list[str]
) -> pa.Table:
arrays: list[pa.ChunkedArray] = []
for data_name, delete_name in zip(data_names, delete_names, strict=True):
column = delete_table.column(delete_name)
data_type = data_table.schema.field(data_name).type
column = _cast_equality_delete_key_column(column, data_type)
arrays.append(column)

return pa.Table.from_arrays(arrays, names=data_names)


def _has_null_equality_key(key_table: pa.Table, key_names: list[str]) -> pa.ChunkedArray:
null_key_mask = pc.is_null(key_table.column(key_names[0]))
for key_name in key_names[1:]:
null_key_mask = pc.or_(null_key_mask, pc.is_null(key_table.column(key_name)))
return null_key_mask


def _apply_null_equality_delete_rows(data_table: pa.Table, delete_key_table: pa.Table, key_names: list[str]) -> pa.Table:
if data_table.num_rows == 0 or delete_key_table.num_rows == 0:
return data_table

data_table = data_table.combine_chunks()
delete_key_table = delete_key_table.combine_chunks()
match = pa.array([False] * data_table.num_rows)

for row_idx in range(delete_key_table.num_rows):
row_match = pa.array([True] * data_table.num_rows)
for key_name in key_names:
delete_value = delete_key_table.column(key_name)[row_idx]
data_column = data_table.column(key_name)
column_match = (
pc.fill_null(pc.equal(data_column, delete_value), False) if delete_value.is_valid else pc.is_null(data_column)
)
row_match = pc.and_(row_match, column_match)
match = pc.or_(match, row_match)

return data_table.filter(pc.invert(match))


def _apply_equality_deletes(
batch: pa.RecordBatch,
file_project_schema: Schema,
table_schema: Schema,
equality_deletes: list[tuple[list[int], pa.Table]] | None,
downcast_ns_timestamp_to_us: bool,
format_version: TableVersion,
) -> pa.RecordBatch:
if not equality_deletes or batch.num_rows == 0:
return batch

data_table = pa.Table.from_batches([batch])
for equality_ids, delete_table in equality_deletes:
if not equality_ids:
raise ValueError("Equality delete file is missing required equality_ids")
if delete_table.num_rows == 0:
continue

data_names, delete_names = _equality_delete_key_names(
equality_ids,
file_project_schema,
delete_table,
table_schema,
downcast_ns_timestamp_to_us,
format_version,
)
data_table = _materialize_missing_equality_delete_columns(
data_table, equality_ids, data_names, file_project_schema, table_schema
)
delete_key_table = _equality_delete_key_table(delete_table, data_table, data_names, delete_names)
null_key_mask = _has_null_equality_key(delete_key_table, data_names)
non_null_delete_keys = delete_key_table.filter(pc.invert(null_key_mask))
null_delete_keys = delete_key_table.filter(null_key_mask)

if non_null_delete_keys.num_rows > 0:
data_table = data_table.join(non_null_delete_keys, keys=data_names, join_type="left anti")

# PyArrow's anti-join uses SQL null semantics. Iceberg equality deletes use
# IS NOT DISTINCT FROM, so null-key delete rows need an explicit null-aware pass.
# A fully vectorized null-aware anti-join is a production follow-up.
if null_delete_keys.num_rows > 0:
data_table = _apply_null_equality_delete_rows(data_table, null_delete_keys, data_names)

if data_table.num_rows == 0:
return batch.slice(0, 0)

return data_table.combine_chunks().to_batches()[0]


def pyarrow_to_schema(
schema: pa.Schema,
name_mapping: NameMapping | None = None,
Expand Down Expand Up @@ -1621,6 +1802,7 @@ def _task_to_record_batches(
projected_field_ids: set[int],
positional_deletes: list[ChunkedArray] | None,
case_sensitive: bool,
equality_deletes: list[tuple[list[int], pa.Table]] | None = None,
name_mapping: NameMapping | None = None,
partition_spec: PartitionSpec | None = None,
format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION,
Expand Down Expand Up @@ -1691,6 +1873,15 @@ def _task_to_record_batches(
else:
current_batch = table.combine_chunks().to_batches()[0]

current_batch = _apply_equality_deletes(
current_batch,
file_project_schema,
table_schema,
equality_deletes,
downcast_ns_timestamp_to_us,
format_version,
)

# skip empty batches
if current_batch.num_rows == 0:
continue
Expand All @@ -1705,14 +1896,23 @@ def _task_to_record_batches(
)


def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]:
def _read_all_delete_files(
io: FileIO, tasks: Iterable[FileScanTask]
) -> tuple[dict[str, list[ChunkedArray]], dict[str, pa.Table]]:
tasks = list(tasks)
deletes_per_file: dict[str, list[ChunkedArray]] = {}
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
equality_deletes_per_file: dict[str, pa.Table] = {}
unique_positional_deletes = {
delete_file
for task in tasks
for delete_file in task.delete_files
if delete_file.content == DataFileContent.POSITION_DELETES
}
if len(unique_positional_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map(
lambda args: _read_deletes(*args),
[(io, delete_file) for delete_file in unique_deletes],
[(io, delete_file) for delete_file in unique_positional_deletes],
)
for delete in deletes_per_files:
for file, arr in delete.items():
Expand All @@ -1721,7 +1921,21 @@ def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[st
else:
deletes_per_file[file] = [arr]

return deletes_per_file
unique_equality_deletes = {
delete_file
for task in tasks
for delete_file in task.delete_files
if delete_file.content == DataFileContent.EQUALITY_DELETES
}
if len(unique_equality_deletes) > 0:
executor = ExecutorFactory.get_or_create()
equality_delete_files: Iterator[tuple[str, pa.Table]] = executor.map(
lambda args: (args[1].file_path, _read_equality_delete(*args)),
[(io, delete_file) for delete_file in unique_equality_deletes],
)
equality_deletes_per_file.update(equality_delete_files)

return deletes_per_file, equality_deletes_per_file


class ArrowScan:
Expand Down Expand Up @@ -1826,7 +2040,8 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
deletes_per_file = _read_all_delete_files(self._io, tasks)
tasks = list(tasks)
deletes_per_file, equality_deletes_per_file = _read_all_delete_files(self._io, tasks)

total_row_count = 0
executor = ExecutorFactory.get_or_create()
Expand All @@ -1835,7 +2050,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
# Materialize the iterator here to ensure execution happens within the executor.
# Otherwise, the iterator would be lazily consumed later (in the main thread),
# defeating the purpose of using executor.map.
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file, equality_deletes_per_file))

limit_reached = False
for batches in executor.map(batches_for_task, tasks):
Expand All @@ -1855,21 +2070,34 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
break

def _record_batches_from_scan_tasks_and_deletes(
self, tasks: Iterable[FileScanTask], deletes_per_file: dict[str, list[ChunkedArray]]
self,
tasks: Iterable[FileScanTask],
deletes_per_file: dict[str, list[ChunkedArray]],
equality_deletes_per_file: dict[str, pa.Table],
) -> Iterator[pa.RecordBatch]:
total_row_count = 0
for task in tasks:
if self._limit is not None and total_row_count >= self._limit:
break
equality_deletes = [
(delete_file.equality_ids or [], equality_deletes_per_file[delete_file.file_path])
for delete_file in task.delete_files
if delete_file.content == DataFileContent.EQUALITY_DELETES
]
projected_field_ids = set(self._projected_field_ids)
for equality_ids, _ in equality_deletes:
projected_field_ids.update(equality_ids)

batches = _task_to_record_batches(
self._io,
task,
self._bound_row_filter,
self._projected_schema,
self._table_metadata.schema(),
self._projected_field_ids,
projected_field_ids,
deletes_per_file.get(task.file.file_path),
self._case_sensitive,
equality_deletes,
self._table_metadata.name_mapping(),
self._table_metadata.specs().get(task.file.spec_id),
self._table_metadata.format_version,
Expand Down
19 changes: 7 additions & 12 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,21 +2009,13 @@ def from_rest_response(

Returns:
A FileScanTask with the converted data and delete files.

Raises:
NotImplementedError: If equality delete files are encountered.
"""
from pyiceberg.catalog.rest.scan_planning import RESTEqualityDeleteFile

data_file = _rest_file_to_data_file(rest_task.data_file)

resolved_deletes: set[DataFile] = set()
if rest_task.delete_file_references:
for idx in rest_task.delete_file_references:
delete_file = delete_files[idx]
if isinstance(delete_file, RESTEqualityDeleteFile):
raise NotImplementedError(f"PyIceberg does not yet support equality deletes: {delete_file.file_path}")
resolved_deletes.add(_rest_file_to_data_file(delete_file))
resolved_deletes.add(_rest_file_to_data_file(delete_files[idx]))

return FileScanTask(
data_file=data_file,
Expand All @@ -2034,7 +2026,7 @@ def from_rest_response(

def _rest_file_to_data_file(rest_file: RESTContentFile) -> DataFile:
"""Convert a REST content file to a manifest DataFile."""
from pyiceberg.catalog.rest.scan_planning import RESTDataFile
from pyiceberg.catalog.rest.scan_planning import RESTDataFile, RESTEqualityDeleteFile

if isinstance(rest_file, RESTDataFile):
column_sizes = rest_file.column_sizes.to_dict() if rest_file.column_sizes else None
Expand All @@ -2047,6 +2039,8 @@ def _rest_file_to_data_file(rest_file: RESTContentFile) -> DataFile:
null_value_counts = None
nan_value_counts = None

equality_ids = rest_file.equality_ids if isinstance(rest_file, RESTEqualityDeleteFile) else None

data_file = DataFile.from_args(
content=DataFileContent.from_rest_type(rest_file.content),
file_path=rest_file.file_path,
Expand All @@ -2058,6 +2052,7 @@ def _rest_file_to_data_file(rest_file: RESTContentFile) -> DataFile:
value_counts=value_counts,
null_value_counts=null_value_counts,
nan_value_counts=nan_value_counts,
equality_ids=equality_ids,
split_offsets=rest_file.split_offsets,
sort_order_id=rest_file.sort_order_id,
)
Expand Down Expand Up @@ -2335,7 +2330,7 @@ def plan_files(self, manifests: Iterable[ManifestFile]) -> Iterable[FileScanTask
List of FileScanTasks that contain both data and delete files.
"""
data_entries: list[ManifestEntry] = []
delete_index = DeleteFileIndex()
delete_index = DeleteFileIndex(self.table_metadata.schema())

residual_evaluators: dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)

Expand All @@ -2346,7 +2341,7 @@ def plan_files(self, manifests: Iterable[ManifestFile]) -> Iterable[FileScanTask
elif data_file.content == DataFileContent.POSITION_DELETES:
delete_index.add_delete_file(manifest_entry, partition_key=data_file.partition)
elif data_file.content == DataFileContent.EQUALITY_DELETES:
raise ValueError("PyIceberg does not yet support equality deletes: https://github.com/apache/iceberg/issues/6568")
delete_index.add_delete_file(manifest_entry, partition_key=data_file.partition)
else:
raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}")

Expand Down
Loading