Skip to content

Commit e6a55c9

Browse files
committed
[SPARK-56166][PYTHON] Use ArrowBatchTransformer.enforce_schema to replace manual type coercion logic
### What changes were proposed in this pull request? Replace manual column-by-column type coercion with `ArrowBatchTransformer.enforce_schema` in three places: 1. `ArrowStreamArrowUDTFSerializer.apply_type_coercion` in serializers.py 2. `ArrowStreamArrowUDFSerializer.create_batch` in serializers.py 3. `process_results` in worker.py (scalar Arrow iter UDF path) Also: - Add `arrow_cast` parameter to `enforce_schema` for strict type matching mode - Add `KeyError` handling in `enforce_schema` for missing columns with user-friendly error - Remove now-unused `coerce_arrow_array` imports from serializers.py and worker.py ### Why are the changes needed? These three places duplicated the same coerce-and-reassemble logic that `enforce_schema` already provides. Consolidating reduces code duplication and ensures consistent error handling. ### Does this PR introduce _any_ user-facing change? Error messages for type/schema mismatches in Arrow UDTFs are slightly changed to be consistent with other Arrow UDF error messages. ### How was this patch tested? Existing tests in `test_arrow_udtf.py` and `test_arrow_udf_scalar.py`. ### Was this patch authored or co-authored using generative AI tooling? Yes.
1 parent 9c40e18 commit e6a55c9

4 files changed

Lines changed: 41 additions & 100 deletions

File tree

python/pyspark/sql/conversion.py

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def enforce_schema(
112112
cls,
113113
batch: "pa.RecordBatch",
114114
arrow_schema: "pa.Schema",
115+
*,
116+
arrow_cast: bool = True,
115117
safecheck: bool = True,
116118
) -> "pa.RecordBatch":
117119
"""
@@ -124,6 +126,9 @@ def enforce_schema(
124126
arrow_schema : pa.Schema
125127
Target Arrow schema. Callers should pre-compute this once via
126128
to_arrow_schema() to avoid repeated conversion.
129+
arrow_cast : bool, default True
130+
If True, cast mismatched types to the target type.
131+
If False, raise an error on type mismatch instead of casting.
127132
safecheck : bool, default True
128133
If True, use safe casting (fails on overflow/truncation).
129134
@@ -149,8 +154,19 @@ def enforce_schema(
149154

150155
coerced_arrays = []
151156
for i, field in enumerate(arrow_schema):
152-
arr = batch.column(i) if use_index else batch.column(field.name)
157+
try:
158+
arr = batch.column(i) if use_index else batch.column(field.name)
159+
except KeyError:
160+
raise PySparkTypeError(
161+
f"Result column '{field.name}' does not exist in the output. "
162+
f"Expected schema: {arrow_schema}, got: {batch.schema}."
163+
)
153164
if arr.type != field.type:
165+
if not arrow_cast:
166+
raise PySparkTypeError(
167+
f"Result type of column '{field.name}' does not match "
168+
f"the expected type. Expected: {field.type}, got: {arr.type}."
169+
)
154170
try:
155171
arr = arr.cast(target_type=field.type, safe=safecheck)
156172
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
@@ -221,54 +237,6 @@ def to_pandas(
221237
]
222238

223239

224-
# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema
225-
# instead of per-column coercion.
226-
def coerce_arrow_array(
227-
arr: "pa.Array",
228-
target_type: "pa.DataType",
229-
*,
230-
safecheck: bool = True,
231-
arrow_cast: bool = True,
232-
) -> "pa.Array":
233-
"""
234-
Coerce an Arrow Array to a target type, with optional type-mismatch enforcement.
235-
236-
When ``arrow_cast`` is True (default), mismatched types are cast to the
237-
target type. When False, a type mismatch raises an error instead.
238-
239-
Parameters
240-
----------
241-
arr : pa.Array
242-
Input Arrow array
243-
target_type : pa.DataType
244-
Target Arrow type
245-
safecheck : bool
246-
Whether to use safe casting (default True)
247-
arrow_cast : bool
248-
Whether to allow casting when types don't match (default True)
249-
250-
Returns
251-
-------
252-
pa.Array
253-
"""
254-
from pyspark.errors import PySparkTypeError
255-
256-
if arr.type == target_type:
257-
return arr
258-
259-
if not arrow_cast:
260-
raise PySparkTypeError(
261-
"Arrow UDFs require the return type to match the expected Arrow type. "
262-
f"Expected: {target_type}, but got: {arr.type}."
263-
)
264-
265-
# when safe is True, the cast will fail if there's a overflow or other
266-
# unsafe conversion.
267-
# RecordBatch.cast(...) isn't used as minimum PyArrow version
268-
# required for RecordBatch.cast(...) is v16.0
269-
return arr.cast(target_type=target_type, safe=safecheck)
270-
271-
272240
class PandasToArrowConversion:
273241
"""
274242
Conversion utilities from pandas data to Arrow.

python/pyspark/sql/pandas/serializers.py

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
2424

2525
import pyspark
26-
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
26+
from pyspark.errors import PySparkRuntimeError, PySparkValueError
2727
from pyspark.serializers import (
2828
Serializer,
2929
read_int,
@@ -37,7 +37,6 @@
3737
ArrowTableToRowsConversion,
3838
ArrowBatchTransformer,
3939
PandasToArrowConversion,
40-
coerce_arrow_array,
4140
)
4241
from pyspark.sql.pandas.types import (
4342
from_arrow_schema,
@@ -299,40 +298,9 @@ def apply_type_coercion():
299298
assert isinstance(arrow_return_type, pa.StructType), (
300299
f"Expected pa.StructType, got {type(arrow_return_type)}"
301300
)
302-
303-
# Handle empty struct case specially
304-
if batch.num_columns == 0:
305-
coerced_batch = batch # skip type coercion
306-
else:
307-
expected_field_names = [field.name for field in arrow_return_type]
308-
actual_field_names = batch.schema.names
309-
310-
if expected_field_names != actual_field_names:
311-
raise PySparkTypeError(
312-
"Target schema's field names are not matching the record batch's "
313-
"field names. "
314-
f"Expected: {expected_field_names}, but got: {actual_field_names}."
315-
)
316-
317-
coerced_arrays = []
318-
for i, field in enumerate(arrow_return_type):
319-
try:
320-
coerced_arrays.append(
321-
coerce_arrow_array(
322-
batch.column(i),
323-
field.type,
324-
safecheck=True,
325-
)
326-
)
327-
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
328-
raise PySparkTypeError(
329-
f"Result type of column '{field.name}' does not "
330-
f"match the expected type. Expected: {field.type}, "
331-
f"got: {batch.column(i).type}."
332-
) from e
333-
coerced_batch = pa.RecordBatch.from_arrays(
334-
coerced_arrays, names=expected_field_names
335-
)
301+
coerced_batch = ArrowBatchTransformer.enforce_schema(
302+
batch, pa.schema(arrow_return_type), safecheck=True
303+
)
336304
yield coerced_batch, arrow_return_type
337305

338306
return super().dump_stream(apply_type_coercion(), stream)
@@ -615,13 +583,15 @@ def dump_stream(self, iterator, stream):
615583
def create_batch(
616584
arr_tuples: List[Tuple["pa.Array", "pa.DataType"]],
617585
) -> "pa.RecordBatch":
618-
arrs = [
619-
coerce_arrow_array(
620-
arr, arrow_type, safecheck=self._safecheck, arrow_cast=self._arrow_cast
621-
)
622-
for arr, arrow_type in arr_tuples
623-
]
624-
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
586+
names = ["_%d" % i for i in range(len(arr_tuples))]
587+
arrs = [arr for arr, _ in arr_tuples]
588+
batch = pa.RecordBatch.from_arrays(arrs, names)
589+
target_schema = pa.schema(
590+
[pa.field(n, t) for n, (_, t) in zip(names, arr_tuples)]
591+
)
592+
return ArrowBatchTransformer.enforce_schema(
593+
batch, target_schema, safecheck=self._safecheck, arrow_cast=self._arrow_cast
594+
)
625595

626596
def normalize(packed):
627597
if len(packed) == 2 and isinstance(packed[1], pa.DataType):

python/pyspark/sql/tests/arrow/test_arrow_udtf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def eval(self) -> Iterator["pa.Table"]:
211211

212212
with self.assertRaisesRegex(
213213
PythonException,
214-
"Target schema's field names are not matching the record batch's field names",
214+
r"(?s)Result column 'x' does not exist in the output\. "
215+
r"Expected schema: x: int32\ny: string, "
216+
r"got: wrong_col: int32\nanother_wrong_col: double\.",
215217
):
216218
result_df = MismatchedSchemaUDTF()
217219
result_df.collect()
@@ -373,8 +375,8 @@ def eval(self) -> Iterator["pa.Table"]:
373375
# Should fail with Arrow cast exception since string cannot be cast to int
374376
with self.assertRaisesRegex(
375377
PythonException,
376-
"PySparkTypeError: Result type of column 'id' does not match the expected type. "
377-
"Expected: int32, got: string.",
378+
"Result type of column 'id' does not match "
379+
"the expected type. Expected: int32, got: string.",
378380
):
379381
result_df = StringToIntUDTF()
380382
result_df.collect()

python/pyspark/worker.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
LocalDataToArrowConversion,
5252
ArrowTableToRowsConversion,
5353
ArrowBatchTransformer,
54-
coerce_arrow_array,
5554
)
5655
from pyspark.sql.functions import SkipRestOfInputTableException
5756
from pyspark.sql.pandas.serializers import (
@@ -2909,13 +2908,15 @@ def extract_args(batch: pa.RecordBatch):
29092908
# Call UDF and verify result type (iterator of pa.Array)
29102909
verified_iter = verify_result(pa.Array)(udf_func(args_iter))
29112910

2912-
# Process results: coerce type and assemble into RecordBatch
2911+
# Process results: enforce schema and assemble into RecordBatch
2912+
target_schema = pa.schema([pa.field("_0", arrow_return_type)])
2913+
29132914
def process_results():
29142915
for result in verified_iter:
2915-
result = coerce_arrow_array(
2916-
result, arrow_return_type, safecheck=True, arrow_cast=True
2916+
batch = pa.RecordBatch.from_arrays([result], ["_0"])
2917+
yield ArrowBatchTransformer.enforce_schema(
2918+
batch, target_schema, safecheck=True
29172919
)
2918-
yield pa.RecordBatch.from_arrays([result], ["_0"])
29192920

29202921
# Apply row limit check (fail-fast)
29212922
# TODO(SPARK-55579): Create Arrow-specific error class (e.g., ARROW_UDF_OUTPUT_EXCEEDS_INPUT_ROWS)

0 commit comments

Comments
 (0)