Skip to content

Commit a8c32ad

Browse files
committed
[Fix](pyudf) Fix error type conversion
1 parent 7642b00 commit a8c32ad

13 files changed

Lines changed: 256 additions & 30 deletions

File tree

be/src/exprs/aggregate/aggregate_function_python_udaf.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,9 @@ void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const {
252252
"Failed to convert argument type {} to Arrow type: {}", i,
253253
st.to_string());
254254
}
255-
fields.push_back(arrow::field(std::to_string(i), arrow_type));
255+
auto primitive_type = argument_types[i]->get_primitive_type();
256+
fields.push_back(create_arrow_field_with_metadata(std::to_string(i), arrow_type,
257+
true, primitive_type));
256258
}
257259

258260
// Add places column for GROUP BY aggregation (always included, NULL in single-place mode)

be/src/format/arrow/arrow_row_batch.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ Status convert_to_arrow_type(const DataTypePtr& origin_type,
172172
}
173173

174174
// Helper function to create an Arrow Field with type metadata if applicable, such as IP types
175-
static std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
175+
std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
176176
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
177177
bool is_nullable, PrimitiveType primitive_type) {
178178
if (primitive_type == PrimitiveType::TYPE_IPV4) {
@@ -181,6 +181,9 @@ static std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
181181
} else if (primitive_type == PrimitiveType::TYPE_IPV6) {
182182
auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"IPV6"});
183183
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata);
184+
} else if (primitive_type == PrimitiveType::TYPE_LARGEINT) {
185+
auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"LARGEINT"});
186+
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata);
184187
} else {
185188
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable);
186189
}

be/src/format/arrow/arrow_row_batch.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "common/status.h"
2424
#include "core/block/block.h"
25+
#include "core/data_type/define_primitive_type.h"
2526
#include "exprs/vexpr_fwd.h"
2627

2728
// This file will convert Doris RowBatch to/from Arrow's RecordBatch
@@ -31,6 +32,7 @@
3132
namespace arrow {
3233

3334
class DataType;
35+
class Field;
3436
class RecordBatch;
3537
class Schema;
3638

@@ -42,8 +44,14 @@ constexpr size_t MAX_ARROW_UTF8 = (1ULL << 21); // 2G
4244

4345
class RowDescriptor;
4446

45-
Status convert_to_arrow_type(const DataTypePtr& type, std::shared_ptr<arrow::DataType>* result,
46-
const std::string& timezone);
47+
// Create an Arrow Field with doris_type metadata for special types (IPV4, IPV6, LARGEINT).
48+
// These types require metadata so the Python UDF server can perform proper type conversion.
49+
std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
50+
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
51+
bool is_nullable, PrimitiveType primitive_type);
52+
53+
Status convert_to_arrow_type(const vectorized::DataTypePtr& type,
54+
std::shared_ptr<arrow::DataType>* result, const std::string& timezone);
4755

4856
Status get_arrow_schema_from_block(const Block& block, std::shared_ptr<arrow::Schema>* result,
4957
const std::string& timezone);

be/src/udf/python/python_server.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -288,25 +288,40 @@ def convert_arrow_field_to_python(field, column_metadata=None):
288288
)
289289
return value
290290
return None
291+
# Handle Doris LARGEINT type (Arrow utf8 -> Python int)
292+
elif doris_type in (b'LARGEINT', 'LARGEINT'):
293+
if pa.types.is_string(field.type) or pa.types.is_large_string(field.type):
294+
value = field.as_py()
295+
if value is not None:
296+
try:
297+
return int(value)
298+
except (ValueError, TypeError) as e:
299+
logging.warning(
300+
"Failed to convert string '%s' to int (LARGEINT): %s", value, e
301+
)
302+
return value
303+
return None
291304

292305
return field.as_py()
293306

294307

295-
def convert_python_to_arrow_value(value, output_type=None):
308+
def convert_python_to_arrow_value(value, output_type=None, output_metadata=None):
296309
"""
297310
Convert Python value back to Arrow-compatible value.
298311
299-
This function handles the reverse conversion of IP addresses:
312+
This function handles the reverse conversion of special types:
300313
- ipaddress.IPv4Address -> int (with uint32 to int32 conversion)
301314
- ipaddress.IPv6Address -> str (for Arrow utf8)
315+
- Python int -> str (for LARGEINT, which uses Arrow utf8)
302316
303317
Type Safety:
304318
For IPv4/IPv6 return types, MUST return ipaddress objects.
305319
Returning raw integers or strings will raise TypeError.
306320
307321
Args:
308322
value: Python value to convert (can be single value or iterable)
309-
output_type: Optional Arrow DataType with metadata
323+
output_type: Optional Arrow DataType
324+
output_metadata: Optional metadata dict from the output Arrow field
310325
311326
Returns:
312327
Arrow-compatible value
@@ -316,14 +331,23 @@ def convert_python_to_arrow_value(value, output_type=None):
316331

317332
is_ipv4_output = False
318333
is_ipv6_output = False
334+
is_largeint_output = False
319335

320-
if output_type is not None and hasattr(output_type, 'metadata') and output_type.metadata:
336+
# Check output_metadata (from field metadata, passed explicitly)
337+
metadata = output_metadata
338+
# Fallback: check output_type.metadata (for compound types like struct fields)
339+
if metadata is None and output_type is not None and hasattr(output_type, 'metadata') and output_type.metadata:
340+
metadata = output_type.metadata
341+
342+
if metadata:
321343
# Arrow metadata keys can be either bytes or str depending on how they were created
322-
doris_type = output_type.metadata.get(b'doris_type') or output_type.metadata.get('doris_type')
344+
doris_type = metadata.get(b'doris_type') or metadata.get('doris_type')
323345
if doris_type in (b'IPV4', 'IPV4'):
324346
is_ipv4_output = True
325347
elif doris_type in (b'IPV6', 'IPV6'):
326348
is_ipv6_output = True
349+
elif doris_type in (b'LARGEINT', 'LARGEINT'):
350+
is_largeint_output = True
327351

328352
# Convert IPv4Address back to int
329353
if isinstance(value, ipaddress.IPv4Address):
@@ -333,6 +357,10 @@ def convert_python_to_arrow_value(value, output_type=None):
333357
if isinstance(value, ipaddress.IPv6Address):
334358
return str(value)
335359

360+
# Convert Python int back to str for LARGEINT (Arrow uses utf8 for LARGEINT)
361+
if is_largeint_output and isinstance(value, int):
362+
return str(value)
363+
336364
# IPv4 output must return IPv4Address objects
337365
if is_ipv4_output and isinstance(value, int):
338366
raise TypeError(
@@ -352,10 +380,10 @@ def convert_python_to_arrow_value(value, output_type=None):
352380
# For list types, recursively convert elements
353381
if output_type and pa.types.is_list(output_type):
354382
element_type = output_type.value_type
355-
return [convert_python_to_arrow_value(v, element_type) for v in value]
383+
return [convert_python_to_arrow_value(v, element_type, output_metadata) for v in value]
356384
else:
357385
# No type info, just recurse without type
358-
return [convert_python_to_arrow_value(v, None) for v in value]
386+
return [convert_python_to_arrow_value(v, None, output_metadata) for v in value]
359387

360388
# Handle tuple values (could be struct data)
361389
if isinstance(value, tuple):
@@ -373,7 +401,7 @@ def convert_python_to_arrow_value(value, output_type=None):
373401
else:
374402
# Not a struct type, treat as regular tuple and recurse without type
375403
return tuple(convert_python_to_arrow_value(v, None) for v in value)
376-
404+
377405
if isinstance(value, dict):
378406
# For map types, convert keys and values recursively
379407
if output_type and pa.types.is_map(output_type):
@@ -393,7 +421,7 @@ def convert_python_to_arrow_value(value, output_type=None):
393421
for k, v in value.items()]
394422

395423
if isinstance(value, pd.Series):
396-
return value.apply(lambda v: convert_python_to_arrow_value(v, output_type))
424+
return value.apply(lambda v: convert_python_to_arrow_value(v, output_type, output_metadata))
397425

398426
return value
399427

@@ -473,6 +501,7 @@ def __init__(
473501
input_types: pa.Schema,
474502
output_type: pa.DataType,
475503
client_type: int,
504+
output_metadata: Optional[dict] = None,
476505
) -> None:
477506
"""
478507
Initialize Python UDF metadata.
@@ -488,6 +517,7 @@ def __init__(
488517
input_types: PyArrow schema for input parameters
489518
output_type: PyArrow data type for return value
490519
client_type: 0 for UDF, 1 for UDAF, 2 for UDTF
520+
output_metadata: Optional metadata dict from the output Arrow field
491521
"""
492522
self.name = name
493523
self.symbol = symbol
@@ -499,6 +529,7 @@ def __init__(
499529
self.input_types = input_types
500530
self.output_type = output_type
501531
self.client_type = ClientType(client_type)
532+
self.output_metadata = output_metadata
502533

503534
def is_udf(self) -> bool:
504535
"""Check if this is a UDF (User-Defined Function)."""
@@ -627,7 +658,7 @@ def _scalar_call(self, record_batch: pa.RecordBatch) -> pa.Array:
627658
f"please check the always_nullable property in create function statement, "
628659
f"it should be true"
629660
)
630-
result.append(convert_python_to_arrow_value(res, self.python_udf_meta.output_type))
661+
result.append(convert_python_to_arrow_value(res, self.python_udf_meta.output_type, self.python_udf_meta.output_metadata))
631662
except Exception as e:
632663
logging.error(
633664
"Error in scalar UDF execution at row %s: %s\nArgs: %s\nTraceback: %s",
@@ -697,7 +728,7 @@ def _vectorized_call(self, record_batch: pa.RecordBatch) -> pa.Array:
697728
)
698729
raise RuntimeError(f"Error in vectorized UDF: {e}") from e
699730

700-
result = convert_python_to_arrow_value(result, self.python_udf_meta.output_type)
731+
result = convert_python_to_arrow_value(result, self.python_udf_meta.output_type, self.python_udf_meta.output_metadata)
701732

702733
# Convert result to PyArrow Array
703734
result_array = None
@@ -1614,6 +1645,7 @@ def parse_python_udf_meta(
16141645
return None
16151646

16161647
output_type = output_schema.field(0).type
1648+
output_metadata = output_schema.field(0).metadata
16171649

16181650
python_udf_meta = PythonUDFMeta(
16191651
name=name,
@@ -1626,6 +1658,7 @@ def parse_python_udf_meta(
16261658
input_types=input_schema,
16271659
output_type=output_type,
16281660
client_type=client_type,
1661+
output_metadata=output_metadata,
16291662
)
16301663

16311664
return python_udf_meta
@@ -1887,13 +1920,14 @@ def _handle_udaf_finalize(
18871920
place_id: int,
18881921
output_type: pa.DataType,
18891922
state_manager: UDAFStateManager,
1923+
output_metadata: Optional[dict] = None,
18901924
) -> pa.RecordBatch:
18911925
"""Handle UDAF FINALIZE operation.
18921926
18931927
Returns: [result: output_type] (null if failed)
18941928
"""
18951929
try:
1896-
result = convert_python_to_arrow_value(state_manager.finalize(place_id), output_type)
1930+
result = convert_python_to_arrow_value(state_manager.finalize(place_id), output_type, output_metadata)
18971931
except Exception as e:
18981932
logging.error(
18991933
"FINALIZE operation failed for place_id=%s: %s",
@@ -2171,7 +2205,8 @@ def _handle_exchange_udaf(
21712205
)
21722206
elif operation_type == UDAFOperationType.FINALIZE:
21732207
result_batch_finalize = self._handle_udaf_finalize(
2174-
place_id, python_udaf_meta.output_type, state_manager
2208+
place_id, python_udaf_meta.output_type, state_manager,
2209+
python_udaf_meta.output_metadata
21752210
)
21762211
# Serialize the result to binary (including NULL results)
21772212
# NULL is a valid aggregation result, not an error
@@ -2302,7 +2337,8 @@ def _handle_exchange_udtf(
23022337
# Process all input rows and build ListArray
23032338
try:
23042339
response_batch = self._process_udtf_with_list_array(
2305-
udtf_func, input_batch, python_udtf_meta.output_type
2340+
udtf_func, input_batch, python_udtf_meta.output_type,
2341+
python_udtf_meta.output_metadata
23062342
)
23072343

23082344
# Send the response batch
@@ -2339,6 +2375,7 @@ def _process_udtf_with_list_array(
23392375
udtf_func: Callable,
23402376
input_batch: pa.RecordBatch,
23412377
expected_output_type: pa.DataType,
2378+
output_metadata: Optional[dict] = None,
23422379
) -> pa.RecordBatch:
23432380
"""
23442381
Process UDTF function on all input rows and generate a ListArray.
@@ -2347,6 +2384,7 @@ def _process_udtf_with_list_array(
23472384
udtf_func: The UDTF function to call
23482385
input_batch: Input RecordBatch with N rows
23492386
expected_output_type: Expected Arrow type for output data
2387+
output_metadata: Optional metadata dict from the output Arrow field
23502388
23512389
Returns:
23522390
RecordBatch with a single ListArray column where each element
@@ -2424,7 +2462,7 @@ def _process_udtf_with_list_array(
24242462

24252463
all_results.append(row_outputs)
24262464

2427-
all_results = convert_python_to_arrow_value(all_results, expected_output_type)
2465+
all_results = convert_python_to_arrow_value(all_results, expected_output_type, output_metadata)
24282466

24292467
try:
24302468
list_array = pa.array(all_results, type=pa.list_(expected_output_type))

be/src/udf/python/python_udf_meta.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "udf/python/python_udf_meta.h"
1919

2020
#include <arrow/util/base64.h>
21+
#include <arrow/util/key_value_metadata.h>
2122
#include <fmt/core.h>
2223
#include <rapidjson/stringbuffer.h>
2324
#include <rapidjson/writer.h>
@@ -30,15 +31,36 @@
3031

3132
namespace doris {
3233

33-
Status PythonUDFMeta::convert_types_to_schema(const DataTypes& types, const std::string& timezone,
34+
// Create an Arrow Field with doris_type metadata for special types (e.g. IP, LARGEINT)
35+
std::shared_ptr<arrow::Field> PythonUDFMeta::create_field_with_doris_metadata(
36+
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
37+
bool is_nullable, PrimitiveType primitive_type) {
38+
static const std::unordered_map<PrimitiveType, std::string> doris_type_metadata = {
39+
{PrimitiveType::TYPE_IPV4, "IPV4"},
40+
{PrimitiveType::TYPE_IPV6, "IPV6"},
41+
{PrimitiveType::TYPE_LARGEINT, "LARGEINT"},
42+
};
43+
44+
auto it = doris_type_metadata.find(primitive_type);
45+
if (it != doris_type_metadata.end()) {
46+
auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {it->second});
47+
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata);
48+
}
49+
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable);
50+
}
51+
52+
Status PythonUDFMeta::convert_types_to_schema(const vectorized::DataTypes& types,
53+
const std::string& timezone,
3454
std::shared_ptr<arrow::Schema>* schema) {
3555
assert(!types.empty());
3656
arrow::SchemaBuilder builder;
3757
for (size_t i = 0; i < types.size(); ++i) {
3858
std::shared_ptr<arrow::DataType> arrow_type;
3959
RETURN_IF_ERROR(convert_to_arrow_type(types[i], &arrow_type, timezone));
40-
std::shared_ptr<arrow::Field> field = std::make_shared<arrow::Field>(
41-
"arg" + std::to_string(i), arrow_type, types[i]->is_nullable());
60+
61+
auto field = create_field_with_doris_metadata("arg" + std::to_string(i), arrow_type,
62+
types[i]->is_nullable(),
63+
types[i]->get_primitive_type());
4264
RETURN_DORIS_STATUS_IF_ERROR(builder.AddField(field));
4365
}
4466
RETURN_DORIS_STATUS_IF_RESULT_ERROR(schema, builder.Finish());

be/src/udf/python/python_udf_meta.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ struct PythonUDFMeta {
5959
Status check() const;
6060

6161
bool operator==(const PythonUDFMeta& other) const { return id == other.id; }
62+
63+
private:
64+
std::shared_ptr<arrow::Field> create_field_with_doris_metadata(
65+
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
66+
bool is_nullable, PrimitiveType primitive_type);
6267
};
6368

6469
} // namespace doris

be/test/core/data_type_serde/data_type_serde_arrow_test.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,16 @@ std::shared_ptr<Block> create_test_block(std::vector<PrimitiveType> cols, int ro
387387
ColumnWithTypeAndName type_and_name(vec->get_ptr(), data_type, col_name);
388388
block->insert(std::move(type_and_name));
389389
} break;
390+
case TYPE_LARGEINT: {
391+
auto vec = vectorized::ColumnInt128::create();
392+
auto& data = vec->get_data();
393+
for (int i = 0; i < row_num; ++i) {
394+
data.push_back(__int128_t(i));
395+
}
396+
vectorized::DataTypePtr data_type(std::make_shared<vectorized::DataTypeInt128>());
397+
vectorized::ColumnWithTypeAndName type_and_name(vec->get_ptr(), data_type, col_name);
398+
block->insert(std::move(type_and_name));
399+
} break;
390400
default:
391401
LOG(FATAL) << "error column type";
392402
}
@@ -425,9 +435,9 @@ void block_converter_test(std::vector<PrimitiveType> cols, int row_num, bool is_
425435

426436
TEST(DataTypeSerDeArrowTest, DataTypeScalaSerDeTest) {
427437
std::vector<PrimitiveType> cols = {
428-
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
429-
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_DATETIME,
430-
TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
438+
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
439+
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_LARGEINT,
440+
TYPE_DATETIME, TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
431441
};
432442
serialize_and_deserialize_arrow_test(cols, 7, true);
433443
serialize_and_deserialize_arrow_test(cols, 7, false);
@@ -506,9 +516,9 @@ TEST(DataTypeSerDeArrowTest, BigStringSerDeTest) {
506516

507517
TEST(DataTypeSerDeArrowTest, BlockConverterTest) {
508518
std::vector<PrimitiveType> cols = {
509-
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
510-
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_DATETIME,
511-
TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
519+
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
520+
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_LARGEINT,
521+
TYPE_DATETIME, TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
512522
};
513523
block_converter_test(cols, 7, true);
514524
block_converter_test(cols, 7, false);

0 commit comments

Comments
 (0)