diff --git a/be/src/exprs/function/function_regexp.cpp b/be/src/exprs/function/function_regexp.cpp index 8a9871c8eb3d6c..cfce51063ef4e3 100644 --- a/be/src/exprs/function/function_regexp.cpp +++ b/be/src/exprs/function/function_regexp.cpp @@ -33,11 +33,13 @@ #include "core/block/column_numbers.h" #include "core/block/column_with_type_and_name.h" #include "core/column/column.h" +#include "core/column/column_array.h" #include "core/column/column_const.h" #include "core/column/column_nullable.h" #include "core/column/column_string.h" #include "core/column/column_vector.h" #include "core/data_type/data_type.h" +#include "core/data_type/data_type_array.h" #include "core/data_type/data_type_nullable.h" #include "core/data_type/data_type_number.h" #include "core/data_type/data_type_string.h" @@ -410,8 +412,10 @@ class FunctionRegexpReplace : public IFunction { } }; +template struct RegexpReplaceImpl { - static constexpr auto name = "regexp_replace"; + static constexpr auto name = ReplaceOne ? "regexp_replace_one" : "regexp_replace"; + static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[], const StringRef& options_value, size_t input_rows_count, ColumnString::Chars& result_data, ColumnString::Offsets& result_offset, @@ -421,14 +425,11 @@ struct RegexpReplaceImpl { const auto* replace_col = check_and_get_column(argument_columns[2].get()); for (size_t i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; - } _execute_inner_loop(context, str_col, pattern_col, replace_col, options_value, result_data, result_offset, null_map, i); } } + static void execute_impl_const_args(FunctionContext* context, ColumnPtr argument_columns[], const StringRef& options_value, size_t input_rows_count, ColumnString::Chars& result_data, @@ -438,14 +439,11 @@ struct RegexpReplaceImpl { const auto* replace_col = check_and_get_column(argument_columns[2].get()); for (size_t i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; - } _execute_inner_loop(context, str_col, pattern_col, replace_col, options_value, result_data, result_offset, null_map, i); } } + template static void _execute_inner_loop(FunctionContext* context, const ColumnString* str_col, const ColumnString* pattern_col, @@ -473,133 +471,93 @@ struct RegexpReplaceImpl { replace_col->get_data_at(index_check_const(index_now, Const)).to_string_view()); std::string result_str(str_col->get_data_at(index_now).to_string()); - re2::RE2::GlobalReplace(&result_str, *re, replace_str); + if constexpr (ReplaceOne) { + re2::RE2::Replace(&result_str, *re, replace_str); + } else { + re2::RE2::GlobalReplace(&result_str, *re, replace_str); + } StringOP::push_value_string(result_str, index_now, result_data, result_offset); } }; -struct RegexpReplaceOneImpl { - static constexpr auto name = "regexp_replace_one"; +template +struct RegexpExtractImpl { + static constexpr auto name = ReturnNull ? "regexp_extract_or_null" : "regexp_extract"; + static constexpr size_t num_args = 3; + static constexpr size_t PATTERN_ARG_IDX = 1; - static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[], - const StringRef& options_value, size_t input_rows_count, - ColumnString::Chars& result_data, ColumnString::Offsets& result_offset, - NullMap& null_map) { - const auto* str_col = check_and_get_column(argument_columns[0].get()); - const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - const auto* replace_col = check_and_get_column(argument_columns[2].get()); - // 3 args - for (size_t i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; - } - _execute_inner_loop(context, str_col, pattern_col, replace_col, options_value, - result_data, result_offset, null_map, i); - } - } + static DataTypePtr return_type() { return make_nullable(std::make_shared()); } - static void execute_impl_const_args(FunctionContext* context, ColumnPtr argument_columns[], - const StringRef& options_value, size_t input_rows_count, - ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map) { - const auto* str_col = check_and_get_column(argument_columns[0].get()); - const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - const auto* replace_col = check_and_get_column(argument_columns[2].get()); - // 3 args - for (size_t i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; - } - _execute_inner_loop(context, str_col, pattern_col, replace_col, options_value, - result_data, result_offset, null_map, i); - } - } - template - static void _execute_inner_loop(FunctionContext* context, const ColumnString* str_col, - const ColumnString* pattern_col, - const ColumnString* replace_col, const StringRef& options_value, - ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map, - const size_t index_now) { - re2::RE2* re = reinterpret_cast( - context->get_function_state(FunctionContext::THREAD_LOCAL)); - std::unique_ptr scoped_re; // destroys re if state->re is nullptr - if (re == nullptr) { - std::string error_str; - const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, Const)); - bool st = StringFunctions::compile_regex(pattern, &error_str, StringRef(), - options_value, scoped_re); - if (!st) { - context->add_warning(error_str.c_str()); - StringOP::push_null_string(index_now, result_data, result_offset, null_map); - return; - } - re = scoped_re.get(); + static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) { + bool col_const[3]; + ColumnPtr argument_columns[3]; + for (int i = 0; i < 3; ++i) { + col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column); } + argument_columns[0] = col_const[0] ? static_cast( + *block.get_by_position(arguments[0]).column) + .convert_to_full_column() + : block.get_by_position(arguments[0]).column; - re2::StringPiece replace_str = re2::StringPiece( - replace_col->get_data_at(index_check_const(index_now, Const)).to_string_view()); + auto result_null_map = ColumnUInt8::create(input_rows_count, 0); + auto result_data_column = ColumnString::create(); + auto& result_data = result_data_column->get_chars(); + auto& result_offset = result_data_column->get_offsets(); + result_offset.resize(input_rows_count); + auto& null_map = result_null_map->get_data(); - std::string result_str(str_col->get_data_at(index_now).to_string()); - re2::RE2::Replace(&result_str, *re, replace_str); - StringOP::push_value_string(result_str, index_now, result_data, result_offset); - } -}; + default_preprocess_parameter_columns(argument_columns, col_const, {1, 2}, block, arguments); -template -struct RegexpExtractImpl { - static constexpr auto name = ReturnNull ? "regexp_extract_or_null" : "regexp_extract"; - // 3 args - static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[], - size_t input_rows_count, ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map) { - const auto* str_col = check_and_get_column(argument_columns[0].get()); - const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - const auto* index_col = check_and_get_column(argument_columns[2].get()); - for (size_t i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; - } - const auto& index_data = index_col->get_int(i); - if (index_data < 0) { - ReturnNull ? StringOP::push_null_string(i, result_data, result_offset, null_map) - : StringOP::push_empty_string(i, result_data, result_offset); - continue; - } - _execute_inner_loop(context, str_col, pattern_col, index_data, result_data, - result_offset, null_map, i); + if (col_const[1] && col_const[2]) { + _execute_loop(context, argument_columns, input_rows_count, result_data, + result_offset, null_map); + } else { + _execute_loop(context, argument_columns, input_rows_count, result_data, + result_offset, null_map); } + + block.get_by_position(result).column = + ColumnNullable::create(std::move(result_data_column), std::move(result_null_map)); + return Status::OK(); } - static void execute_impl_const_args(FunctionContext* context, ColumnPtr argument_columns[], - size_t input_rows_count, ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map) { +private: + template + static void _execute_loop(FunctionContext* context, ColumnPtr argument_columns[], + size_t input_rows_count, ColumnString::Chars& result_data, + ColumnString::Offsets& result_offset, NullMap& null_map) { const auto* str_col = check_and_get_column(argument_columns[0].get()); const auto* pattern_col = check_and_get_column(argument_columns[1].get()); const auto* index_col = check_and_get_column(argument_columns[2].get()); - const auto& index_data = index_col->get_int(0); - if (index_data < 0) { + if constexpr (Const) { + const auto& index_data = index_col->get_int(0); + if (index_data < 0) { + for (size_t i = 0; i < input_rows_count; ++i) { + ReturnNull ? StringOP::push_null_string(i, result_data, result_offset, null_map) + : StringOP::push_empty_string(i, result_data, result_offset); + } + return; + } for (size_t i = 0; i < input_rows_count; ++i) { - ReturnNull ? StringOP::push_null_string(i, result_data, result_offset, null_map) - : StringOP::push_empty_string(i, result_data, result_offset); + _execute_inner_loop(context, str_col, pattern_col, index_data, result_data, + result_offset, null_map, i); } - return; - } - - for (size_t i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; + } else { + for (size_t i = 0; i < input_rows_count; ++i) { + const auto& index_data = index_col->get_int(i); + if (index_data < 0) { + ReturnNull ? StringOP::push_null_string(i, result_data, result_offset, null_map) + : StringOP::push_empty_string(i, result_data, result_offset); + continue; + } + _execute_inner_loop(context, str_col, pattern_col, index_data, result_data, + result_offset, null_map, i); } - - _execute_inner_loop(context, str_col, pattern_col, index_data, result_data, - result_offset, null_map, i); } } + template static void _execute_inner_loop(FunctionContext* context, const ColumnString* str_col, const ColumnString* pattern_col, const Int64 index_data, @@ -648,66 +606,183 @@ struct RegexpExtractImpl { } }; -struct RegexpExtractAllImpl { - static constexpr auto name = "regexp_extract_all"; +// Output handler for existing string-formatted result: "['a','b']" +struct RegexpExtractAllStringOutput { + static constexpr const char* func_name = "regexp_extract_all"; + static DataTypePtr return_type() { return make_nullable(std::make_shared()); } - size_t get_number_of_arguments() const { return 2; } + ColumnString::Chars& result_data; + ColumnString::Offsets& result_offset; - static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[], - size_t input_rows_count, ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map) { - const auto* str_col = check_and_get_column(argument_columns[0].get()); - const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - for (int i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; + void push_empty(size_t index) { + StringOP::push_empty_string(index, result_data, result_offset); + } + void push_null(size_t index, NullMap& null_map) { + StringOP::push_null_string(index, result_data, result_offset, null_map); + } + void push_matches(size_t index, const std::vector& matches) { + size_t total_size = 2; // '[' and ']' + for (const auto& m : matches) { + total_size += m.size() + 3; // "'xxx'," + } + + size_t old_size = result_data.size(); + result_data.resize(old_size + total_size); + char* pos = reinterpret_cast(&result_data[old_size]); + + *pos++ = '['; + for (size_t j = 0; j < matches.size(); ++j) { + if (j > 0) { + *pos++ = ','; } - _execute_inner_loop(context, str_col, pattern_col, result_data, result_offset, - null_map, i); + *pos++ = '\''; + memcpy(pos, matches[j].data(), matches[j].size()); + pos += matches[j].size(); + *pos++ = '\''; } + *pos++ = ']'; + + result_data.resize(old_size + static_cast(pos - reinterpret_cast( + &result_data[old_size]))); + result_offset[index] = static_cast(result_data.size()); } - static void execute_impl_const_args(FunctionContext* context, ColumnPtr argument_columns[], - size_t input_rows_count, ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map) { + struct State { + ColumnString::MutablePtr data_column; + explicit State(size_t rows) : data_column(ColumnString::create()) { + data_column->get_offsets().resize(rows); + } + RegexpExtractAllStringOutput create_handler() { + return {.result_data = data_column->get_chars(), + .result_offset = data_column->get_offsets()}; + } + ColumnPtr finalize(ColumnUInt8::MutablePtr null_map) { + return ColumnNullable::create(std::move(data_column), std::move(null_map)); + } + }; +}; + +// Output handler for proper Array> result +struct RegexpExtractAllArrayOutput { + static constexpr const char* func_name = "regexp_extract_all_array"; + static DataTypePtr return_type() { + return make_nullable( + std::make_shared(make_nullable(std::make_shared()))); + } + + ColumnString& nested_col; + ColumnArray::Offsets64& array_offsets; + NullMap& nested_null_map; + UInt64 current_offset = 0; + + void push_empty(size_t index) { array_offsets.push_back(current_offset); } + void push_null(size_t index, NullMap& null_map) { + null_map[index] = 1; + array_offsets.push_back(current_offset); + } + void push_matches(size_t index, const std::vector& matches) { + for (const auto& m : matches) { + nested_col.insert_data(m.data(), m.size()); + nested_null_map.push_back(0); + current_offset++; + } + array_offsets.push_back(current_offset); + } + + struct State { + ColumnArray::MutablePtr array_column; + ColumnNullable* nested_nullable; + explicit State(size_t /*rows*/) { + auto nullable_str = make_nullable(std::make_shared()); + array_column = ColumnArray::create(nullable_str->create_column(), + ColumnArray::ColumnOffsets::create()); + nested_nullable = assert_cast(&array_column->get_data()); + } + RegexpExtractAllArrayOutput create_handler() { + return {.nested_col = assert_cast(nested_nullable->get_nested_column()), + .array_offsets = array_column->get_offsets(), + .nested_null_map = nested_nullable->get_null_map_data()}; + } + ColumnPtr finalize(ColumnUInt8::MutablePtr null_map) { + return ColumnNullable::create(std::move(array_column), std::move(null_map)); + } + }; +}; + +// Handler controls return type & column layout +template +struct RegexpExtractAllImpl { + static constexpr auto name = Handler::func_name; + static constexpr size_t num_args = 2; + static constexpr size_t PATTERN_ARG_IDX = 1; + + static DataTypePtr return_type() { return Handler::return_type(); } + + static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) { + bool col_const[2]; + ColumnPtr argument_columns[2]; + for (int i = 0; i < 2; ++i) { + col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column); + } + argument_columns[0] = col_const[0] ? static_cast( + *block.get_by_position(arguments[0]).column) + .convert_to_full_column() + : block.get_by_position(arguments[0]).column; + + default_preprocess_parameter_columns(argument_columns, col_const, {1}, block, arguments); + const auto* str_col = check_and_get_column(argument_columns[0].get()); const auto* pattern_col = check_and_get_column(argument_columns[1].get()); - for (int i = 0; i < input_rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, result_data, result_offset, null_map); - continue; - } - _execute_inner_loop(context, str_col, pattern_col, result_data, result_offset, - null_map, i); - } + + auto outer_null_map = ColumnUInt8::create(input_rows_count, 0); + auto& null_map_data = outer_null_map->get_data(); + + typename Handler::State state(input_rows_count); + auto handler = state.create_handler(); + + std::visit( + [&](auto is_const) { + for (size_t i = 0; i < input_rows_count; ++i) { + if (null_map_data[i]) { + handler.push_null(i, null_map_data); + continue; + } + regexp_extract_all_inner_loop(context, str_col, pattern_col, + handler, null_map_data, i); + } + }, + make_bool_variant(col_const[1])); + + block.get_by_position(result).column = state.finalize(std::move(outer_null_map)); + return Status::OK(); } - template - static void _execute_inner_loop(FunctionContext* context, const ColumnString* str_col, - const ColumnString* pattern_col, - ColumnString::Chars& result_data, - ColumnString::Offsets& result_offset, NullMap& null_map, - const size_t index_now) { + +private: + template + static void regexp_extract_all_inner_loop(FunctionContext* context, const ColumnString* str_col, + const ColumnString* pattern_col, Handler& handler, + NullMap& null_map, const size_t index_now) { auto* engine = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); std::unique_ptr scoped_engine; if (engine == nullptr) { std::string error_str; - const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, Const)); + const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, is_const)); scoped_engine = std::make_unique(); bool st = RegexpExtractEngine::compile(pattern, &error_str, *scoped_engine, context->state()->enable_extended_regex()); if (!st) { context->add_warning(error_str.c_str()); - StringOP::push_null_string(index_now, result_data, result_offset, null_map); + handler.push_null(index_now, null_map); return; } engine = scoped_engine.get(); } if (engine->number_of_capturing_groups() == 0) { - StringOP::push_empty_string(index_now, result_data, result_offset); + handler.push_empty(index_now); return; } const auto& str = str_col->get_data_at(index_now); @@ -715,19 +790,10 @@ struct RegexpExtractAllImpl { engine->match_all_and_extract(str.data, str.size, res_matches); if (res_matches.empty()) { - StringOP::push_empty_string(index_now, result_data, result_offset); + handler.push_empty(index_now); return; } - - std::string res = "["; - for (int j = 0; j < res_matches.size(); ++j) { - res += "'" + res_matches[j] + "'"; - if (j < res_matches.size() - 1) { - res += ","; - } - } - res += "]"; - StringOP::push_value_string(std::string_view(res), index_now, result_data, result_offset); + handler.push_matches(index_now, res_matches); } }; @@ -741,22 +807,18 @@ class FunctionRegexpFunctionality : public IFunction { String get_name() const override { return name; } - size_t get_number_of_arguments() const override { - if constexpr (std::is_same_v) { - return 2; - } - return 3; - } + size_t get_number_of_arguments() const override { return Impl::num_args; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - return make_nullable(std::make_shared()); + return Impl::return_type(); } Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override { if (scope == FunctionContext::THREAD_LOCAL) { - if (context->is_col_constant(1)) { + if (context->is_col_constant(Impl::PATTERN_ARG_IDX)) { DCHECK(!context->get_function_state(scope)); - const auto pattern_col = context->get_constant_col(1)->column_ptr; + const auto pattern_col = + context->get_constant_col(Impl::PATTERN_ARG_IDX)->column_ptr; const auto& pattern = pattern_col->get_data_at(0); if (pattern.size == 0) { return Status::OK(); @@ -766,6 +828,7 @@ class FunctionRegexpFunctionality : public IFunction { auto engine = std::make_shared(); bool st = RegexpExtractEngine::compile(pattern, &error_str, *engine, context->state()->enable_extended_regex()); + if (!st) { context->set_error(error_str.c_str()); return Status::InvalidArgument(error_str); @@ -778,65 +841,21 @@ class FunctionRegexpFunctionality : public IFunction { Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const override { - size_t argument_size = arguments.size(); - - auto result_null_map = ColumnUInt8::create(input_rows_count, 0); - auto result_data_column = ColumnString::create(); - auto& result_data = result_data_column->get_chars(); - auto& result_offset = result_data_column->get_offsets(); - result_offset.resize(input_rows_count); - - bool col_const[3]; - ColumnPtr argument_columns[3]; - for (int i = 0; i < argument_size; ++i) { - col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column); - } - argument_columns[0] = col_const[0] ? static_cast( - *block.get_by_position(arguments[0]).column) - .convert_to_full_column() - : block.get_by_position(arguments[0]).column; - if constexpr (std::is_same_v) { - default_preprocess_parameter_columns(argument_columns, col_const, {1}, block, - arguments); - } else { - default_preprocess_parameter_columns(argument_columns, col_const, {1, 2}, block, - arguments); - } - - if constexpr (std::is_same_v) { - if (col_const[1]) { - Impl::execute_impl_const_args(context, argument_columns, input_rows_count, - result_data, result_offset, - result_null_map->get_data()); - } else { - Impl::execute_impl(context, argument_columns, input_rows_count, result_data, - result_offset, result_null_map->get_data()); - } - } else { - if (col_const[1] && col_const[2]) { - Impl::execute_impl_const_args(context, argument_columns, input_rows_count, - result_data, result_offset, - result_null_map->get_data()); - } else { - Impl::execute_impl(context, argument_columns, input_rows_count, result_data, - result_offset, result_null_map->get_data()); - } - } - - block.get_by_position(result).column = - ColumnNullable::create(std::move(result_data_column), std::move(result_null_map)); - return Status::OK(); + return Impl::execute(context, block, arguments, result, input_rows_count); } }; void register_function_regexp_extract(SimpleFunctionFactory& factory) { - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); + factory.register_function, ThreeParamTypes>>(); + factory.register_function, FourParamTypes>>(); + factory.register_function, ThreeParamTypes>>(); + factory.register_function, FourParamTypes>>(); factory.register_function>>(); factory.register_function>>(); - factory.register_function>(); + factory.register_function< + FunctionRegexpFunctionality>>(); + factory.register_function< + FunctionRegexpFunctionality>>(); factory.register_function(); } diff --git a/be/test/exprs/function/function_like_test.cpp b/be/test/exprs/function/function_like_test.cpp index 881886831ec43b..82618a790e99e7 100644 --- a/be/test/exprs/function/function_like_test.cpp +++ b/be/test/exprs/function/function_like_test.cpp @@ -20,6 +20,7 @@ #include "core/column/column_string.h" #include "core/column/column_vector.h" +#include "core/data_type/data_type_array.h" #include "core/data_type/data_type_nullable.h" #include "core/data_type/data_type_number.h" #include "core/data_type/data_type_string.h" @@ -248,6 +249,156 @@ TEST(FunctionLikeTest, regexp_extract_all) { } } +TEST(FunctionLikeTest, regexp_extract_all_array) { + std::string func_name = "regexp_extract_all_array"; + auto str_type = std::make_shared(); + auto return_type = make_nullable( + std::make_shared(make_nullable(std::make_shared()))); + + auto run_case = [&](const std::string& str, const std::string& pattern, + const std::string& expected, bool expect_null = false) { + auto col_str = ColumnString::create(); + col_str->insert_data(str.data(), str.size()); + auto col_pattern = ColumnString::create(); + col_pattern->insert_data(pattern.data(), pattern.size()); + + Block block; + block.insert({std::move(col_str), str_type, "str"}); + block.insert({ColumnConst::create(std::move(col_pattern), 1), str_type, "pattern"}); + block.insert({nullptr, return_type, "result"}); + + ColumnsWithTypeAndName arg_cols = {block.get_by_position(0), block.get_by_position(1)}; + auto func = + SimpleFunctionFactory::instance().get_function(func_name, arg_cols, return_type); + ASSERT_TRUE(func != nullptr); + + std::vector arg_types = {str_type, str_type}; + FunctionUtils fn_utils({}, arg_types, false); + auto* fn_ctx = fn_utils.get_fn_ctx(); + fn_ctx->set_constant_cols( + {nullptr, std::make_shared(block.get_by_position(1).column)}); + + ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::THREAD_LOCAL)); + ASSERT_EQ(Status::OK(), func->execute(fn_ctx, block, {0, 1}, 2, 1)); + + auto result_col = block.get_by_position(2).column; + ASSERT_TRUE(result_col.get() != nullptr); + if (expect_null) { + EXPECT_TRUE(result_col->is_null_at(0)); + } else { + ASSERT_FALSE(result_col->is_null_at(0)); + auto result_str = return_type->to_string(*result_col, 0); + EXPECT_EQ(expected, result_str) + << "input: '" << str << "', pattern: '" << pattern << "'"; + } + + static_cast(func->close(fn_ctx, FunctionContext::THREAD_LOCAL)); + static_cast(func->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + }; + + run_case("x=a3&x=18abc&x=2&y=3&x=4&x=17bcd", "x=([0-9]+)([a-z]+)", "[\"18\", \"17\"]"); + run_case("x=a3&x=18abc&x=2&y=3&x=4", "^x=([a-z]+)([0-9]+)", "[\"a\"]"); + run_case("http://a.m.baidu.com/i41915173660.htm", "i([0-9]+)", "[\"41915173660\"]"); + run_case("http://a.m.baidu.com/i41915i73660.htm", "i([0-9]+)", "[\"41915\", \"73660\"]"); + run_case("hitdecisiondlist", "(i)(.*?)(e)", "[\"i\"]"); + run_case("no_match_here", "x=([0-9]+)", "[]"); + run_case("abc", "([a-z]+)", "[\"abc\"]"); + + // Helper for testing null input propagation + auto nullable_str_type = make_nullable(str_type); + auto run_null_case = [&](bool null_str, bool null_pattern) { + ColumnPtr col_str; + DataTypePtr str_col_type; + if (null_str) { + auto col = ColumnNullable::create(ColumnString::create(), ColumnUInt8::create()); + col->insert_default(); + col_str = std::move(col); + str_col_type = nullable_str_type; + } else { + auto col = ColumnString::create(); + col->insert_data("abc", 3); + col_str = std::move(col); + str_col_type = str_type; + } + + ColumnPtr col_pattern; + DataTypePtr pattern_col_type; + if (null_pattern) { + auto col = ColumnNullable::create(ColumnString::create(), ColumnUInt8::create()); + col->insert_default(); + col_pattern = ColumnConst::create(std::move(col), 1); + pattern_col_type = nullable_str_type; + } else { + auto col = ColumnString::create(); + col->insert_data("([a-z]+)", 8); + col_pattern = ColumnConst::create(std::move(col), 1); + pattern_col_type = str_type; + } + + Block block; + block.insert({col_str, str_col_type, "str"}); + block.insert({col_pattern, pattern_col_type, "pattern"}); + block.insert({nullptr, return_type, "result"}); + + ColumnsWithTypeAndName arg_cols = {block.get_by_position(0), block.get_by_position(1)}; + auto func = + SimpleFunctionFactory::instance().get_function(func_name, arg_cols, return_type); + ASSERT_TRUE(func != nullptr); + + std::vector arg_types = {str_col_type, pattern_col_type}; + FunctionUtils fn_utils({}, arg_types, false); + auto* fn_ctx = fn_utils.get_fn_ctx(); + fn_ctx->set_constant_cols( + {nullptr, std::make_shared(block.get_by_position(1).column)}); + + ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::THREAD_LOCAL)); + ASSERT_EQ(Status::OK(), func->execute(fn_ctx, block, {0, 1}, 2, 1)); + + EXPECT_TRUE(block.get_by_position(2).column->is_null_at(0)) + << "Expected null for null_str=" << null_str << " null_pattern=" << null_pattern; + + static_cast(func->close(fn_ctx, FunctionContext::THREAD_LOCAL)); + static_cast(func->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + }; + + // NULL input string → null result + run_null_case(true, false); + // NULL pattern → null result + run_null_case(false, true); + + // Invalid const pattern → open() should fail + { + auto col_str = ColumnString::create(); + col_str->insert_data("abc", 3); + auto col_pattern = ColumnString::create(); + col_pattern->insert_data("(", 1); + Block block; + block.insert({std::move(col_str), str_type, "str"}); + block.insert({ColumnConst::create(std::move(col_pattern), 1), str_type, "pattern"}); + block.insert({nullptr, return_type, "result"}); + + ColumnsWithTypeAndName arg_cols = {block.get_by_position(0), block.get_by_position(1)}; + auto func = + SimpleFunctionFactory::instance().get_function(func_name, arg_cols, return_type); + ASSERT_TRUE(func != nullptr); + + std::vector arg_types = {str_type, str_type}; + FunctionUtils fn_utils({}, arg_types, false); + auto* fn_ctx = fn_utils.get_fn_ctx(); + fn_ctx->set_constant_cols( + {nullptr, std::make_shared(block.get_by_position(1).column)}); + + ASSERT_EQ(Status::OK(), func->open(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + // Invalid pattern should cause open() to fail for THREAD_LOCAL scope + EXPECT_NE(Status::OK(), func->open(fn_ctx, FunctionContext::THREAD_LOCAL)); + + static_cast(func->close(fn_ctx, FunctionContext::THREAD_LOCAL)); + static_cast(func->close(fn_ctx, FunctionContext::FRAGMENT_LOCAL)); + } +} + TEST(FunctionLikeTest, regexp_replace) { std::string func_name = "regexp_replace"; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index c8719dbaeb5832..ba50340077b54a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -416,6 +416,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpCount; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAll; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAllArray; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractOrNull; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplace; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplaceOne; @@ -986,6 +987,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(RegexpCount.class, "regexp_count"), scalar(RegexpExtract.class, "regexp_extract"), scalar(RegexpExtractAll.class, "regexp_extract_all"), + scalar(RegexpExtractAllArray.class, "regexp_extract_all_array"), scalar(RegexpExtractOrNull.class, "regexp_extract_or_null"), scalar(RegexpReplace.class, "regexp_replace"), scalar(RegexpReplaceOne.class, "regexp_replace_one"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/RegexpExtractAllArray.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/RegexpExtractAllArray.java new file mode 100644 index 00000000000000..58f395bb71652f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/RegexpExtractAllArray.java @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * ScalarFunction 'regexp_extract_all_array'. + * Returns all matches of a regex pattern as an Array<String> instead of a string-formatted array. + */ +public class RegexpExtractAllArray extends ScalarFunction + implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable, PropagateNullLiteral { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT)) + .args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(ArrayType.of(StringType.INSTANCE)) + .args(StringType.INSTANCE, StringType.INSTANCE) + ); + + /** + * constructor with 2 arguments. + */ + public RegexpExtractAllArray(Expression arg0, Expression arg1) { + super("regexp_extract_all_array", arg0, arg1); + } + + /** constructor for withChildren and reuse signature */ + private RegexpExtractAllArray(ScalarFunctionParams functionParams) { + super(functionParams); + } + + /** + * withChildren. + */ + @Override + public RegexpExtractAllArray withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new RegexpExtractAllArray(getFunctionParams(children)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitRegexpExtractAllArray(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index c32d2157d50c50..6c78bb5826d7fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -437,6 +437,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpCount; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAll; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractAllArray; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtractOrNull; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplace; import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplaceOne; @@ -2160,6 +2161,10 @@ default R visitRegexpExtractAll(RegexpExtractAll regexpExtractAll, C context) { return visitScalarFunction(regexpExtractAll, context); } + default R visitRegexpExtractAllArray(RegexpExtractAllArray regexpExtractAllArray, C context) { + return visitScalarFunction(regexpExtractAllArray, context); + } + default R visitRegexpExtractOrNull(RegexpExtractOrNull regexpExtractOrNull, C context) { return visitScalarFunction(regexpExtractOrNull, context); } diff --git a/regression-test/data/query_p0/sql_functions/string_functions/test_string_function_regexp.out b/regression-test/data/query_p0/sql_functions/string_functions/test_string_function_regexp.out index d7422abb0e7472..eedf54b0fa4c9e 100644 --- a/regression-test/data/query_p0/sql_functions/string_functions/test_string_function_regexp.out +++ b/regression-test/data/query_p0/sql_functions/string_functions/test_string_function_regexp.out @@ -262,6 +262,62 @@ aXb -- !sql_regexp_extract_all_10 -- ['aXb','cXd'] +-- !regexp_extract_all_array_1 -- +["18", "17"] + +-- !regexp_extract_all_array_2 -- +["41915", "73660"] + +-- !regexp_extract_all_array_3 -- +["abc", "def", "ghi"] + +-- !regexp_extract_all_array_4 -- +[] + +-- !regexp_extract_all_array_5 -- +\N + +-- !regexp_extract_all_array_6 -- +\N + +-- !regexp_extract_all_array_7 -- +["ab", "c", "c", "c"] + +-- !regexp_extract_all_array_8 -- +\N +[] +["Emmy", "eillish"] +["It", "s", "ok"] +["It", "s", "true"] +["billie", "eillish"] +["billie", "eillish"] + +-- !regexp_extract_all_array_9 -- +\N +[] +["mmy", "eillish"] +["t", "s", "ok"] +["t", "s", "true"] +["billie", "eillish"] +["billie", "eillish"] + +-- !regexp_extract_all_array_10 -- +\N 5 \N + 6 [] +Emmy eillish 3 ["Emmy", "eillish"] +It's ok 2 ["It", "s", "ok"] +It's true 4 ["It", "s", "true"] +billie eillish \N ["billie", "eillish"] +billie eillish 1 ["billie", "eillish"] + +-- !regexp_extract_all_array_11 -- +[] +[] +[] +[] +[] +[] + -- !sql -- a-b-c diff --git a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function_regexp.groovy b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function_regexp.groovy index 3a219a2e619b38..2f22e319535a09 100644 --- a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function_regexp.groovy +++ b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function_regexp.groovy @@ -128,6 +128,18 @@ suite("test_string_function_regexp") { qt_sql_regexp_extract_all_9 "SELECT REGEXP_EXTRACT_ALL(concat('aXb', char(10), 'cXd'), '(?-s)(\\\\w.\\\\w)');" qt_sql_regexp_extract_all_10 "SELECT REGEXP_EXTRACT_ALL(concat('aXb', char(10), 'cXd'), '(\\\\w.\\\\w)');" + qt_regexp_extract_all_array_1 "SELECT regexp_extract_all_array('x=a3&x=18abc&x=2&y=3&x=4&x=17bcd', 'x=([0-9]+)([a-z]+)');" + qt_regexp_extract_all_array_2 "SELECT regexp_extract_all_array('http://a.m.baidu.com/i41915i73660.htm', 'i([0-9]+)');" + qt_regexp_extract_all_array_3 "SELECT regexp_extract_all_array('abc=111, def=222, ghi=333', '(\"[^\"]+\"|\\\\w+)=(\"[^\"]+\"|\\\\w+)');" + qt_regexp_extract_all_array_4 "select regexp_extract_all_array('xxfs','f');" + qt_regexp_extract_all_array_5 "select regexp_extract_all_array(NULL, 'pattern');" + qt_regexp_extract_all_array_6 "select regexp_extract_all_array('text', NULL);" + qt_regexp_extract_all_array_7 "select regexp_extract_all_array('abcdfesscca', '(ab|c|)');" + qt_regexp_extract_all_array_8 "SELECT regexp_extract_all_array(k, '(\\\\w+)') from test_string_function_regexp ORDER BY k;" + qt_regexp_extract_all_array_9 "SELECT regexp_extract_all_array(k, '([a-z]+)') from test_string_function_regexp ORDER BY k;" + qt_regexp_extract_all_array_10 "SELECT k, v, regexp_extract_all_array(k, '(\\\\w+)') from test_string_function_regexp ORDER BY k;" + qt_regexp_extract_all_array_11 "SELECT regexp_extract_all_array(k, concat('^', k)) from test_string_function_regexp WHERE k IS NOT NULL ORDER BY k;" + qt_sql "SELECT regexp_replace('a b c', \" \", \"-\");" qt_sql "SELECT regexp_replace('a b c','(b)','<\\\\1>');"