Skip to content

Commit c537147

Browse files
committed
[SPARK-56166][PYTHON] Use ArrowBatchTransformer.enforce_schema to replace manual type coercion logic
### What changes were proposed in this pull request? Replace manual column-by-column type coercion with `ArrowBatchTransformer.enforce_schema` in three places: 1. `ArrowStreamArrowUDTFSerializer.apply_type_coercion` in serializers.py 2. `ArrowStreamArrowUDFSerializer.create_batch` in serializers.py 3. `process_results` in worker.py (scalar Arrow iter UDF path) Also: - Add `arrow_cast` parameter to `enforce_schema` for strict type matching mode - Add `KeyError` handling in `enforce_schema` for missing columns with user-friendly error - Remove now-unused `coerce_arrow_array` imports from serializers.py and worker.py ### Why are the changes needed? These three places duplicated the same coerce-and-reassemble logic that `enforce_schema` already provides. Consolidating reduces code duplication and ensures consistent error handling. ### Does this PR introduce _any_ user-facing change? Error messages for type/schema mismatches in Arrow UDTFs are slightly changed to be consistent with other Arrow UDF error messages. ### How was this patch tested? Existing tests in `test_arrow_udtf.py` and `test_arrow_udf_scalar.py`. ### Was this patch authored or co-authored using generative AI tooling? Yes.
1 parent 9c40e18 commit c537147

4 files changed

Lines changed: 40 additions & 52 deletions

File tree

python/pyspark/sql/conversion.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def enforce_schema(
113113
batch: "pa.RecordBatch",
114114
arrow_schema: "pa.Schema",
115115
safecheck: bool = True,
116+
arrow_cast: bool = True,
116117
) -> "pa.RecordBatch":
117118
"""
118119
Enforce target schema on a RecordBatch by reordering columns and coercing types.
@@ -126,6 +127,9 @@ def enforce_schema(
126127
to_arrow_schema() to avoid repeated conversion.
127128
safecheck : bool, default True
128129
If True, use safe casting (fails on overflow/truncation).
130+
arrow_cast : bool, default True
131+
If True, cast mismatched types to the target type.
132+
If False, raise an error on type mismatch instead of casting.
129133
130134
Returns
131135
-------
@@ -149,8 +153,19 @@ def enforce_schema(
149153

150154
coerced_arrays = []
151155
for i, field in enumerate(arrow_schema):
152-
arr = batch.column(i) if use_index else batch.column(field.name)
156+
try:
157+
arr = batch.column(i) if use_index else batch.column(field.name)
158+
except KeyError:
159+
raise PySparkTypeError(
160+
f"Result column '{field.name}' does not exist in the output. "
161+
f"Expected schema: {arrow_schema}, got: {batch.schema}."
162+
)
153163
if arr.type != field.type:
164+
if not arrow_cast:
165+
raise PySparkTypeError(
166+
f"Result type of column '{field.name}' does not match "
167+
f"the expected type. Expected: {field.type}, got: {arr.type}."
168+
)
154169
try:
155170
arr = arr.cast(target_type=field.type, safe=safecheck)
156171
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:

python/pyspark/sql/pandas/serializers.py

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
2424

2525
import pyspark
26-
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
26+
from pyspark.errors import PySparkRuntimeError, PySparkValueError
2727
from pyspark.serializers import (
2828
Serializer,
2929
read_int,
@@ -37,7 +37,6 @@
3737
ArrowTableToRowsConversion,
3838
ArrowBatchTransformer,
3939
PandasToArrowConversion,
40-
coerce_arrow_array,
4140
)
4241
from pyspark.sql.pandas.types import (
4342
from_arrow_schema,
@@ -299,40 +298,9 @@ def apply_type_coercion():
299298
assert isinstance(arrow_return_type, pa.StructType), (
300299
f"Expected pa.StructType, got {type(arrow_return_type)}"
301300
)
302-
303-
# Handle empty struct case specially
304-
if batch.num_columns == 0:
305-
coerced_batch = batch # skip type coercion
306-
else:
307-
expected_field_names = [field.name for field in arrow_return_type]
308-
actual_field_names = batch.schema.names
309-
310-
if expected_field_names != actual_field_names:
311-
raise PySparkTypeError(
312-
"Target schema's field names are not matching the record batch's "
313-
"field names. "
314-
f"Expected: {expected_field_names}, but got: {actual_field_names}."
315-
)
316-
317-
coerced_arrays = []
318-
for i, field in enumerate(arrow_return_type):
319-
try:
320-
coerced_arrays.append(
321-
coerce_arrow_array(
322-
batch.column(i),
323-
field.type,
324-
safecheck=True,
325-
)
326-
)
327-
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
328-
raise PySparkTypeError(
329-
f"Result type of column '{field.name}' does not "
330-
f"match the expected type. Expected: {field.type}, "
331-
f"got: {batch.column(i).type}."
332-
) from e
333-
coerced_batch = pa.RecordBatch.from_arrays(
334-
coerced_arrays, names=expected_field_names
335-
)
301+
coerced_batch = ArrowBatchTransformer.enforce_schema(
302+
batch, pa.schema(arrow_return_type), safecheck=True
303+
)
336304
yield coerced_batch, arrow_return_type
337305

338306
return super().dump_stream(apply_type_coercion(), stream)
@@ -615,13 +583,15 @@ def dump_stream(self, iterator, stream):
615583
def create_batch(
616584
arr_tuples: List[Tuple["pa.Array", "pa.DataType"]],
617585
) -> "pa.RecordBatch":
618-
arrs = [
619-
coerce_arrow_array(
620-
arr, arrow_type, safecheck=self._safecheck, arrow_cast=self._arrow_cast
621-
)
622-
for arr, arrow_type in arr_tuples
623-
]
624-
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
586+
names = ["_%d" % i for i in range(len(arr_tuples))]
587+
arrs = [arr for arr, _ in arr_tuples]
588+
batch = pa.RecordBatch.from_arrays(arrs, names)
589+
target_schema = pa.schema(
590+
[pa.field(n, t) for n, (_, t) in zip(names, arr_tuples)]
591+
)
592+
return ArrowBatchTransformer.enforce_schema(
593+
batch, target_schema, safecheck=self._safecheck, arrow_cast=self._arrow_cast
594+
)
625595

626596
def normalize(packed):
627597
if len(packed) == 2 and isinstance(packed[1], pa.DataType):

python/pyspark/sql/tests/arrow/test_arrow_udtf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def eval(self) -> Iterator["pa.Table"]:
211211

212212
with self.assertRaisesRegex(
213213
PythonException,
214-
"Target schema's field names are not matching the record batch's field names",
214+
r"(?s)Result column 'x' does not exist in the output\. "
215+
r"Expected schema: x: int32\ny: string, "
216+
r"got: wrong_col: int32\nanother_wrong_col: double\.",
215217
):
216218
result_df = MismatchedSchemaUDTF()
217219
result_df.collect()
@@ -373,8 +375,8 @@ def eval(self) -> Iterator["pa.Table"]:
373375
# Should fail with Arrow cast exception since string cannot be cast to int
374376
with self.assertRaisesRegex(
375377
PythonException,
376-
"PySparkTypeError: Result type of column 'id' does not match the expected type. "
377-
"Expected: int32, got: string.",
378+
"Result type of column 'id' does not match "
379+
"the expected type. Expected: int32, got: string.",
378380
):
379381
result_df = StringToIntUDTF()
380382
result_df.collect()

python/pyspark/worker.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
LocalDataToArrowConversion,
5252
ArrowTableToRowsConversion,
5353
ArrowBatchTransformer,
54-
coerce_arrow_array,
5554
)
5655
from pyspark.sql.functions import SkipRestOfInputTableException
5756
from pyspark.sql.pandas.serializers import (
@@ -2909,13 +2908,15 @@ def extract_args(batch: pa.RecordBatch):
29092908
# Call UDF and verify result type (iterator of pa.Array)
29102909
verified_iter = verify_result(pa.Array)(udf_func(args_iter))
29112910

2912-
# Process results: coerce type and assemble into RecordBatch
2911+
# Process results: enforce schema and assemble into RecordBatch
2912+
target_schema = pa.schema([pa.field("_0", arrow_return_type)])
2913+
29132914
def process_results():
29142915
for result in verified_iter:
2915-
result = coerce_arrow_array(
2916-
result, arrow_return_type, safecheck=True, arrow_cast=True
2916+
batch = pa.RecordBatch.from_arrays([result], ["_0"])
2917+
yield ArrowBatchTransformer.enforce_schema(
2918+
batch, target_schema, safecheck=True
29172919
)
2918-
yield pa.RecordBatch.from_arrays([result], ["_0"])
29192920

29202921
# Apply row limit check (fail-fast)
29212922
# TODO(SPARK-55579): Create Arrow-specific error class (e.g., ARROW_UDF_OUTPUT_EXCEEDS_INPUT_ROWS)

0 commit comments

Comments
 (0)