Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion be/src/exprs/aggregate/aggregate_function_python_udaf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const {
"Failed to convert argument type {} to Arrow type: {}", i,
st.to_string());
}
fields.push_back(arrow::field(std::to_string(i), arrow_type));
auto primitive_type = argument_types[i]->get_primitive_type();
fields.push_back(create_arrow_field_with_metadata(std::to_string(i), arrow_type,
true, primitive_type));
}

// Add places column for GROUP BY aggregation (always included, NULL in single-place mode)
Expand Down
5 changes: 4 additions & 1 deletion be/src/format/arrow/arrow_row_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Status convert_to_arrow_type(const DataTypePtr& origin_type,
}

// Helper function to create an Arrow Field with type metadata if applicable, such as IP types
static std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
bool is_nullable, PrimitiveType primitive_type) {
if (primitive_type == PrimitiveType::TYPE_IPV4) {
Expand All @@ -181,6 +181,9 @@ static std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
} else if (primitive_type == PrimitiveType::TYPE_IPV6) {
auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"IPV6"});
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata);
} else if (primitive_type == PrimitiveType::TYPE_LARGEINT) {
auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"LARGEINT"});
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata);
} else {
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable);
}
Expand Down
12 changes: 10 additions & 2 deletions be/src/format/arrow/arrow_row_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "common/status.h"
#include "core/block/block.h"
#include "core/data_type/define_primitive_type.h"
#include "exprs/vexpr_fwd.h"

// This file will convert Doris RowBatch to/from Arrow's RecordBatch
Expand All @@ -31,6 +32,7 @@
namespace arrow {

class DataType;
class Field;
class RecordBatch;
class Schema;

Expand All @@ -42,8 +44,14 @@ constexpr size_t MAX_ARROW_UTF8 = (1ULL << 21); // 2G

class RowDescriptor;

Status convert_to_arrow_type(const DataTypePtr& type, std::shared_ptr<arrow::DataType>* result,
const std::string& timezone);
// Create an Arrow Field with doris_type metadata for special types (IPV4, IPV6, LARGEINT).
// These types require metadata so the Python UDF server can perform proper type conversion.
std::shared_ptr<arrow::Field> create_arrow_field_with_metadata(
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
bool is_nullable, PrimitiveType primitive_type);

Status convert_to_arrow_type(const vectorized::DataTypePtr& type,
std::shared_ptr<arrow::DataType>* result, const std::string& timezone);

Status get_arrow_schema_from_block(const Block& block, std::shared_ptr<arrow::Schema>* result,
const std::string& timezone);
Expand Down
68 changes: 53 additions & 15 deletions be/src/udf/python/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,25 +288,40 @@ def convert_arrow_field_to_python(field, column_metadata=None):
)
return value
return None
# Handle Doris LARGEINT type (Arrow utf8 -> Python int)
elif doris_type in (b'LARGEINT', 'LARGEINT'):
if pa.types.is_string(field.type) or pa.types.is_large_string(field.type):
value = field.as_py()
if value is not None:
try:
return int(value)
except (ValueError, TypeError) as e:
logging.warning(
"Failed to convert string '%s' to int (LARGEINT): %s", value, e
)
return value
return None

return field.as_py()


def convert_python_to_arrow_value(value, output_type=None):
def convert_python_to_arrow_value(value, output_type=None, output_metadata=None):
"""
Convert Python value back to Arrow-compatible value.

This function handles the reverse conversion of IP addresses:
This function handles the reverse conversion of special types:
- ipaddress.IPv4Address -> int (with uint32 to int32 conversion)
- ipaddress.IPv6Address -> str (for Arrow utf8)
- Python int -> str (for LARGEINT, which uses Arrow utf8)

Type Safety:
For IPv4/IPv6 return types, MUST return ipaddress objects.
Returning raw integers or strings will raise TypeError.

Args:
value: Python value to convert (can be single value or iterable)
output_type: Optional Arrow DataType with metadata
output_type: Optional Arrow DataType
output_metadata: Optional metadata dict from the output Arrow field

Returns:
Arrow-compatible value
Expand All @@ -316,14 +331,23 @@ def convert_python_to_arrow_value(value, output_type=None):

is_ipv4_output = False
is_ipv6_output = False
is_largeint_output = False

if output_type is not None and hasattr(output_type, 'metadata') and output_type.metadata:
# Check output_metadata (from field metadata, passed explicitly)
metadata = output_metadata
# Fallback: check output_type.metadata (for compound types like struct fields)
if metadata is None and output_type is not None and hasattr(output_type, 'metadata') and output_type.metadata:
metadata = output_type.metadata

if metadata:
# Arrow metadata keys can be either bytes or str depending on how they were created
doris_type = output_type.metadata.get(b'doris_type') or output_type.metadata.get('doris_type')
doris_type = metadata.get(b'doris_type') or metadata.get('doris_type')
if doris_type in (b'IPV4', 'IPV4'):
is_ipv4_output = True
elif doris_type in (b'IPV6', 'IPV6'):
is_ipv6_output = True
elif doris_type in (b'LARGEINT', 'LARGEINT'):
is_largeint_output = True

# Convert IPv4Address back to int
if isinstance(value, ipaddress.IPv4Address):
Expand All @@ -333,6 +357,10 @@ def convert_python_to_arrow_value(value, output_type=None):
if isinstance(value, ipaddress.IPv6Address):
return str(value)

# Convert Python int back to str for LARGEINT (Arrow uses utf8 for LARGEINT)
if is_largeint_output and isinstance(value, int):
return str(value)

# IPv4 output must return IPv4Address objects
if is_ipv4_output and isinstance(value, int):
raise TypeError(
Expand All @@ -352,10 +380,10 @@ def convert_python_to_arrow_value(value, output_type=None):
# For list types, recursively convert elements
if output_type and pa.types.is_list(output_type):
element_type = output_type.value_type
return [convert_python_to_arrow_value(v, element_type) for v in value]
return [convert_python_to_arrow_value(v, element_type, output_metadata) for v in value]
else:
# No type info, just recurse without type
return [convert_python_to_arrow_value(v, None) for v in value]
return [convert_python_to_arrow_value(v, None, output_metadata) for v in value]

# Handle tuple values (could be struct data)
if isinstance(value, tuple):
Expand All @@ -373,7 +401,7 @@ def convert_python_to_arrow_value(value, output_type=None):
else:
# Not a struct type, treat as regular tuple and recurse without type
return tuple(convert_python_to_arrow_value(v, None) for v in value)

if isinstance(value, dict):
# For map types, convert keys and values recursively
if output_type and pa.types.is_map(output_type):
Expand All @@ -393,7 +421,7 @@ def convert_python_to_arrow_value(value, output_type=None):
for k, v in value.items()]

if isinstance(value, pd.Series):
return value.apply(lambda v: convert_python_to_arrow_value(v, output_type))
return value.apply(lambda v: convert_python_to_arrow_value(v, output_type, output_metadata))

return value

Expand Down Expand Up @@ -473,6 +501,7 @@ def __init__(
input_types: pa.Schema,
output_type: pa.DataType,
client_type: int,
output_metadata: Optional[dict] = None,
) -> None:
"""
Initialize Python UDF metadata.
Expand All @@ -488,6 +517,7 @@ def __init__(
input_types: PyArrow schema for input parameters
output_type: PyArrow data type for return value
client_type: 0 for UDF, 1 for UDAF, 2 for UDTF
output_metadata: Optional metadata dict from the output Arrow field
"""
self.name = name
self.symbol = symbol
Expand All @@ -499,6 +529,7 @@ def __init__(
self.input_types = input_types
self.output_type = output_type
self.client_type = ClientType(client_type)
self.output_metadata = output_metadata

def is_udf(self) -> bool:
"""Check if this is a UDF (User-Defined Function)."""
Expand Down Expand Up @@ -627,7 +658,7 @@ def _scalar_call(self, record_batch: pa.RecordBatch) -> pa.Array:
f"please check the always_nullable property in create function statement, "
f"it should be true"
)
result.append(convert_python_to_arrow_value(res, self.python_udf_meta.output_type))
result.append(convert_python_to_arrow_value(res, self.python_udf_meta.output_type, self.python_udf_meta.output_metadata))
except Exception as e:
logging.error(
"Error in scalar UDF execution at row %s: %s\nArgs: %s\nTraceback: %s",
Expand Down Expand Up @@ -697,7 +728,7 @@ def _vectorized_call(self, record_batch: pa.RecordBatch) -> pa.Array:
)
raise RuntimeError(f"Error in vectorized UDF: {e}") from e

result = convert_python_to_arrow_value(result, self.python_udf_meta.output_type)
result = convert_python_to_arrow_value(result, self.python_udf_meta.output_type, self.python_udf_meta.output_metadata)

# Convert result to PyArrow Array
result_array = None
Expand Down Expand Up @@ -1614,6 +1645,7 @@ def parse_python_udf_meta(
return None

output_type = output_schema.field(0).type
output_metadata = output_schema.field(0).metadata

python_udf_meta = PythonUDFMeta(
name=name,
Expand All @@ -1626,6 +1658,7 @@ def parse_python_udf_meta(
input_types=input_schema,
output_type=output_type,
client_type=client_type,
output_metadata=output_metadata,
)

return python_udf_meta
Expand Down Expand Up @@ -1887,13 +1920,14 @@ def _handle_udaf_finalize(
place_id: int,
output_type: pa.DataType,
state_manager: UDAFStateManager,
output_metadata: Optional[dict] = None,
) -> pa.RecordBatch:
"""Handle UDAF FINALIZE operation.

Returns: [result: output_type] (null if failed)
"""
try:
result = convert_python_to_arrow_value(state_manager.finalize(place_id), output_type)
result = convert_python_to_arrow_value(state_manager.finalize(place_id), output_type, output_metadata)
except Exception as e:
logging.error(
"FINALIZE operation failed for place_id=%s: %s",
Expand Down Expand Up @@ -2171,7 +2205,8 @@ def _handle_exchange_udaf(
)
elif operation_type == UDAFOperationType.FINALIZE:
result_batch_finalize = self._handle_udaf_finalize(
place_id, python_udaf_meta.output_type, state_manager
place_id, python_udaf_meta.output_type, state_manager,
python_udaf_meta.output_metadata
)
# Serialize the result to binary (including NULL results)
# NULL is a valid aggregation result, not an error
Expand Down Expand Up @@ -2302,7 +2337,8 @@ def _handle_exchange_udtf(
# Process all input rows and build ListArray
try:
response_batch = self._process_udtf_with_list_array(
udtf_func, input_batch, python_udtf_meta.output_type
udtf_func, input_batch, python_udtf_meta.output_type,
python_udtf_meta.output_metadata
)

# Send the response batch
Expand Down Expand Up @@ -2339,6 +2375,7 @@ def _process_udtf_with_list_array(
udtf_func: Callable,
input_batch: pa.RecordBatch,
expected_output_type: pa.DataType,
output_metadata: Optional[dict] = None,
) -> pa.RecordBatch:
"""
Process UDTF function on all input rows and generate a ListArray.
Expand All @@ -2347,6 +2384,7 @@ def _process_udtf_with_list_array(
udtf_func: The UDTF function to call
input_batch: Input RecordBatch with N rows
expected_output_type: Expected Arrow type for output data
output_metadata: Optional metadata dict from the output Arrow field

Returns:
RecordBatch with a single ListArray column where each element
Expand Down Expand Up @@ -2424,7 +2462,7 @@ def _process_udtf_with_list_array(

all_results.append(row_outputs)

all_results = convert_python_to_arrow_value(all_results, expected_output_type)
all_results = convert_python_to_arrow_value(all_results, expected_output_type, output_metadata)

try:
list_array = pa.array(all_results, type=pa.list_(expected_output_type))
Expand Down
25 changes: 23 additions & 2 deletions be/src/udf/python/python_udf_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "udf/python/python_udf_meta.h"

#include <arrow/util/base64.h>
#include <arrow/util/key_value_metadata.h>
#include <fmt/core.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
Expand All @@ -30,15 +31,35 @@

namespace doris {

// Create an Arrow Field with doris_type metadata for special types (e.g. IP, LARGEINT)
std::shared_ptr<arrow::Field> PythonUDFMeta::create_field_with_doris_metadata(
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
bool is_nullable, PrimitiveType primitive_type) {
static const std::unordered_map<PrimitiveType, std::string> doris_type_metadata = {
{PrimitiveType::TYPE_IPV4, "IPV4"},
{PrimitiveType::TYPE_IPV6, "IPV6"},
{PrimitiveType::TYPE_LARGEINT, "LARGEINT"},
};

auto it = doris_type_metadata.find(primitive_type);
if (it != doris_type_metadata.end()) {
auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {it->second});
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable, metadata);
}
return std::make_shared<arrow::Field>(field_name, arrow_type, is_nullable);
}

Status PythonUDFMeta::convert_types_to_schema(const DataTypes& types, const std::string& timezone,
std::shared_ptr<arrow::Schema>* schema) {
assert(!types.empty());
arrow::SchemaBuilder builder;
for (size_t i = 0; i < types.size(); ++i) {
std::shared_ptr<arrow::DataType> arrow_type;
RETURN_IF_ERROR(convert_to_arrow_type(types[i], &arrow_type, timezone));
std::shared_ptr<arrow::Field> field = std::make_shared<arrow::Field>(
"arg" + std::to_string(i), arrow_type, types[i]->is_nullable());

auto field = create_field_with_doris_metadata("arg" + std::to_string(i), arrow_type,
types[i]->is_nullable(),
types[i]->get_primitive_type());
RETURN_DORIS_STATUS_IF_ERROR(builder.AddField(field));
}
RETURN_DORIS_STATUS_IF_RESULT_ERROR(schema, builder.Finish());
Expand Down
5 changes: 5 additions & 0 deletions be/src/udf/python/python_udf_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ struct PythonUDFMeta {
Status check() const;

bool operator==(const PythonUDFMeta& other) const { return id == other.id; }

private:
std::shared_ptr<arrow::Field> create_field_with_doris_metadata(
const std::string& field_name, const std::shared_ptr<arrow::DataType>& arrow_type,
bool is_nullable, PrimitiveType primitive_type);
};

} // namespace doris
Expand Down
22 changes: 16 additions & 6 deletions be/test/core/data_type_serde/data_type_serde_arrow_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ std::shared_ptr<Block> create_test_block(std::vector<PrimitiveType> cols, int ro
ColumnWithTypeAndName type_and_name(vec->get_ptr(), data_type, col_name);
block->insert(std::move(type_and_name));
} break;
case TYPE_LARGEINT: {
auto vec = ColumnInt128::create();
auto& data = vec->get_data();
for (int i = 0; i < row_num; ++i) {
data.push_back(__int128_t(i));
}
DataTypePtr data_type(std::make_shared<DataTypeInt128>());
ColumnWithTypeAndName type_and_name(vec->get_ptr(), data_type, col_name);
block->insert(std::move(type_and_name));
} break;
default:
LOG(FATAL) << "error column type";
}
Expand Down Expand Up @@ -425,9 +435,9 @@ void block_converter_test(std::vector<PrimitiveType> cols, int row_num, bool is_

TEST(DataTypeSerDeArrowTest, DataTypeScalaSerDeTest) {
std::vector<PrimitiveType> cols = {
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_DATETIME,
TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_LARGEINT,
TYPE_DATETIME, TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
};
serialize_and_deserialize_arrow_test(cols, 7, true);
serialize_and_deserialize_arrow_test(cols, 7, false);
Expand Down Expand Up @@ -506,9 +516,9 @@ TEST(DataTypeSerDeArrowTest, BigStringSerDeTest) {

TEST(DataTypeSerDeArrowTest, BlockConverterTest) {
std::vector<PrimitiveType> cols = {
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_DATETIME,
TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN,
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_LARGEINT,
TYPE_DATETIME, TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2,
};
block_converter_test(cols, 7, true);
block_converter_test(cols, 7, false);
Expand Down
Loading