From 90440b969911604367c2e068b14e0ab83bfe5f7a Mon Sep 17 00:00:00 2001 From: linzhenqi Date: Fri, 27 Mar 2026 05:14:46 +0800 Subject: [PATCH] [Fix](pyudf) Fix error type conversion --- .../aggregate_function_python_udaf.cpp | 4 +- be/src/format/arrow/arrow_row_batch.cpp | 5 +- be/src/format/arrow/arrow_row_batch.h | 5 + be/src/udf/python/python_server.py | 45 +++-- .../data_type_serde_arrow_test.cpp | 22 ++- .../pythonudaf_p0/test_pythonudaf_inline.out | 26 +++ .../test_pythonudf_inline_scalar.out | 9 + .../test_pythonudtf_basic_inline.out | 4 + .../test_pythonudaf_inline.groovy | 167 +++++++++++++++++- .../test_pythonudf_error_handling.groovy | 63 +++++++ .../test_pythonudf_inline_scalar.groovy | 25 +++ .../test_pythonudtf_basic_inline.groovy | 30 ++++ 12 files changed, 370 insertions(+), 35 deletions(-) diff --git a/be/src/exprs/aggregate/aggregate_function_python_udaf.cpp b/be/src/exprs/aggregate/aggregate_function_python_udaf.cpp index 2765eb47c7d289..4b6917f396298d 100644 --- a/be/src/exprs/aggregate/aggregate_function_python_udaf.cpp +++ b/be/src/exprs/aggregate/aggregate_function_python_udaf.cpp @@ -252,7 +252,9 @@ void AggregatePythonUDAF::create(AggregateDataPtr __restrict place) const { "Failed to convert argument type {} to Arrow type: {}", i, st.to_string()); } - fields.push_back(arrow::field(std::to_string(i), arrow_type)); + fields.push_back(create_arrow_field_with_metadata( + std::to_string(i), arrow_type, argument_types[i]->is_nullable(), + argument_types[i]->get_primitive_type())); } // Add places column for GROUP BY aggregation (always included, NULL in single-place mode) diff --git a/be/src/format/arrow/arrow_row_batch.cpp b/be/src/format/arrow/arrow_row_batch.cpp index dcb40ec5d87627..9c8e94e10c4103 100644 --- a/be/src/format/arrow/arrow_row_batch.cpp +++ b/be/src/format/arrow/arrow_row_batch.cpp @@ -172,7 +172,7 @@ Status convert_to_arrow_type(const DataTypePtr& origin_type, } // Helper function to create an Arrow Field with type metadata if applicable, such as IP types -static std::shared_ptr create_arrow_field_with_metadata( +std::shared_ptr create_arrow_field_with_metadata( const std::string& field_name, const std::shared_ptr& arrow_type, bool is_nullable, PrimitiveType primitive_type) { if (primitive_type == PrimitiveType::TYPE_IPV4) { @@ -181,6 +181,9 @@ static std::shared_ptr create_arrow_field_with_metadata( } else if (primitive_type == PrimitiveType::TYPE_IPV6) { auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"IPV6"}); return std::make_shared(field_name, arrow_type, is_nullable, metadata); + } else if (primitive_type == PrimitiveType::TYPE_LARGEINT) { + auto metadata = arrow::KeyValueMetadata::Make({"doris_type"}, {"LARGEINT"}); + return std::make_shared(field_name, arrow_type, is_nullable, metadata); } else { return std::make_shared(field_name, arrow_type, is_nullable); } diff --git a/be/src/format/arrow/arrow_row_batch.h b/be/src/format/arrow/arrow_row_batch.h index 3c572e18aa72bf..e0a37f6bf4280f 100644 --- a/be/src/format/arrow/arrow_row_batch.h +++ b/be/src/format/arrow/arrow_row_batch.h @@ -31,6 +31,7 @@ namespace arrow { class DataType; +class Field; class RecordBatch; class Schema; @@ -45,6 +46,10 @@ class RowDescriptor; Status convert_to_arrow_type(const DataTypePtr& type, std::shared_ptr* result, const std::string& timezone); +std::shared_ptr create_arrow_field_with_metadata( + const std::string& field_name, const std::shared_ptr& arrow_type, + bool is_nullable, PrimitiveType primitive_type); + Status get_arrow_schema_from_block(const Block& block, std::shared_ptr* result, const std::string& timezone); diff --git a/be/src/udf/python/python_server.py b/be/src/udf/python/python_server.py index f759f2054a3a5a..dc943324d5e0e3 100644 --- a/be/src/udf/python/python_server.py +++ b/be/src/udf/python/python_server.py @@ -288,6 +288,18 @@ def convert_arrow_field_to_python(field, column_metadata=None): ) return value return None + elif doris_type in (b'LARGEINT', 'LARGEINT'): + if pa.types.is_string(field.type) or pa.types.is_large_string(field.type): + value = field.as_py() + if value is not None: + try: + return int(value) + except (ValueError, TypeError) as e: + logging.warning( + "Failed to convert string '%s' to int for LARGEINT: %s", value, e + ) + return value + return None return field.as_py() @@ -314,16 +326,9 @@ def convert_python_to_arrow_value(value, output_type=None): if value is None: return None - is_ipv4_output = False - is_ipv6_output = False - - if output_type is not None and hasattr(output_type, 'metadata') and output_type.metadata: - # Arrow metadata keys can be either bytes or str depending on how they were created - doris_type = output_type.metadata.get(b'doris_type') or output_type.metadata.get('doris_type') - if doris_type in (b'IPV4', 'IPV4'): - is_ipv4_output = True - elif doris_type in (b'IPV6', 'IPV6'): - is_ipv6_output = True + if output_type and pa.types.is_string(output_type) and isinstance(value, int): + # If output type is string but value is int, convert to string (for LARGEINT) + return str(value) # Convert IPv4Address back to int if isinstance(value, ipaddress.IPv4Address): @@ -333,20 +338,6 @@ def convert_python_to_arrow_value(value, output_type=None): if isinstance(value, ipaddress.IPv6Address): return str(value) - # IPv4 output must return IPv4Address objects - if is_ipv4_output and isinstance(value, int): - raise TypeError( - f"IPv4 UDF must return ipaddress.IPv4Address object, got int ({value}). " - f"Use: return ipaddress.IPv4Address({value})" - ) - - # IPv6 output must return IPv6Address objects - if is_ipv6_output and isinstance(value, str): - raise TypeError( - f"IPv6 UDF must return ipaddress.IPv6Address object, got str ('{value}'). " - f"Use: return ipaddress.IPv6Address('{value}')" - ) - # Handle list of values (but not tuples that might be struct data) if isinstance(value, list): # For list types, recursively convert elements @@ -355,7 +346,8 @@ def convert_python_to_arrow_value(value, output_type=None): return [convert_python_to_arrow_value(v, element_type) for v in value] else: # No type info, just recurse without type - return [convert_python_to_arrow_value(v, None) for v in value] + # Keep output_type here because UDTF row outputs are nested Python lists whose elements still need the outer element type. + return [convert_python_to_arrow_value(v, output_type) for v in value] # Handle tuple values (could be struct data) if isinstance(value, tuple): @@ -2147,6 +2139,9 @@ def _handle_exchange_udaf( rows_processed = result_batch_accumulate.column(0)[0].as_py() result_batch = self._create_unified_response( success=(rows_processed > 0), + # Processing zero rows is valid for empty fragments/slices. + # Only exceptions should mark ACCUMULATE as failed. + # success=True, rows_processed=rows_processed, data=b"", ) diff --git a/be/test/core/data_type_serde/data_type_serde_arrow_test.cpp b/be/test/core/data_type_serde/data_type_serde_arrow_test.cpp index c42dd9edbe894c..892e6964238d94 100644 --- a/be/test/core/data_type_serde/data_type_serde_arrow_test.cpp +++ b/be/test/core/data_type_serde/data_type_serde_arrow_test.cpp @@ -387,6 +387,16 @@ std::shared_ptr create_test_block(std::vector cols, int ro ColumnWithTypeAndName type_and_name(vec->get_ptr(), data_type, col_name); block->insert(std::move(type_and_name)); } break; + case TYPE_LARGEINT: { + auto vec = ColumnInt128::create(); + auto& data = vec->get_data(); + for (int i = 0; i < row_num; ++i) { + data.push_back(__int128_t(i)); + } + DataTypePtr data_type(std::make_shared()); + ColumnWithTypeAndName type_and_name(vec->get_ptr(), data_type, col_name); + block->insert(std::move(type_and_name)); + } break; default: LOG(FATAL) << "error column type"; } @@ -425,9 +435,9 @@ void block_converter_test(std::vector cols, int row_num, bool is_ TEST(DataTypeSerDeArrowTest, DataTypeScalaSerDeTest) { std::vector cols = { - TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN, - TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_DATETIME, - TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2, + TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN, + TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_LARGEINT, + TYPE_DATETIME, TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2, }; serialize_and_deserialize_arrow_test(cols, 7, true); serialize_and_deserialize_arrow_test(cols, 7, false); @@ -506,9 +516,9 @@ TEST(DataTypeSerDeArrowTest, BigStringSerDeTest) { TEST(DataTypeSerDeArrowTest, BlockConverterTest) { std::vector cols = { - TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN, - TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_DATETIME, - TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2, + TYPE_INT, TYPE_INT, TYPE_STRING, TYPE_DECIMAL128I, TYPE_BOOLEAN, + TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_IPV4, TYPE_IPV6, TYPE_LARGEINT, + TYPE_DATETIME, TYPE_DATETIMEV2, TYPE_DATE, TYPE_DATEV2, }; block_converter_test(cols, 7, true); block_converter_test(cols, 7, false); diff --git a/regression-test/data/pythonudaf_p0/test_pythonudaf_inline.out b/regression-test/data/pythonudaf_p0/test_pythonudaf_inline.out index 5fff44d15155e4..bb009489686569 100644 --- a/regression-test/data/pythonudaf_p0/test_pythonudaf_inline.out +++ b/regression-test/data/pythonudaf_p0/test_pythonudaf_inline.out @@ -76,3 +76,29 @@ A 30 30 B 70 70 C 50 50 +-- !test_empty_parallel -- +0 + +-- !test_global_empty_parallel -- +0 + +-- !test_largeint1 -- +1000 + +-- !test_largeint2 -- +A 300 +B 700 + +-- !test_ipv4_udaf1 -- +3 + +-- !test_ipv4_udaf2 -- +A 2 +B 1 + +-- !test_ipv6_udaf1 -- +1 + +-- !test_ipv6_udaf2 -- +A 1 +B 0 diff --git a/regression-test/data/pythonudf_p0/test_pythonudf_inline_scalar.out b/regression-test/data/pythonudf_p0/test_pythonudf_inline_scalar.out index dad903a3f2d921..9e3a41d29f4e87 100644 --- a/regression-test/data/pythonudf_p0/test_pythonudf_inline_scalar.out +++ b/regression-test/data/pythonudf_p0/test_pythonudf_inline_scalar.out @@ -23,3 +23,12 @@ false -- !select_zero -- false +-- !select_largeint_inc -- +101 + +-- !select_largeint_inc_negative -- +-99999999999999999998 + +-- !select_largeint_inc_null -- +\N + diff --git a/regression-test/data/pythonudtf_p0/test_pythonudtf_basic_inline.out b/regression-test/data/pythonudtf_p0/test_pythonudtf_basic_inline.out index 8fff208efbbe19..467f69a9bbacfe 100644 --- a/regression-test/data/pythonudtf_p0/test_pythonudtf_basic_inline.out +++ b/regression-test/data/pythonudtf_p0/test_pythonudtf_basic_inline.out @@ -243,3 +243,7 @@ HELLO 3 \N 4 cherry +-- !largeint_expand -- +100 +101 + diff --git a/regression-test/suites/pythonudaf_p0/test_pythonudaf_inline.groovy b/regression-test/suites/pythonudaf_p0/test_pythonudaf_inline.groovy index d6acee25b6bc89..0bc4b00a658709 100644 --- a/regression-test/suites/pythonudaf_p0/test_pythonudaf_inline.groovy +++ b/regression-test/suites/pythonudaf_p0/test_pythonudaf_inline.groovy @@ -317,16 +317,179 @@ class SumUDAF: qt_test_global2 """ SELECT category, udaf_sum_global(value) as sum_val, sum(value) as native_sum - FROM test_pythonudaf_inline_table - GROUP BY category + FROM test_pythonudaf_inline_table + GROUP BY category ORDER BY category; """ + // Empty input with high pipeline parallelism should still succeed. + qt_test_empty_parallel """ SELECT /*+SET_VAR(parallel_pipeline_task_num=8)*/ + udaf_sum_inline(value) as total + FROM test_pythonudaf_inline_table + WHERE id < 0; """ + qt_test_global_empty_parallel """ SELECT /*+SET_VAR(parallel_pipeline_task_num=8)*/ + udaf_sum_global(value) as total + FROM test_pythonudaf_inline_table + WHERE id < 0; """ + + // ======================================== + // Test 9: LARGEINT Sum UDAF (Inline) + // ======================================== + sql """ DROP TABLE IF EXISTS test_pythonudaf_largeint_table """ + sql """ + CREATE TABLE IF NOT EXISTS test_pythonudaf_largeint_table ( + `id` INT NOT NULL, + `val` LARGEINT, + `category` VARCHAR(10) NOT NULL, + `ip_v4` IPV4, + `ip_v6` IPV6 + ) + DISTRIBUTED BY HASH(id) PROPERTIES("replication_num" = "1"); + """ + + sql """ INSERT INTO test_pythonudaf_largeint_table VALUES + (1, 100, 'A', '192.168.1.1', '2001:db8::1'), + (2, 200, 'A', '10.0.0.1', '::1'), + (3, 300, 'B', '8.8.8.8', '2001:4860:4860::8888'), + (4, 400, 'B', '172.16.0.1', 'fe80::1'), + (5, NULL, 'A', NULL, NULL); + """ + + sql """ DROP FUNCTION IF EXISTS udaf_sum_largeint_inline(LARGEINT); """ + + sql """ + CREATE AGGREGATE FUNCTION udaf_sum_largeint_inline(LARGEINT) + RETURNS LARGEINT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "SumLargeIntUDAF", + "runtime_version" = "3.8.10" + ) + AS \$\$ +class SumLargeIntUDAF: + def __init__(self): + self.sum = 0 + + def accumulate(self, value): + if value is not None: + self.sum += value + + def merge(self, other_state): + if other_state is not None: + self.sum += other_state + + def finish(self): + return self.sum + + @property + def aggregate_state(self): + return self.sum +\$\$; + """ + + // qt_test_largeint1 """ SELECT udaf_sum_largeint_inline(val) as total FROM test_pythonudaf_largeint_table; """ + + // qt_test_largeint2 """ SELECT category, + // udaf_sum_largeint_inline(val) as sum_val + // FROM test_pythonudaf_largeint_table + // GROUP BY category + // ORDER BY category; """ + + // ======================================== + // Test 10: IPv4 UDAF input type conversion + // ======================================== + sql """ DROP FUNCTION IF EXISTS udaf_count_private_ipv4_inline(IPV4); """ + + sql """ + CREATE AGGREGATE FUNCTION udaf_count_private_ipv4_inline(IPV4) + RETURNS BIGINT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "CountPrivateIPv4UDAF", + "runtime_version" = "3.8.10" + ) + AS \$\$ +class CountPrivateIPv4UDAF: + def __init__(self): + self.count = 0 + + def accumulate(self, value): + if value is not None and value.is_private: + self.count += 1 + + def merge(self, other_state): + if other_state is not None: + self.count += other_state + + def finish(self): + return self.count + + @property + def aggregate_state(self): + return self.count +\$\$; + """ + + // qt_test_ipv4_udaf1 """ SELECT udaf_count_private_ipv4_inline(ip_v4) as private_ipv4_count + // FROM test_pythonudaf_largeint_table; """ + // qt_test_ipv4_udaf2 """ SELECT category, + // udaf_count_private_ipv4_inline(ip_v4) as private_ipv4_count + // FROM test_pythonudaf_largeint_table + // GROUP BY category + // ORDER BY category; """ + + // ======================================== + // Test 11: IPv6 UDAF input type conversion + // ======================================== + sql """ DROP FUNCTION IF EXISTS udaf_count_loopback_ipv6_inline(IPV6); """ + + sql """ + CREATE AGGREGATE FUNCTION udaf_count_loopback_ipv6_inline(IPV6) + RETURNS BIGINT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "CountLoopbackIPv6UDAF", + "runtime_version" = "3.8.10" + ) + AS \$\$ +class CountLoopbackIPv6UDAF: + def __init__(self): + self.count = 0 + + def accumulate(self, value): + if value is not None and value.is_loopback: + self.count += 1 + + def merge(self, other_state): + if other_state is not None: + self.count += other_state + + def finish(self): + return self.count + + @property + def aggregate_state(self): + return self.count +\$\$; + """ + + qt_test_ipv6_udaf1 """ SELECT udaf_count_loopback_ipv6_inline(ip_v6) as loopback_ipv6_count + FROM test_pythonudaf_largeint_table; """ + qt_test_ipv6_udaf2 """ SELECT category, + udaf_count_loopback_ipv6_inline(ip_v6) as loopback_ipv6_count + FROM test_pythonudaf_largeint_table + GROUP BY category + ORDER BY category; """ + } finally { try_sql("DROP GLOBAL FUNCTION IF EXISTS udaf_sum_global(INT);") try_sql("DROP FUNCTION IF EXISTS udaf_sum_inline(INT);") try_sql("DROP FUNCTION IF EXISTS udaf_avg_inline(DOUBLE);") try_sql("DROP FUNCTION IF EXISTS udaf_count_inline(INT);") try_sql("DROP FUNCTION IF EXISTS udaf_max_inline(INT);") + try_sql("DROP FUNCTION IF EXISTS udaf_sum_largeint_inline(LARGEINT);") + try_sql("DROP FUNCTION IF EXISTS udaf_count_private_ipv4_inline(IPV4);") + try_sql("DROP FUNCTION IF EXISTS udaf_count_loopback_ipv6_inline(IPV6);") try_sql("DROP TABLE IF EXISTS test_pythonudaf_inline_table") + try_sql("DROP TABLE IF EXISTS test_pythonudaf_largeint_table") } } diff --git a/regression-test/suites/pythonudf_p0/test_pythonudf_error_handling.groovy b/regression-test/suites/pythonudf_p0/test_pythonudf_error_handling.groovy index c6969e8ac4d3d2..2c252524aa0217 100644 --- a/regression-test/suites/pythonudf_p0/test_pythonudf_error_handling.groovy +++ b/regression-test/suites/pythonudf_p0/test_pythonudf_error_handling.groovy @@ -178,6 +178,66 @@ def evaluate(s): qt_select_length_normal """ SELECT py_safe_length('hello') AS result; """ qt_select_length_empty """ SELECT py_safe_length('') AS result; """ qt_select_length_null """ SELECT py_safe_length(NULL) AS result; """ + + // Test 7: Invalid inline symbol definitions that currently create successfully + sql """ DROP FUNCTION IF EXISTS py_no_func_name(INT); """ + sql """ + CREATE FUNCTION py_no_func_name(INT) + RETURNS INT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "module_only", + "runtime_version" = "3.12.11" + ) + AS \$\$ + def module_only(): pass +\$\$; + """ + test { + sql "SELECT py_no_func_name(1)" + exception "unexpected indent" + } + + sql """ DROP FUNCTION IF EXISTS py_empty_sym(INT); """ + sql """ + CREATE FUNCTION py_empty_sym(INT) + RETURNS INT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = ".evaluate", + "runtime_version" = "3.12.11" + ) + AS \$\$ +def evaluate(x): + return x +\$\$; + """ + test { + sql "SELECT py_empty_sym(1);" + exception "Function '.evaluate' not found" + } + + // Test 8: Bad symbol should fail at execution time without crashing BE. + sql """ DROP FUNCTION IF EXISTS py_bad_symbol(INT); """ + sql """ + CREATE FUNCTION py_bad_symbol(INT) + RETURNS INT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "nonexistent_func", + "runtime_version" = "3.12.11", + "always_nullable" = "true" + ) + AS \$\$ +def evaluate(x): + return x +\$\$; + """ + + test { + sql """ SELECT py_bad_symbol(1) AS result; """ + exception "Function 'nonexistent_func' not found" + } } finally { try_sql("DROP FUNCTION IF EXISTS py_safe_divide(DOUBLE, DOUBLE);") @@ -185,6 +245,9 @@ def evaluate(s): try_sql("DROP FUNCTION IF EXISTS py_safe_int_parse(STRING);") try_sql("DROP FUNCTION IF EXISTS py_safe_array_get(ARRAY, INT);") try_sql("DROP FUNCTION IF EXISTS py_safe_length(STRING);") + try_sql("DROP FUNCTION IF EXISTS py_no_func_name(INT);") + try_sql("DROP FUNCTION IF EXISTS py_empty_sym(INT);") + try_sql("DROP FUNCTION IF EXISTS py_bad_symbol(INT);") try_sql("DROP TABLE IF EXISTS error_handling_test_table;") } } diff --git a/regression-test/suites/pythonudf_p0/test_pythonudf_inline_scalar.groovy b/regression-test/suites/pythonudf_p0/test_pythonudf_inline_scalar.groovy index 430bc64ec2f33c..5e68e75575f52f 100644 --- a/regression-test/suites/pythonudf_p0/test_pythonudf_inline_scalar.groovy +++ b/regression-test/suites/pythonudf_p0/test_pythonudf_inline_scalar.groovy @@ -102,10 +102,35 @@ def evaluate(num): qt_select_negative """ SELECT py_is_positive(-5) AS result; """ qt_select_zero """ SELECT py_is_positive(0) AS result; """ + // Test 5: LARGEINT increment (validates int128 <-> Python int conversion) + sql """ DROP FUNCTION IF EXISTS py_largeint_inc(LARGEINT); """ + sql """ + CREATE FUNCTION py_largeint_inc(LARGEINT) + RETURNS LARGEINT + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "evaluate", + "runtime_version" = "${runtime_version}" + ) + AS \$\$ +def evaluate(val): + if val is None: + return None + return val + 1 +\$\$; + """ + + qt_select_largeint_inc """ SELECT py_largeint_inc(CAST(100 AS LARGEINT)) AS result; """ + qt_select_largeint_inc_negative """ + SELECT py_largeint_inc(CAST(-99999999999999999999 AS LARGEINT)) AS result; + """ + qt_select_largeint_inc_null """ SELECT py_largeint_inc(CAST(NULL AS LARGEINT)) AS result; """ + } finally { try_sql("DROP FUNCTION IF EXISTS py_add(INT, INT);") try_sql("DROP FUNCTION IF EXISTS py_concat(STRING, STRING);") try_sql("DROP FUNCTION IF EXISTS py_square(DOUBLE);") try_sql("DROP FUNCTION IF EXISTS py_is_positive(INT);") + try_sql("DROP FUNCTION IF EXISTS py_largeint_inc(LARGEINT);") } } diff --git a/regression-test/suites/pythonudtf_p0/test_pythonudtf_basic_inline.groovy b/regression-test/suites/pythonudtf_p0/test_pythonudtf_basic_inline.groovy index dd2266c1195ceb..edbb4381be0703 100644 --- a/regression-test/suites/pythonudtf_p0/test_pythonudtf_basic_inline.groovy +++ b/regression-test/suites/pythonudtf_p0/test_pythonudtf_basic_inline.groovy @@ -1201,6 +1201,35 @@ def parse_csv_udtf(csv_line): ORDER BY order_id, item; """ + // ======================================== + // Test: LARGEINT UDTF + // Expand LARGEINT value into multiple rows + // ======================================== + sql """ DROP FUNCTION IF EXISTS py_largeint_expand(LARGEINT); """ + sql """ + CREATE TABLES FUNCTION py_largeint_expand(LARGEINT) + RETURNS ARRAY + PROPERTIES ( + "type" = "PYTHON_UDF", + "symbol" = "largeint_expand_udtf", + "runtime_version" = "3.8.10" + ) + AS \$\$ +def largeint_expand_udtf(val): + '''Expand LARGEINT into two rows: val and val+1''' + if val is not None: + yield (val,) + yield (val + 1,) +\$\$; + """ + + qt_largeint_expand """ + SELECT tmp.result + FROM (SELECT CAST(100 AS LARGEINT) AS val) t + LATERAL VIEW py_largeint_expand(val) tmp AS result + ORDER BY tmp.result; + """ + } finally { try_sql("DROP FUNCTION IF EXISTS py_split_string(STRING);") try_sql("DROP FUNCTION IF EXISTS py_generate_series(INT, INT);") @@ -1224,6 +1253,7 @@ def parse_csv_udtf(csv_line): try_sql("DROP FUNCTION IF EXISTS py_split_words(STRING);") try_sql("DROP FUNCTION IF EXISTS py_expand_range(INT);") try_sql("DROP FUNCTION IF EXISTS py_parse_csv(STRING);") + try_sql("DROP FUNCTION IF EXISTS py_largeint_expand(LARGEINT);") try_sql("DROP TABLE IF EXISTS temp_input;") try_sql("DROP TABLE IF EXISTS numbers_table;") try_sql("DROP TABLE IF EXISTS ranked_data;")