From e6a55c9811622604e378f8d02cff66101be1d62e Mon Sep 17 00:00:00 2001 From: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:32:23 +0000 Subject: [PATCH 1/4] [SPARK-56166][PYTHON] Use ArrowBatchTransformer.enforce_schema to replace manual type coercion logic ### What changes were proposed in this pull request? Replace manual column-by-column type coercion with `ArrowBatchTransformer.enforce_schema` in three places: 1. `ArrowStreamArrowUDTFSerializer.apply_type_coercion` in serializers.py 2. `ArrowStreamArrowUDFSerializer.create_batch` in serializers.py 3. `process_results` in worker.py (scalar Arrow iter UDF path) Also: - Add `arrow_cast` parameter to `enforce_schema` for strict type matching mode - Add `KeyError` handling in `enforce_schema` for missing columns with user-friendly error - Remove now-unused `coerce_arrow_array` imports from serializers.py and worker.py ### Why are the changes needed? These three places duplicated the same coerce-and-reassemble logic that `enforce_schema` already provides. Consolidating reduces code duplication and ensures consistent error handling. ### Does this PR introduce _any_ user-facing change? Error messages for type/schema mismatches in Arrow UDTFs are slightly changed to be consistent with other Arrow UDF error messages. ### How was this patch tested? Existing tests in `test_arrow_udtf.py` and `test_arrow_udf_scalar.py`. ### Was this patch authored or co-authored using generative AI tooling? Yes. --- python/pyspark/sql/conversion.py | 66 +++++-------------- python/pyspark/sql/pandas/serializers.py | 56 ++++------------ .../sql/tests/arrow/test_arrow_udtf.py | 8 ++- python/pyspark/worker.py | 11 ++-- 4 files changed, 41 insertions(+), 100 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 602ebec8eb80d..c971b9414d41d 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -112,6 +112,8 @@ def enforce_schema( cls, batch: "pa.RecordBatch", arrow_schema: "pa.Schema", + *, + arrow_cast: bool = True, safecheck: bool = True, ) -> "pa.RecordBatch": """ @@ -124,6 +126,9 @@ def enforce_schema( arrow_schema : pa.Schema Target Arrow schema. Callers should pre-compute this once via to_arrow_schema() to avoid repeated conversion. + arrow_cast : bool, default True + If True, cast mismatched types to the target type. + If False, raise an error on type mismatch instead of casting. safecheck : bool, default True If True, use safe casting (fails on overflow/truncation). @@ -149,8 +154,19 @@ def enforce_schema( coerced_arrays = [] for i, field in enumerate(arrow_schema): - arr = batch.column(i) if use_index else batch.column(field.name) + try: + arr = batch.column(i) if use_index else batch.column(field.name) + except KeyError: + raise PySparkTypeError( + f"Result column '{field.name}' does not exist in the output. " + f"Expected schema: {arrow_schema}, got: {batch.schema}." + ) if arr.type != field.type: + if not arrow_cast: + raise PySparkTypeError( + f"Result type of column '{field.name}' does not match " + f"the expected type. Expected: {field.type}, got: {arr.type}." + ) try: arr = arr.cast(target_type=field.type, safe=safecheck) except (pa.ArrowInvalid, pa.ArrowTypeError) as e: @@ -221,54 +237,6 @@ def to_pandas( ] -# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema -# instead of per-column coercion. -def coerce_arrow_array( - arr: "pa.Array", - target_type: "pa.DataType", - *, - safecheck: bool = True, - arrow_cast: bool = True, -) -> "pa.Array": - """ - Coerce an Arrow Array to a target type, with optional type-mismatch enforcement. - - When ``arrow_cast`` is True (default), mismatched types are cast to the - target type. When False, a type mismatch raises an error instead. - - Parameters - ---------- - arr : pa.Array - Input Arrow array - target_type : pa.DataType - Target Arrow type - safecheck : bool - Whether to use safe casting (default True) - arrow_cast : bool - Whether to allow casting when types don't match (default True) - - Returns - ------- - pa.Array - """ - from pyspark.errors import PySparkTypeError - - if arr.type == target_type: - return arr - - if not arrow_cast: - raise PySparkTypeError( - "Arrow UDFs require the return type to match the expected Arrow type. " - f"Expected: {target_type}, but got: {arr.type}." - ) - - # when safe is True, the cast will fail if there's a overflow or other - # unsafe conversion. - # RecordBatch.cast(...) isn't used as minimum PyArrow version - # required for RecordBatch.cast(...) is v16.0 - return arr.cast(target_type=target_type, safe=safecheck) - - class PandasToArrowConversion: """ Conversion utilities from pandas data to Arrow. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 01e4a8163489c..d67e130d07deb 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple import pyspark -from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkRuntimeError, PySparkValueError from pyspark.serializers import ( Serializer, read_int, @@ -37,7 +37,6 @@ ArrowTableToRowsConversion, ArrowBatchTransformer, PandasToArrowConversion, - coerce_arrow_array, ) from pyspark.sql.pandas.types import ( from_arrow_schema, @@ -299,40 +298,9 @@ def apply_type_coercion(): assert isinstance(arrow_return_type, pa.StructType), ( f"Expected pa.StructType, got {type(arrow_return_type)}" ) - - # Handle empty struct case specially - if batch.num_columns == 0: - coerced_batch = batch # skip type coercion - else: - expected_field_names = [field.name for field in arrow_return_type] - actual_field_names = batch.schema.names - - if expected_field_names != actual_field_names: - raise PySparkTypeError( - "Target schema's field names are not matching the record batch's " - "field names. " - f"Expected: {expected_field_names}, but got: {actual_field_names}." - ) - - coerced_arrays = [] - for i, field in enumerate(arrow_return_type): - try: - coerced_arrays.append( - coerce_arrow_array( - batch.column(i), - field.type, - safecheck=True, - ) - ) - except (pa.ArrowInvalid, pa.ArrowTypeError) as e: - raise PySparkTypeError( - f"Result type of column '{field.name}' does not " - f"match the expected type. Expected: {field.type}, " - f"got: {batch.column(i).type}." - ) from e - coerced_batch = pa.RecordBatch.from_arrays( - coerced_arrays, names=expected_field_names - ) + coerced_batch = ArrowBatchTransformer.enforce_schema( + batch, pa.schema(arrow_return_type), safecheck=True + ) yield coerced_batch, arrow_return_type return super().dump_stream(apply_type_coercion(), stream) @@ -615,13 +583,15 @@ def dump_stream(self, iterator, stream): def create_batch( arr_tuples: List[Tuple["pa.Array", "pa.DataType"]], ) -> "pa.RecordBatch": - arrs = [ - coerce_arrow_array( - arr, arrow_type, safecheck=self._safecheck, arrow_cast=self._arrow_cast - ) - for arr, arrow_type in arr_tuples - ] - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + names = ["_%d" % i for i in range(len(arr_tuples))] + arrs = [arr for arr, _ in arr_tuples] + batch = pa.RecordBatch.from_arrays(arrs, names) + target_schema = pa.schema( + [pa.field(n, t) for n, (_, t) in zip(names, arr_tuples)] + ) + return ArrowBatchTransformer.enforce_schema( + batch, target_schema, safecheck=self._safecheck, arrow_cast=self._arrow_cast + ) def normalize(packed): if len(packed) == 2 and isinstance(packed[1], pa.DataType): diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py index 4c3996e125d55..f41b7613ec42d 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py @@ -211,7 +211,9 @@ def eval(self) -> Iterator["pa.Table"]: with self.assertRaisesRegex( PythonException, - "Target schema's field names are not matching the record batch's field names", + r"(?s)Result column 'x' does not exist in the output\. " + r"Expected schema: x: int32\ny: string, " + r"got: wrong_col: int32\nanother_wrong_col: double\.", ): result_df = MismatchedSchemaUDTF() result_df.collect() @@ -373,8 +375,8 @@ def eval(self) -> Iterator["pa.Table"]: # Should fail with Arrow cast exception since string cannot be cast to int with self.assertRaisesRegex( PythonException, - "PySparkTypeError: Result type of column 'id' does not match the expected type. " - "Expected: int32, got: string.", + "Result type of column 'id' does not match " + "the expected type. Expected: int32, got: string.", ): result_df = StringToIntUDTF() result_df.collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0ef578d3715b0..2ada28a7c8384 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -51,7 +51,6 @@ LocalDataToArrowConversion, ArrowTableToRowsConversion, ArrowBatchTransformer, - coerce_arrow_array, ) from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( @@ -2909,13 +2908,15 @@ def extract_args(batch: pa.RecordBatch): # Call UDF and verify result type (iterator of pa.Array) verified_iter = verify_result(pa.Array)(udf_func(args_iter)) - # Process results: coerce type and assemble into RecordBatch + # Process results: enforce schema and assemble into RecordBatch + target_schema = pa.schema([pa.field("_0", arrow_return_type)]) + def process_results(): for result in verified_iter: - result = coerce_arrow_array( - result, arrow_return_type, safecheck=True, arrow_cast=True + batch = pa.RecordBatch.from_arrays([result], ["_0"]) + yield ArrowBatchTransformer.enforce_schema( + batch, target_schema, safecheck=True ) - yield pa.RecordBatch.from_arrays([result], ["_0"]) # Apply row limit check (fail-fast) # TODO(SPARK-55579): Create Arrow-specific error class (e.g., ARROW_UDF_OUTPUT_EXCEEDS_INPUT_ROWS) From 08ee4100e1a53a9b101816f46fb0481708570dac Mon Sep 17 00:00:00 2001 From: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:59:00 +0000 Subject: [PATCH 2/4] fix: apply ruff format to serializers.py and worker.py --- python/pyspark/sql/pandas/serializers.py | 4 +--- python/pyspark/worker.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d67e130d07deb..efce633ae2130 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -586,9 +586,7 @@ def create_batch( names = ["_%d" % i for i in range(len(arr_tuples))] arrs = [arr for arr, _ in arr_tuples] batch = pa.RecordBatch.from_arrays(arrs, names) - target_schema = pa.schema( - [pa.field(n, t) for n, (_, t) in zip(names, arr_tuples)] - ) + target_schema = pa.schema([pa.field(n, t) for n, (_, t) in zip(names, arr_tuples)]) return ArrowBatchTransformer.enforce_schema( batch, target_schema, safecheck=self._safecheck, arrow_cast=self._arrow_cast ) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2ada28a7c8384..306ed8787c849 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2914,9 +2914,7 @@ def extract_args(batch: pa.RecordBatch): def process_results(): for result in verified_iter: batch = pa.RecordBatch.from_arrays([result], ["_0"]) - yield ArrowBatchTransformer.enforce_schema( - batch, target_schema, safecheck=True - ) + yield ArrowBatchTransformer.enforce_schema(batch, target_schema, safecheck=True) # Apply row limit check (fail-fast) # TODO(SPARK-55579): Create Arrow-specific error class (e.g., ARROW_UDF_OUTPUT_EXCEEDS_INPUT_ROWS) From 5f12756ec2477fff981f6bf642576a0c8270f850 Mon Sep 17 00:00:00 2001 From: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 24 Mar 2026 18:13:20 +0000 Subject: [PATCH 3/4] test: add unit tests for ArrowBatchTransformer.enforce_schema --- python/pyspark/sql/tests/test_conversion.py | 47 +++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 9ac6bcbd0537c..2e42408355f42 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -158,6 +158,53 @@ def test_wrap_struct_empty_batch(self): self.assertEqual(wrapped.num_rows, 0) self.assertEqual(wrapped.num_columns, 1) + def test_enforce_schema_nested_cast(self): + """Nested struct and list types are cast recursively by Arrow.""" + import pyarrow as pa + + inner = pa.struct([("a", pa.int32()), ("b", pa.float32())]) + batch = pa.RecordBatch.from_arrays( + [ + pa.array([{"a": 1, "b": 2.0}], type=inner), + pa.array([[1, 2]], type=pa.list_(pa.int32())), + ], + names=["s", "l"], + ) + target = pa.schema([ + ("s", pa.struct([("a", pa.int64()), ("b", pa.float64())])), + ("l", pa.list_(pa.int64())), + ]) + result = ArrowBatchTransformer.enforce_schema(batch, target) + self.assertEqual(result.schema, target) + + def test_enforce_schema_arrow_cast_false(self): + """arrow_cast=False raises on type mismatch instead of casting.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1], type=pa.int32())], names=["x"]) + target = pa.schema([("x", pa.int64())]) + with self.assertRaises(PySparkTypeError): + ArrowBatchTransformer.enforce_schema(batch, target, arrow_cast=False) + + def test_enforce_schema_safecheck(self): + """safecheck=True rejects overflow; safecheck=False allows it.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([999], type=pa.int64())], names=["x"]) + target = pa.schema([("x", pa.int8())]) + with self.assertRaises(PySparkTypeError): + ArrowBatchTransformer.enforce_schema(batch, target, safecheck=True) + result = ArrowBatchTransformer.enforce_schema(batch, target, safecheck=False) + self.assertEqual(result.schema, target) + + def test_enforce_schema_missing_column(self): + """Missing column raises PySparkTypeError.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + with self.assertRaises(PySparkTypeError): + ArrowBatchTransformer.enforce_schema(batch, pa.schema([("missing", pa.int64())])) + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) @unittest.skipIf(not have_pandas, pandas_requirement_message) From 35e212559c88ed61121fa1054a85bb317ef8ec4e Mon Sep 17 00:00:00 2001 From: Yicong Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:39:52 +0000 Subject: [PATCH 4/4] fix: apply ruff format to test_conversion.py --- python/pyspark/sql/tests/test_conversion.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 2e42408355f42..f37ddf0f909d9 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -170,10 +170,12 @@ def test_enforce_schema_nested_cast(self): ], names=["s", "l"], ) - target = pa.schema([ - ("s", pa.struct([("a", pa.int64()), ("b", pa.float64())])), - ("l", pa.list_(pa.int64())), - ]) + target = pa.schema( + [ + ("s", pa.struct([("a", pa.int64()), ("b", pa.float64())])), + ("l", pa.list_(pa.int64())), + ] + ) result = ArrowBatchTransformer.enforce_schema(batch, target) self.assertEqual(result.schema, target)