Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 17 additions & 49 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def enforce_schema(
cls,
batch: "pa.RecordBatch",
arrow_schema: "pa.Schema",
*,
arrow_cast: bool = True,
safecheck: bool = True,
) -> "pa.RecordBatch":
"""
Expand All @@ -124,6 +126,9 @@ def enforce_schema(
arrow_schema : pa.Schema
Target Arrow schema. Callers should pre-compute this once via
to_arrow_schema() to avoid repeated conversion.
arrow_cast : bool, default True
If True, cast mismatched types to the target type.
If False, raise an error on type mismatch instead of casting.
safecheck : bool, default True
If True, use safe casting (fails on overflow/truncation).

Expand All @@ -149,8 +154,19 @@ def enforce_schema(

coerced_arrays = []
for i, field in enumerate(arrow_schema):
arr = batch.column(i) if use_index else batch.column(field.name)
try:
arr = batch.column(i) if use_index else batch.column(field.name)
except KeyError:
raise PySparkTypeError(
f"Result column '{field.name}' does not exist in the output. "
f"Expected schema: {arrow_schema}, got: {batch.schema}."
)
if arr.type != field.type:
if not arrow_cast:
raise PySparkTypeError(
f"Result type of column '{field.name}' does not match "
f"the expected type. Expected: {field.type}, got: {arr.type}."
)
try:
arr = arr.cast(target_type=field.type, safe=safecheck)
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
Expand Down Expand Up @@ -221,54 +237,6 @@ def to_pandas(
]


# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema
# instead of per-column coercion.
def coerce_arrow_array(
arr: "pa.Array",
target_type: "pa.DataType",
*,
safecheck: bool = True,
arrow_cast: bool = True,
) -> "pa.Array":
"""
Coerce an Arrow Array to a target type, with optional type-mismatch enforcement.

When ``arrow_cast`` is True (default), mismatched types are cast to the
target type. When False, a type mismatch raises an error instead.

Parameters
----------
arr : pa.Array
Input Arrow array
target_type : pa.DataType
Target Arrow type
safecheck : bool
Whether to use safe casting (default True)
arrow_cast : bool
Whether to allow casting when types don't match (default True)

Returns
-------
pa.Array
"""
from pyspark.errors import PySparkTypeError

if arr.type == target_type:
return arr

if not arrow_cast:
raise PySparkTypeError(
"Arrow UDFs require the return type to match the expected Arrow type. "
f"Expected: {target_type}, but got: {arr.type}."
)

# when safe is True, the cast will fail if there's a overflow or other
# unsafe conversion.
# RecordBatch.cast(...) isn't used as minimum PyArrow version
# required for RecordBatch.cast(...) is v16.0
return arr.cast(target_type=target_type, safe=safecheck)


class PandasToArrowConversion:
"""
Conversion utilities from pandas data to Arrow.
Expand Down
54 changes: 11 additions & 43 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple

import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.serializers import (
Serializer,
read_int,
Expand All @@ -37,7 +37,6 @@
ArrowTableToRowsConversion,
ArrowBatchTransformer,
PandasToArrowConversion,
coerce_arrow_array,
)
from pyspark.sql.pandas.types import (
from_arrow_schema,
Expand Down Expand Up @@ -299,40 +298,9 @@ def apply_type_coercion():
assert isinstance(arrow_return_type, pa.StructType), (
f"Expected pa.StructType, got {type(arrow_return_type)}"
)

# Handle empty struct case specially
if batch.num_columns == 0:
coerced_batch = batch # skip type coercion
else:
expected_field_names = [field.name for field in arrow_return_type]
actual_field_names = batch.schema.names

if expected_field_names != actual_field_names:
raise PySparkTypeError(
"Target schema's field names are not matching the record batch's "
"field names. "
f"Expected: {expected_field_names}, but got: {actual_field_names}."
)

coerced_arrays = []
for i, field in enumerate(arrow_return_type):
try:
coerced_arrays.append(
coerce_arrow_array(
batch.column(i),
field.type,
safecheck=True,
)
)
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
raise PySparkTypeError(
f"Result type of column '{field.name}' does not "
f"match the expected type. Expected: {field.type}, "
f"got: {batch.column(i).type}."
) from e
coerced_batch = pa.RecordBatch.from_arrays(
coerced_arrays, names=expected_field_names
)
coerced_batch = ArrowBatchTransformer.enforce_schema(
batch, pa.schema(arrow_return_type), safecheck=True
)
yield coerced_batch, arrow_return_type

return super().dump_stream(apply_type_coercion(), stream)
Expand Down Expand Up @@ -615,13 +583,13 @@ def dump_stream(self, iterator, stream):
def create_batch(
arr_tuples: List[Tuple["pa.Array", "pa.DataType"]],
) -> "pa.RecordBatch":
arrs = [
coerce_arrow_array(
arr, arrow_type, safecheck=self._safecheck, arrow_cast=self._arrow_cast
)
for arr, arrow_type in arr_tuples
]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
names = ["_%d" % i for i in range(len(arr_tuples))]
arrs = [arr for arr, _ in arr_tuples]
batch = pa.RecordBatch.from_arrays(arrs, names)
target_schema = pa.schema([pa.field(n, t) for n, (_, t) in zip(names, arr_tuples)])
return ArrowBatchTransformer.enforce_schema(
batch, target_schema, safecheck=self._safecheck, arrow_cast=self._arrow_cast
)

def normalize(packed):
if len(packed) == 2 and isinstance(packed[1], pa.DataType):
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def eval(self) -> Iterator["pa.Table"]:

with self.assertRaisesRegex(
PythonException,
"Target schema's field names are not matching the record batch's field names",
r"(?s)Result column 'x' does not exist in the output\. "
r"Expected schema: x: int32\ny: string, "
r"got: wrong_col: int32\nanother_wrong_col: double\.",
):
result_df = MismatchedSchemaUDTF()
result_df.collect()
Expand Down Expand Up @@ -373,8 +375,8 @@ def eval(self) -> Iterator["pa.Table"]:
# Should fail with Arrow cast exception since string cannot be cast to int
with self.assertRaisesRegex(
PythonException,
"PySparkTypeError: Result type of column 'id' does not match the expected type. "
"Expected: int32, got: string.",
"Result type of column 'id' does not match "
"the expected type. Expected: int32, got: string.",
):
result_df = StringToIntUDTF()
result_df.collect()
Expand Down
49 changes: 49 additions & 0 deletions python/pyspark/sql/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,55 @@ def test_wrap_struct_empty_batch(self):
self.assertEqual(wrapped.num_rows, 0)
self.assertEqual(wrapped.num_columns, 1)

def test_enforce_schema_nested_cast(self):
"""Nested struct and list types are cast recursively by Arrow."""
import pyarrow as pa

inner = pa.struct([("a", pa.int32()), ("b", pa.float32())])
batch = pa.RecordBatch.from_arrays(
[
pa.array([{"a": 1, "b": 2.0}], type=inner),
pa.array([[1, 2]], type=pa.list_(pa.int32())),
],
names=["s", "l"],
)
target = pa.schema(
[
("s", pa.struct([("a", pa.int64()), ("b", pa.float64())])),
("l", pa.list_(pa.int64())),
]
)
result = ArrowBatchTransformer.enforce_schema(batch, target)
self.assertEqual(result.schema, target)

def test_enforce_schema_arrow_cast_false(self):
"""arrow_cast=False raises on type mismatch instead of casting."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([1], type=pa.int32())], names=["x"])
target = pa.schema([("x", pa.int64())])
with self.assertRaises(PySparkTypeError):
ArrowBatchTransformer.enforce_schema(batch, target, arrow_cast=False)

def test_enforce_schema_safecheck(self):
"""safecheck=True rejects overflow; safecheck=False allows it."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([999], type=pa.int64())], names=["x"])
target = pa.schema([("x", pa.int8())])
with self.assertRaises(PySparkTypeError):
ArrowBatchTransformer.enforce_schema(batch, target, safecheck=True)
result = ArrowBatchTransformer.enforce_schema(batch, target, safecheck=False)
self.assertEqual(result.schema, target)

def test_enforce_schema_missing_column(self):
"""Missing column raises PySparkTypeError."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
with self.assertRaises(PySparkTypeError):
ArrowBatchTransformer.enforce_schema(batch, pa.schema([("missing", pa.int64())]))


@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@unittest.skipIf(not have_pandas, pandas_requirement_message)
Expand Down
11 changes: 5 additions & 6 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
LocalDataToArrowConversion,
ArrowTableToRowsConversion,
ArrowBatchTransformer,
coerce_arrow_array,
)
from pyspark.sql.functions import SkipRestOfInputTableException
from pyspark.sql.pandas.serializers import (
Expand Down Expand Up @@ -2909,13 +2908,13 @@ def extract_args(batch: pa.RecordBatch):
# Call UDF and verify result type (iterator of pa.Array)
verified_iter = verify_result(pa.Array)(udf_func(args_iter))

# Process results: coerce type and assemble into RecordBatch
# Process results: enforce schema and assemble into RecordBatch
target_schema = pa.schema([pa.field("_0", arrow_return_type)])

def process_results():
for result in verified_iter:
result = coerce_arrow_array(
result, arrow_return_type, safecheck=True, arrow_cast=True
)
yield pa.RecordBatch.from_arrays([result], ["_0"])
batch = pa.RecordBatch.from_arrays([result], ["_0"])
yield ArrowBatchTransformer.enforce_schema(batch, target_schema, safecheck=True)

# Apply row limit check (fail-fast)
# TODO(SPARK-55579): Create Arrow-specific error class (e.g., ARROW_UDF_OUTPUT_EXCEEDS_INPUT_ROWS)
Expand Down